Initial commit

This commit is contained in:
Charlotte Som 2024-11-25 02:25:03 +02:00
commit d883c9b10b
10 changed files with 1379 additions and 0 deletions

9
.editorconfig Normal file
View file

@ -0,0 +1,9 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = false
insert_final_newline = true

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

1055
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

18
Cargo.toml Normal file
View file

@ -0,0 +1,18 @@
[package]
name = "cerulea_relay"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1.0.93"
fastwebsockets = { version = "0.8.0", features = ["hyper", "unstable-split", "upgrade"] }
http-body-util = "0.1.2"
hyper = { version = "1.5.1", features = ["client", "full", "http1", "http2", "server"] }
hyper-util = { version = "0.1.10", features = ["tokio", "server", "client", "http1", "http2"] }
pin-project-lite = "0.2.15"
qstring = "0.7.2"
tap = "1.0.1"
tokio = { version = "1.41.1", features = ["full"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
uuid = { version = "1.11.0", features = ["v4"] }

8
README.md Normal file
View file

@ -0,0 +1,8 @@
# cerulea-relay
Realtime relay (1hr backfill window) for PDSes with fewer than 1000 repos.
The idea is that we can have much larger limits if we scale down the volume of the network.
- Large block sizes
- Large record size limit
- etcetcetc

4
src/lib.rs Normal file
View file

@ -0,0 +1,4 @@
pub mod prelude;
pub mod relay_subscription;
pub mod server;

21
src/main.rs Normal file
View file

@ -0,0 +1,21 @@
use anyhow::Result;
use std::{net::SocketAddr, sync::Arc};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use cerulea_relay::server::{self, RelayServer};
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(fmt::layer())
.with(EnvFilter::from_default_env())
.init();
let server = Arc::new(RelayServer::default());
// TODO: scrape some dudes
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
server::listen(Arc::clone(&server), addr).await?;
Ok(())
}

1
src/prelude.rs Normal file
View file

@ -0,0 +1 @@
pub use tap::prelude::*;

183
src/relay_subscription.rs Normal file
View file

@ -0,0 +1,183 @@
use crate::prelude::*;
use std::sync::Arc;
use anyhow::Result;
use fastwebsockets::{
upgrade::{is_upgrade_request, upgrade},
FragmentCollectorRead, Frame, OpCode, Payload, WebSocket, WebSocketError, WebSocketWrite,
};
use hyper::{
body::{Bytes, Incoming},
upgrade::Upgraded,
Request, Response, StatusCode,
};
use hyper_util::rt::TokioIo;
use qstring::QString;
use tokio::{
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
},
sync::broadcast::{error::RecvError, Receiver},
};
use uuid::Uuid;
use crate::server::{empty, full, RelayServer, ServerResponse};
enum Operation<'f> {
NoOp,
WriteBlock(Bytes),
WriteFrame(Frame<'f>),
ExitWithFrame(Frame<'f>),
Exit,
}
struct RelaySubscription {
id: Uuid,
running: bool,
}
type WSRead = FragmentCollectorRead<OwnedReadHalf>;
type WSWrite = WebSocketWrite<OwnedWriteHalf>;
impl RelaySubscription {
fn create(mut ws: WebSocket<TokioIo<Upgraded>>) -> (Self, WSRead, WSWrite) {
ws.set_auto_close(false);
ws.set_auto_pong(false);
let (ws_rx, ws_tx) = ws.split(|stream| {
let upgraded = stream.into_inner();
let parts = upgraded.downcast::<TokioIo<TcpStream>>().expect("uhhhh");
let (read, write) = parts.io.into_inner().into_split();
(read, write)
});
let ws_rx = FragmentCollectorRead::new(ws_rx);
let sub = RelaySubscription {
id: Uuid::new_v4(),
running: true,
};
(sub, ws_rx, ws_tx)
}
async fn dispatch_operation(&mut self, ws_tx: &mut WSWrite, op: Operation<'_>) {
if let Err(e) = match op {
Operation::NoOp => return,
Operation::WriteBlock(bytes) => {
ws_tx
.write_frame(Frame::binary(Payload::Borrowed(&bytes)))
.await
}
Operation::WriteFrame(frame) => ws_tx.write_frame(frame).await,
Operation::ExitWithFrame(frame) => {
let _ = ws_tx.write_frame(frame).await;
self.running = false;
return;
}
Operation::Exit => {
self.running = false;
return;
}
} {
tracing::warn!("Encountered error: {:?}", e);
self.running = false;
}
}
}
async fn read_frame<'f>(ws_rx: &mut WSRead) -> Operation<'f> {
match ws_rx
.read_frame::<_, WebSocketError>(&mut move |_| async {
unreachable!() // it'll be fiiiine :3
})
.await
{
Ok(frame) if frame.opcode == OpCode::Ping => {
Operation::WriteFrame(Frame::pong(frame.payload))
}
Ok(frame) if frame.opcode == OpCode::Close => {
Operation::ExitWithFrame(Frame::close_raw(frame.payload))
}
Ok(_frame) => {
Operation::NoOp // discard
}
Err(_e) => Operation::Exit,
}
}
async fn rebroadcast_block<'f>(block_rx: &mut Receiver<Bytes>) -> Operation<'f> {
match block_rx.recv().await {
Ok(block) => Operation::WriteBlock(block),
Err(RecvError::Closed) => Operation::ExitWithFrame(Frame::close(1001, b"Going away")),
Err(RecvError::Lagged(_)) => {
Operation::ExitWithFrame(Frame::close(1008, b"Client too slow"))
}
}
}
async fn run_subscription(
server: Arc<RelayServer>,
req: Request<Incoming>,
ws: WebSocket<TokioIo<Upgraded>>,
) {
let query = req.uri().query().map(QString::from);
let cursor: Option<usize> = query
.as_ref()
.and_then(|q| q.get("cursor"))
.and_then(|s| s.parse().ok());
let (mut sub, mut ws_rx, mut ws_tx) = RelaySubscription::create(ws);
tracing::debug!(id = %sub.id, "subscription started");
if let Some(_cursor) = cursor {
tracing::debug!(id = %sub.id, "filling from event cache");
// TODO: cursor catchup
tracing::debug!(id = %sub.id, "subscription live-tailing");
}
// live tailing:
let mut block_rx = server.block_tx.subscribe();
while sub.running {
let op = tokio::select! {
biased;
op = rebroadcast_block(&mut block_rx) => op,
op = read_frame(&mut ws_rx) => op,
};
sub.dispatch_operation(&mut ws_tx, op).await;
}
tracing::debug!(id = %sub.id, "subscription ended");
}
pub async fn handle_subscription(
server: Arc<RelayServer>,
mut req: Request<Incoming>,
) -> Result<ServerResponse> {
if !is_upgrade_request(&req) {
return Response::builder()
.status(StatusCode::UPGRADE_REQUIRED)
.body(full("Upgrade Required"))?
.pipe(Ok);
}
let (res, ws_fut) = upgrade(&mut req)?;
tokio::task::spawn(async move {
let ws = match ws_fut.await {
Ok(ws) => ws,
Err(e) => {
tracing::warn!("error upgrading WebSocket: {e:?}");
return;
}
};
run_subscription(server, req, ws).await;
});
let (head, _) = res.into_parts();
Ok(Response::from_parts(head, empty()))
}

79
src/server.rs Normal file
View file

@ -0,0 +1,79 @@
use crate::{prelude::*, relay_subscription::handle_subscription};
use std::{net::SocketAddr, sync::Arc};
use anyhow::Result;
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
use hyper::{
body::{Bytes, Incoming},
service::service_fn,
Method, Request, Response, StatusCode,
};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
pub struct RelayServer {
pub block_tx: tokio::sync::broadcast::Sender<Bytes>,
}
impl Default for RelayServer {
fn default() -> Self {
let (block_tx, _) = tokio::sync::broadcast::channel::<Bytes>(128);
Self { block_tx }
}
}
pub type ServerResponseBody = BoxBody<Bytes, hyper::Error>;
pub fn empty() -> ServerResponseBody {
Empty::<Bytes>::new().map_err(|e| match e {}).boxed()
}
pub fn full<T: Into<Bytes>>(chunk: T) -> ServerResponseBody {
Full::new(chunk.into()).map_err(|e| match e {}).boxed()
}
pub type ServerResponse = Response<BoxBody<Bytes, hyper::Error>>;
async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<ServerResponse> {
let path = req.uri().path();
tracing::debug!("{}", path);
match (req.method(), path) {
(&Method::GET, "/") => Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/plain")
.body(full("cerulea relay running..."))?
.pipe(Ok),
(&Method::GET, "/xrpc/com.atproto.sync.subscribeRepos") => {
handle_subscription(server, req).await
}
_ => Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "text/plain")
.body(full("Not Found"))?
.pipe(Ok),
}
}
pub async fn listen(server: Arc<RelayServer>, addr: SocketAddr) -> Result<()> {
tracing::info!("Listening on: http://{addr}/ ...");
let listener = TcpListener::bind(addr).await?;
loop {
let (stream, _client_addr) = listener.accept().await?;
let io = TokioIo::new(stream);
let server = Arc::clone(&server);
tokio::task::spawn(async move {
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service_fn(move |req| serve(Arc::clone(&server), req)))
.with_upgrades()
.await
{
eprintln!("Error handling connection: {err:?}")
}
});
}
}