relay-legacy/src/relay_subscription.rs

184 lines
5.2 KiB
Rust

use crate::RelayServer;
use std::sync::Arc;
use anyhow::Result;
use fastwebsockets::{
upgrade::{is_upgrade_request, upgrade},
FragmentCollectorRead, Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketWrite,
};
use hyper::{
body::{Bytes, Incoming},
upgrade::Upgraded,
Request, Response, StatusCode,
};
use hyper_util::rt::TokioIo;
use qstring::QString;
use tokio::{
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
},
sync::broadcast::{error::RecvError, Receiver},
};
use uuid::Uuid;
use crate::http::{body_empty, body_full, ServerResponse};
enum Operation<'f> {
NoOp,
WriteBlock(Bytes),
WriteFrame(Frame<'f>),
ExitWithFrame(Frame<'f>),
Exit,
}
struct RelaySubscription {
id: Uuid,
running: bool,
}
type WSRead = FragmentCollectorRead<OwnedReadHalf>;
type WSWrite = WebSocketWrite<OwnedWriteHalf>;
impl RelaySubscription {
fn create(mut ws: WebSocket<TokioIo<Upgraded>>) -> (Self, WSRead, WSWrite) {
ws.set_auto_close(false);
ws.set_auto_pong(false);
let (ws_rx, ws_tx) = ws.split(|stream| {
let upgraded = stream.into_inner();
let parts = upgraded
.downcast::<TokioIo<TcpStream>>()
.expect("HTTP stream should be a TokioIo<TcpStream> !");
let (read, write) = parts.io.into_inner().into_split();
(read, write)
});
let ws_rx = FragmentCollectorRead::new(ws_rx);
let sub = RelaySubscription {
id: Uuid::new_v4(),
running: true,
};
(sub, ws_rx, ws_tx)
}
async fn dispatch_operation(&mut self, ws_tx: &mut WSWrite, op: Operation<'_>) {
if let Err(e) = match op {
Operation::NoOp => return,
Operation::WriteBlock(bytes) => {
ws_tx
.write_frame(Frame::binary(Payload::Borrowed(&bytes)))
.await
}
Operation::WriteFrame(frame) => ws_tx.write_frame(frame).await,
Operation::ExitWithFrame(frame) => {
let _ = ws_tx.write_frame(frame).await;
self.running = false;
return;
}
Operation::Exit => {
self.running = false;
return;
}
} {
tracing::warn!("Encountered error: {:?}", e);
self.running = false;
}
}
}
async fn read_frame<'f>(ws_rx: &mut WSRead) -> Operation<'f> {
match ws_rx
.read_frame::<_, WebSocketError>(&mut move |_| async {
unreachable!() // it'll be fiiiine :3
})
.await
{
Ok(frame) if frame.opcode == OpCode::Ping => {
Operation::WriteFrame(Frame::pong(frame.payload))
}
Ok(frame) if frame.opcode == OpCode::Close => {
Operation::ExitWithFrame(Frame::close_raw(frame.payload))
}
Ok(_frame) => {
Operation::NoOp // discard
}
Err(_e) => Operation::Exit,
}
}
async fn rebroadcast_block<'f>(block_rx: &mut Receiver<Bytes>) -> Operation<'f> {
match block_rx.recv().await {
Ok(block) => Operation::WriteBlock(block),
Err(RecvError::Closed) => Operation::ExitWithFrame(Frame::close(1001, b"Going away")),
Err(RecvError::Lagged(_)) => {
Operation::ExitWithFrame(Frame::close(1008, b"Client too slow"))
}
}
}
async fn run_subscription(
server: Arc<RelayServer>,
req: Request<Incoming>,
ws: WebSocket<TokioIo<Upgraded>>,
) {
let query = req.uri().query().map(QString::from);
let cursor: Option<usize> = query
.as_ref()
.and_then(|q| q.get("cursor"))
.and_then(|s| s.parse().ok());
let (mut sub, mut ws_rx, mut ws_tx) = RelaySubscription::create(ws);
tracing::debug!(id = %sub.id, "subscription started");
if let Some(_cursor) = cursor {
tracing::debug!(id = %sub.id, "filling from event cache");
// TODO: cursor catchup (read from server db history)
tracing::debug!(id = %sub.id, "subscription live-tailing");
}
// live tailing:
let mut raw_block_rx = server.raw_block_tx.subscribe();
while sub.running {
let op = tokio::select! {
biased;
op = rebroadcast_block(&mut raw_block_rx) => op,
op = read_frame(&mut ws_rx) => op,
};
sub.dispatch_operation(&mut ws_tx, op).await;
}
tracing::debug!(id = %sub.id, "subscription ended");
}
pub async fn handle_subscription(
server: Arc<RelayServer>,
mut req: Request<Incoming>,
) -> Result<ServerResponse> {
if !is_upgrade_request(&req) {
return Ok(Response::builder()
.status(StatusCode::UPGRADE_REQUIRED)
.body(body_full("Upgrade Required"))?);
}
let (res, ws_fut) = upgrade(&mut req)?;
tokio::task::spawn(async move {
let ws = match ws_fut.await {
Ok(ws) => ws,
Err(e) => {
tracing::warn!("error upgrading WebSocket: {e:?}");
return;
}
};
run_subscription(server, req, ws).await;
});
let (head, _) = res.into_parts();
Ok(Response::from_parts(head, body_empty()))
}