184 lines
5.2 KiB
Rust
184 lines
5.2 KiB
Rust
use crate::RelayServer;
|
|
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::Result;
|
|
use fastwebsockets::{
|
|
upgrade::{is_upgrade_request, upgrade},
|
|
FragmentCollectorRead, Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketWrite,
|
|
};
|
|
use hyper::{
|
|
body::{Bytes, Incoming},
|
|
upgrade::Upgraded,
|
|
Request, Response, StatusCode,
|
|
};
|
|
use hyper_util::rt::TokioIo;
|
|
use qstring::QString;
|
|
use tokio::{
|
|
net::{
|
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
|
TcpStream,
|
|
},
|
|
sync::broadcast::{error::RecvError, Receiver},
|
|
};
|
|
use uuid::Uuid;
|
|
|
|
use crate::http::{body_empty, body_full, ServerResponse};
|
|
|
|
enum Operation<'f> {
|
|
NoOp,
|
|
WriteBlock(Bytes),
|
|
WriteFrame(Frame<'f>),
|
|
ExitWithFrame(Frame<'f>),
|
|
Exit,
|
|
}
|
|
|
|
struct RelaySubscription {
|
|
id: Uuid,
|
|
running: bool,
|
|
}
|
|
|
|
type WSRead = FragmentCollectorRead<OwnedReadHalf>;
|
|
type WSWrite = WebSocketWrite<OwnedWriteHalf>;
|
|
|
|
impl RelaySubscription {
|
|
fn create(mut ws: WebSocket<TokioIo<Upgraded>>) -> (Self, WSRead, WSWrite) {
|
|
ws.set_auto_close(false);
|
|
ws.set_auto_pong(false);
|
|
|
|
let (ws_rx, ws_tx) = ws.split(|stream| {
|
|
let upgraded = stream.into_inner();
|
|
let parts = upgraded
|
|
.downcast::<TokioIo<TcpStream>>()
|
|
.expect("HTTP stream should be a TokioIo<TcpStream> !");
|
|
let (read, write) = parts.io.into_inner().into_split();
|
|
(read, write)
|
|
});
|
|
let ws_rx = FragmentCollectorRead::new(ws_rx);
|
|
|
|
let sub = RelaySubscription {
|
|
id: Uuid::new_v4(),
|
|
running: true,
|
|
};
|
|
|
|
(sub, ws_rx, ws_tx)
|
|
}
|
|
|
|
async fn dispatch_operation(&mut self, ws_tx: &mut WSWrite, op: Operation<'_>) {
|
|
if let Err(e) = match op {
|
|
Operation::NoOp => return,
|
|
Operation::WriteBlock(bytes) => {
|
|
ws_tx
|
|
.write_frame(Frame::binary(Payload::Borrowed(&bytes)))
|
|
.await
|
|
}
|
|
Operation::WriteFrame(frame) => ws_tx.write_frame(frame).await,
|
|
Operation::ExitWithFrame(frame) => {
|
|
let _ = ws_tx.write_frame(frame).await;
|
|
self.running = false;
|
|
return;
|
|
}
|
|
Operation::Exit => {
|
|
self.running = false;
|
|
return;
|
|
}
|
|
} {
|
|
tracing::warn!("Encountered error: {:?}", e);
|
|
self.running = false;
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn read_frame<'f>(ws_rx: &mut WSRead) -> Operation<'f> {
|
|
match ws_rx
|
|
.read_frame::<_, WebSocketError>(&mut move |_| async {
|
|
unreachable!() // it'll be fiiiine :3
|
|
})
|
|
.await
|
|
{
|
|
Ok(frame) if frame.opcode == OpCode::Ping => {
|
|
Operation::WriteFrame(Frame::pong(frame.payload))
|
|
}
|
|
Ok(frame) if frame.opcode == OpCode::Close => {
|
|
Operation::ExitWithFrame(Frame::close_raw(frame.payload))
|
|
}
|
|
Ok(_frame) => {
|
|
Operation::NoOp // discard
|
|
}
|
|
Err(_e) => Operation::Exit,
|
|
}
|
|
}
|
|
|
|
async fn rebroadcast_block<'f>(block_rx: &mut Receiver<Bytes>) -> Operation<'f> {
|
|
match block_rx.recv().await {
|
|
Ok(block) => Operation::WriteBlock(block),
|
|
Err(RecvError::Closed) => Operation::ExitWithFrame(Frame::close(1001, b"Going away")),
|
|
Err(RecvError::Lagged(_)) => {
|
|
Operation::ExitWithFrame(Frame::close(1008, b"Client too slow"))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn run_subscription(
|
|
server: Arc<RelayServer>,
|
|
req: Request<Incoming>,
|
|
ws: WebSocket<TokioIo<Upgraded>>,
|
|
) {
|
|
let query = req.uri().query().map(QString::from);
|
|
let cursor: Option<usize> = query
|
|
.as_ref()
|
|
.and_then(|q| q.get("cursor"))
|
|
.and_then(|s| s.parse().ok());
|
|
|
|
let (mut sub, mut ws_rx, mut ws_tx) = RelaySubscription::create(ws);
|
|
|
|
tracing::debug!(id = %sub.id, "subscription started");
|
|
|
|
if let Some(_cursor) = cursor {
|
|
tracing::debug!(id = %sub.id, "filling from event cache");
|
|
|
|
// TODO: cursor catchup (read from server db history)
|
|
|
|
tracing::debug!(id = %sub.id, "subscription live-tailing");
|
|
}
|
|
|
|
// live tailing:
|
|
let mut raw_block_rx = server.raw_block_tx.subscribe();
|
|
while sub.running {
|
|
let op = tokio::select! {
|
|
biased;
|
|
op = rebroadcast_block(&mut raw_block_rx) => op,
|
|
op = read_frame(&mut ws_rx) => op,
|
|
};
|
|
sub.dispatch_operation(&mut ws_tx, op).await;
|
|
}
|
|
|
|
tracing::debug!(id = %sub.id, "subscription ended");
|
|
}
|
|
|
|
pub async fn handle_subscription(
|
|
server: Arc<RelayServer>,
|
|
mut req: Request<Incoming>,
|
|
) -> Result<ServerResponse> {
|
|
if !is_upgrade_request(&req) {
|
|
return Ok(Response::builder()
|
|
.status(StatusCode::UPGRADE_REQUIRED)
|
|
.body(body_full("Upgrade Required"))?);
|
|
}
|
|
|
|
let (res, ws_fut) = upgrade(&mut req)?;
|
|
tokio::task::spawn(async move {
|
|
let ws = match ws_fut.await {
|
|
Ok(ws) => ws,
|
|
Err(e) => {
|
|
tracing::warn!("error upgrading WebSocket: {e:?}");
|
|
return;
|
|
}
|
|
};
|
|
|
|
run_subscription(server, req, ws).await;
|
|
});
|
|
|
|
let (head, _) = res.into_parts();
|
|
Ok(Response::from_parts(head, body_empty()))
|
|
}
|