use crate::RelayServer; use std::sync::Arc; use anyhow::Result; use fastwebsockets::{ upgrade::{is_upgrade_request, upgrade}, FragmentCollectorRead, Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketWrite, }; use hyper::{ body::{Bytes, Incoming}, upgrade::Upgraded, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; use qstring::QString; use tokio::{ net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpStream, }, sync::broadcast::{error::RecvError, Receiver}, }; use uuid::Uuid; use crate::http::{body_empty, body_full, ServerResponse}; enum Operation<'f> { NoOp, WriteBlock(Bytes), WriteFrame(Frame<'f>), ExitWithFrame(Frame<'f>), Exit, } struct RelaySubscription { id: Uuid, running: bool, } type WSRead = FragmentCollectorRead; type WSWrite = WebSocketWrite; impl RelaySubscription { fn create(mut ws: WebSocket>) -> (Self, WSRead, WSWrite) { ws.set_auto_close(false); ws.set_auto_pong(false); let (ws_rx, ws_tx) = ws.split(|stream| { let upgraded = stream.into_inner(); let parts = upgraded .downcast::>() .expect("HTTP stream should be a TokioIo !"); let (read, write) = parts.io.into_inner().into_split(); (read, write) }); let ws_rx = FragmentCollectorRead::new(ws_rx); let sub = RelaySubscription { id: Uuid::new_v4(), running: true, }; (sub, ws_rx, ws_tx) } async fn dispatch_operation(&mut self, ws_tx: &mut WSWrite, op: Operation<'_>) { if let Err(e) = match op { Operation::NoOp => return, Operation::WriteBlock(bytes) => { ws_tx .write_frame(Frame::binary(Payload::Borrowed(&bytes))) .await } Operation::WriteFrame(frame) => ws_tx.write_frame(frame).await, Operation::ExitWithFrame(frame) => { let _ = ws_tx.write_frame(frame).await; self.running = false; return; } Operation::Exit => { self.running = false; return; } } { tracing::warn!("Encountered error: {:?}", e); self.running = false; } } } async fn read_frame<'f>(ws_rx: &mut WSRead) -> Operation<'f> { match ws_rx .read_frame::<_, WebSocketError>(&mut move |_| async { unreachable!() // it'll be fiiiine :3 }) .await { Ok(frame) if frame.opcode == OpCode::Ping => { Operation::WriteFrame(Frame::pong(frame.payload)) } Ok(frame) if frame.opcode == OpCode::Close => { Operation::ExitWithFrame(Frame::close_raw(frame.payload)) } Ok(_frame) => { Operation::NoOp // discard } Err(_e) => Operation::Exit, } } async fn rebroadcast_block<'f>(block_rx: &mut Receiver) -> Operation<'f> { match block_rx.recv().await { Ok(block) => Operation::WriteBlock(block), Err(RecvError::Closed) => Operation::ExitWithFrame(Frame::close(1001, b"Going away")), Err(RecvError::Lagged(_)) => { Operation::ExitWithFrame(Frame::close(1008, b"Client too slow")) } } } async fn run_subscription( server: Arc, req: Request, ws: WebSocket>, ) { let query = req.uri().query().map(QString::from); let cursor: Option = query .as_ref() .and_then(|q| q.get("cursor")) .and_then(|s| s.parse().ok()); let (mut sub, mut ws_rx, mut ws_tx) = RelaySubscription::create(ws); tracing::debug!(id = %sub.id, "subscription started"); if let Some(_cursor) = cursor { tracing::debug!(id = %sub.id, "filling from event cache"); // TODO: cursor catchup (read from server db history) tracing::debug!(id = %sub.id, "subscription live-tailing"); } // live tailing: let mut raw_block_rx = server.raw_block_tx.subscribe(); while sub.running { let op = tokio::select! { biased; op = rebroadcast_block(&mut raw_block_rx) => op, op = read_frame(&mut ws_rx) => op, }; sub.dispatch_operation(&mut ws_tx, op).await; } tracing::debug!(id = %sub.id, "subscription ended"); } pub async fn handle_subscription( server: Arc, mut req: Request, ) -> Result { if !is_upgrade_request(&req) { return Ok(Response::builder() .status(StatusCode::UPGRADE_REQUIRED) .body(body_full("Upgrade Required"))?); } let (res, ws_fut) = upgrade(&mut req)?; tokio::task::spawn(async move { let ws = match ws_fut.await { Ok(ws) => ws, Err(e) => { tracing::warn!("error upgrading WebSocket: {e:?}"); return; } }; run_subscription(server, req, ws).await; }); let (head, _) = res.into_parts(); Ok(Response::from_parts(head, body_empty())) }