From f0e5dc7428b3d8621e6e00aadcb92708a9ff5e31 Mon Sep 17 00:00:00 2001 From: Charlotte Som Date: Mon, 25 Nov 2024 19:45:43 +0200 Subject: [PATCH] properly validate pds repo authority for commits, use atrium types --- Cargo.lock | 254 ++++++++++++++++++++++++++++++++++++-- Cargo.toml | 4 +- src/http.rs | 18 ++- src/indexer.rs | 148 ++++++++++++---------- src/lib.rs | 21 ++-- src/main.rs | 11 +- src/prelude.rs | 1 - src/relay_subscription.rs | 15 ++- src/sequencer.rs | 64 +++++----- src/tls.rs | 21 ++++ src/user.rs | 97 +++++++++++++++ src/wire_proto.rs | 10 +- 12 files changed, 535 insertions(+), 129 deletions(-) delete mode 100644 src/prelude.rs create mode 100644 src/tls.rs create mode 100644 src/user.rs diff --git a/Cargo.lock b/Cargo.lock index 8276aef..c93f276 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anyhow" version = "1.0.93" @@ -38,6 +53,40 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atrium-api" +version = "0.24.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c4e5d077f7941ec5484964fba8697b7571a964b0a714e02ae7bc7332833c36b" +dependencies = [ + "atrium-xrpc", + "chrono", + "http", + "ipld-core", + "langtag", + "regex", + "serde", + "serde_bytes", + "serde_json", + "thiserror", + "tokio", + "trait-variant", +] + +[[package]] +name = "atrium-xrpc" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f223b98be2acdd7afe5b867744aee8258413ed09993099de0a036b247db0ec4c" +dependencies = [ + "http", + "serde", + "serde_html_form", + "serde_json", + "thiserror", + "trait-variant", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -141,6 +190,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "byteorder" version = "1.5.0" @@ -152,6 +207,9 @@ name = "bytes" version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +dependencies = [ + "serde", +] [[package]] name = "cbor4ii" @@ -178,6 +236,8 @@ name = "cerulea_relay" version = "0.1.0" dependencies = [ "anyhow", + "atrium-api", + "bytes", "fastwebsockets", "http-body-util", "hyper", @@ -188,8 +248,8 @@ dependencies = [ "rustls", "serde", "serde_ipld_dagcbor", + "serde_json", "sled", - "tap", "tokio", "tokio-rustls", "tracing", @@ -213,6 +273,21 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-targets", +] + [[package]] name = "cid" version = "0.11.1" @@ -247,6 +322,12 @@ dependencies = [ "cc", ] +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "core2" version = "0.4.0" @@ -389,6 +470,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "fs2" version = "0.4.3" @@ -611,6 +701,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "indexmap" version = "2.6.0" @@ -665,6 +778,24 @@ dependencies = [ "libc", ] +[[package]] +name = "js-sys" +version = "0.3.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "langtag" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed60c85f254d6ae8450cec15eedd921efbc4d1bdf6fcf6202b9a58b403f6f805" +dependencies = [ + "serde", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -799,6 +930,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "object" version = "0.36.5" @@ -1123,6 +1263,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + [[package]] name = "scopeguard" version = "1.2.0" @@ -1158,6 +1304,19 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "serde_html_form" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de514ef58196f1fc96dcaef80fe6170a1ce6215df9687a93fe8300e773fefc5" +dependencies = [ + "form_urlencoded", + "indexmap", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_ipld_dagcbor" version = "0.6.1" @@ -1170,6 +1329,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_json" +version = "1.0.133" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "sha1" version = "0.10.6" @@ -1287,12 +1458,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - [[package]] name = "thiserror" version = "1.0.69" @@ -1443,6 +1608,17 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -1515,6 +1691,61 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.89", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" + [[package]] name = "webpki-roots" version = "0.26.7" @@ -1558,6 +1789,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index daeab84..fd2a3e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,8 @@ edition = "2021" [dependencies] anyhow = "1.0.93" +atrium-api = { version = "0.24.8", default-features = false, features = ["tokio"] } +bytes = { version = "1.8.0", features = ["serde"] } fastwebsockets = { version = "0.8.0", features = ["hyper", "unstable-split", "upgrade"] } http-body-util = "0.1.2" hyper = { version = "1.5.1", features = ["client", "full", "http1", "http2", "server"] } @@ -15,8 +17,8 @@ qstring = "0.7.2" rustls = "0.23.18" serde = { version = "1.0.215", features = ["derive"] } serde_ipld_dagcbor = "0.6.1" +serde_json = "1.0.133" sled = { version = "0.34.7", features = ["compression"] } -tap = "1.0.1" tokio = { version = "1.41.1", features = ["full"] } tokio-rustls = "0.26.0" tracing = "0.1.40" diff --git a/src/http.rs b/src/http.rs index 84ffa83..a78e4a5 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,4 @@ -use crate::{prelude::*, relay_subscription::handle_subscription, RelayServer}; +use crate::{relay_subscription::handle_subscription, RelayServer}; use std::{net::SocketAddr, sync::Arc}; @@ -12,11 +12,11 @@ use hyper::{ use hyper_util::rt::TokioIo; use tokio::net::TcpListener; -pub type ServerResponseBody = BoxBody; -pub fn empty() -> ServerResponseBody { +pub type HttpBody = BoxBody; +pub fn body_empty() -> HttpBody { Empty::::new().map_err(|e| match e {}).boxed() } -pub fn full>(chunk: T) -> ServerResponseBody { +pub fn body_full>(chunk: T) -> HttpBody { Full::new(chunk.into()).map_err(|e| match e {}).boxed() } @@ -28,21 +28,19 @@ async fn serve(server: Arc, req: Request) -> Result Response::builder() + (&Method::GET, "/") => Ok(Response::builder() .status(StatusCode::OK) .header("Content-Type", "text/plain") - .body(full("cerulea relay running..."))? - .pipe(Ok), + .body(body_full("cerulea relay running..."))?), (&Method::GET, "/xrpc/com.atproto.sync.subscribeRepos") => { handle_subscription(server, req).await } - _ => Response::builder() + _ => Ok(Response::builder() .status(StatusCode::NOT_FOUND) .header("Content-Type", "text/plain") - .body(full("Not Found"))? - .pipe(Ok), + .body(body_full("Not Found"))?), } } diff --git a/src/indexer.rs b/src/indexer.rs index 1ce5bbc..dd5ac41 100644 --- a/src/indexer.rs +++ b/src/indexer.rs @@ -1,8 +1,10 @@ use std::{io::Cursor, sync::Arc}; -use anyhow::{anyhow, Result}; -use fastwebsockets::{FragmentCollector, OpCode, Payload}; -use hyper::{body::Bytes, header, upgrade::Upgraded, Request}; +use anyhow::{anyhow, bail, Context, Result}; +use atrium_api::com::atproto::sync::subscribe_repos; +use bytes::Bytes; +use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError}; +use hyper::{header, upgrade::Upgraded, Request, Uri}; use hyper_util::rt::{TokioExecutor, TokioIo}; use ipld_core::ipld::Ipld; use serde_ipld_dagcbor::DecodeError; @@ -16,8 +18,10 @@ use tokio::{ use tokio_rustls::{rustls::pki_types::ServerName, TlsConnector}; use crate::{ - http::empty, - wire_proto::{StreamingEvent, SubscriptionHeader}, + http::body_empty, + tls::open_tls_stream, + user::lookup_user, + wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload}, RelayServer, }; @@ -26,22 +30,15 @@ async fn create_ws_client( port: u16, path: &str, ) -> Result>> { - let addr = format!("{}:{}", domain, port); - let tcp_stream = TcpStream::connect(&addr).await?; + let tcp_stream = TcpStream::connect((domain, port)).await?; - let root_store = - rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let domain_tls = ServerName::try_from(domain.to_string())?; - let client_config = rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - let tls_connector = TlsConnector::from(Arc::new(client_config)); - let tls_stream = tls_connector.connect(domain_tls, tcp_stream).await?; + let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?; let req = Request::builder() .method("GET") .uri(format!("wss://{}:{}{}", &domain, port, path)) - .header("Host", &addr) + .header("Host", domain.to_string()) .header(header::UPGRADE, "websocket") .header(header::CONNECTION, "upgrade") .header( @@ -49,7 +46,7 @@ async fn create_ws_client( fastwebsockets::handshake::generate_key(), ) .header("Sec-WebSocket-Version", "13") - .body(empty())?; + .body(body_empty())?; let (mut ws, _) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream) .await @@ -61,90 +58,115 @@ async fn create_ws_client( } struct DataServerSubscription { + server: Arc, host: String, raw_block_tx: broadcast::Sender, - event_tx: mpsc::Sender, + event_tx: mpsc::Sender, last_seq: Option, } impl DataServerSubscription { - fn new(server: &RelayServer, host: String) -> Self { + fn new(server: Arc, host: String) -> Self { Self { host, raw_block_tx: server.raw_block_tx.clone(), event_tx: server.event_tx.clone(), last_seq: None, + server, } } async fn handle_event(&mut self, frame: Bytes) -> Result<()> { + // TODO: validate if this message is valid to come from this host + let buf: &[u8] = &frame; let mut cursor = Cursor::new(buf); let (header_buf, payload_buf) = match serde_ipld_dagcbor::from_reader::(&mut cursor) { Err(DecodeError::TrailingData) => buf.split_at(cursor.position() as usize), - _ => return Err(anyhow!("invalid frame type")), + _ => bail!("invalid frame type"), }; - let header = serde_ipld_dagcbor::from_slice::(header_buf)?; - let payload = serde_ipld_dagcbor::from_slice::(payload_buf)?; + let header = serde_ipld_dagcbor::from_slice::(header_buf)?; - let Ipld::Map(payload) = payload else { - return Err(anyhow!( - "invalid payload type (expected Map, got {:?})", - payload.kind() - )); // message payloads must always be objects - }; + let payload = match header.t.as_deref() { + Some("#commit") => { + let payload = + serde_ipld_dagcbor::from_slice::(payload_buf)?; - match header.t.as_deref() { - Some("#commit") | Some("#handle") | Some("#identity") | Some("#account") - | Some("#migrate") | Some("#tombstone") | Some("#labels") => { - if let Some(Ipld::Integer(i)) = payload.get("seq") { - self.last_seq = Some(*i); + let user = lookup_user(&self.server, &payload.repo).await?; + let Some(pds) = user.pds else { + bail!("user has no associated pds? {}", user.did); + }; + let uri: Uri = pds.parse()?; + if uri.authority().map(|a| a.host()) != Some(&self.host) { + bail!( + "commit from non-authoritative pds (got {} expected {})", + self.host, + pds + ); } - // send to sequencer, rewrites `seq` in payload - self.event_tx.send((header, payload)).await?; + StreamEventPayload::Commit(payload) } + Some(t) => { + tracing::warn!("dropped unknown message type '{}'", t); + return Ok(()); + } + None => { + return Ok(()); + // skip ig + } + }; - // Some("#info") | unknown - _ => { - // no need to do sequence numbering :3 we can just emit the raw data - self.raw_block_tx.send(frame)?; - } - } + self.event_tx.send((header, payload)).await?; Ok(()) } } -pub async fn subscribe_to_host(server: &RelayServer, host: String) -> Result<()> { +pub async fn subscribe_to_host(server: Arc, host: String) -> Result<()> { tracing::debug!(%host, "establishing connection"); let mut subscription = DataServerSubscription::new(server, host); - // TODO: reconnect (with backoff) (using cursor) if we lose connection + 'reconnect: loop { + let mut ws = create_ws_client( + &subscription.host, + 443, + "/xrpc/com.atproto.sync.subscribeRepos", + ) + .await?; - let mut ws = create_ws_client( - &subscription.host, - 443, - "/xrpc/com.atproto.sync.subscribeRepos", - ) - .await?; + tracing::debug!(host = %subscription.host, "listening"); - tracing::debug!(host = %subscription.host, "listening"); + loop { + match ws.read_frame().await { + Ok(frame) if frame.opcode == OpCode::Binary => { + let bytes = match frame.payload { + Payload::BorrowedMut(slice) => Bytes::from(&*slice), + Payload::Borrowed(slice) => Bytes::from(slice), + Payload::Owned(vec) => Bytes::from(vec), + Payload::Bytes(bytes_mut) => Bytes::from(bytes_mut), + }; - while let Ok(frame) = ws.read_frame().await { - if frame.opcode == OpCode::Binary { - let bytes = match frame.payload { - Payload::BorrowedMut(slice) => Bytes::from(&*slice), - Payload::Borrowed(slice) => Bytes::from(slice), - Payload::Owned(vec) => Bytes::from(vec), - Payload::Bytes(bytes_mut) => Bytes::from(bytes_mut), - }; - - if let Err(e) = subscription.handle_event(bytes).await { - tracing::error!("error handling event (skipping): {e:?}"); - continue; + if let Err(e) = subscription.handle_event(bytes).await { + tracing::error!("error handling event (skipping): {e:?}"); + } + } + Ok(frame) => { + tracing::warn!("unexpected frame type {:?}", frame.opcode); + } + Err(e) => { + tracing::error!(host = %subscription.host, "{e:?}"); + // TODO: should we try reconnect in every situation? + if let WebSocketError::UnexpectedEOF = e { + tracing::debug!(host = %subscription.host, "reconnecting"); + // TODO: should we sleep at all here + continue 'reconnect; + } else { + break 'reconnect; + } + } } } } @@ -157,13 +179,11 @@ pub async fn subscribe_to_host(server: &RelayServer, host: String) -> Result<()> pub fn index_servers(server: Arc, hosts: &[String]) { // in future we will spider out but right now i just want da stuff from my PDS - // TODO: we should be able to track and close / retry these connections lol - for host in hosts.iter() { let host = host.to_string(); let server = Arc::clone(&server); tokio::task::spawn(async move { - if let Err(e) = subscribe_to_host(&server, host).await { + if let Err(e) = subscribe_to_host(server, host).await { tracing::warn!("encountered error subscribing to PDS: {e:?}"); } }); diff --git a/src/lib.rs b/src/lib.rs index 86b1a1a..f9e7915 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,23 +1,28 @@ -use hyper::body::Bytes; +use bytes::Bytes; use tokio::sync::{broadcast, mpsc}; -use wire_proto::StreamingEvent; - -pub mod prelude; +use wire_proto::StreamEvent; pub struct RelayServer { pub db: sled::Db, + pub db_history: sled::Tree, + pub db_users: sled::Tree, - pub event_tx: mpsc::Sender, + pub event_tx: mpsc::Sender, pub raw_block_tx: broadcast::Sender, } impl RelayServer { - pub fn new(db: sled::Db, event_tx: mpsc::Sender) -> Self { + pub fn new(db: sled::Db, event_tx: mpsc::Sender) -> Self { let (raw_block_tx, _) = broadcast::channel(128); Self { - db, event_tx, raw_block_tx, + + db_history: db + .open_tree("history") + .expect("failed to open history tree"), + db_users: db.open_tree("users").expect("failed to open users tree"), + db, } } } @@ -26,4 +31,6 @@ pub mod http; pub mod indexer; pub mod relay_subscription; pub mod sequencer; +pub mod tls; +pub mod user; pub mod wire_proto; diff --git a/src/main.rs b/src/main.rs index 3bac25f..41269f5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,16 +17,25 @@ async fn main() -> Result<()> { .with(EnvFilter::from_str("cerulea_relay=debug").unwrap()) .init(); - let db = sled::open("data").expect("Failed to open database"); + let db = sled::Config::default() + .path("data") + // TODO: configurable cache capacity + .cache_capacity(1024 * 1024 * 1024) + .use_compression(true) + .open() + .expect("Failed to open database"); let (event_tx, event_rx) = mpsc::channel(128); let server = Arc::new(RelayServer::new(db, event_tx)); + // TODO: configurable static list of hosts / crawler index_servers(Arc::clone(&server), &["pds.bun.how".into()]); start_sequencer(Arc::clone(&server), event_rx); + // TODO: configurable bind address let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + http::listen(server, addr).await?; Ok(()) } diff --git a/src/prelude.rs b/src/prelude.rs deleted file mode 100644 index 4024d7b..0000000 --- a/src/prelude.rs +++ /dev/null @@ -1 +0,0 @@ -pub use tap::prelude::*; diff --git a/src/relay_subscription.rs b/src/relay_subscription.rs index e2898df..af2e243 100644 --- a/src/relay_subscription.rs +++ b/src/relay_subscription.rs @@ -1,4 +1,4 @@ -use crate::{prelude::*, RelayServer}; +use crate::RelayServer; use std::sync::Arc; @@ -23,7 +23,7 @@ use tokio::{ }; use uuid::Uuid; -use crate::http::{empty, full, ServerResponse}; +use crate::http::{body_empty, body_full, ServerResponse}; enum Operation<'f> { NoOp, @@ -143,11 +143,11 @@ async fn run_subscription( } // live tailing: - let mut block_rx = server.raw_block_tx.subscribe(); + let mut raw_block_rx = server.raw_block_tx.subscribe(); while sub.running { let op = tokio::select! { biased; - op = rebroadcast_block(&mut block_rx) => op, + op = rebroadcast_block(&mut raw_block_rx) => op, op = read_frame(&mut ws_rx) => op, }; sub.dispatch_operation(&mut ws_tx, op).await; @@ -161,10 +161,9 @@ pub async fn handle_subscription( mut req: Request, ) -> Result { if !is_upgrade_request(&req) { - return Response::builder() + return Ok(Response::builder() .status(StatusCode::UPGRADE_REQUIRED) - .body(full("Upgrade Required"))? - .pipe(Ok); + .body(body_full("Upgrade Required"))?); } let (res, ws_fut) = upgrade(&mut req)?; @@ -181,5 +180,5 @@ pub async fn handle_subscription( }); let (head, _) = res.into_parts(); - Response::from_parts(head, empty()).pipe(Ok) + Ok(Response::from_parts(head, body_empty())) } diff --git a/src/sequencer.rs b/src/sequencer.rs index 39195ba..7709d9d 100644 --- a/src/sequencer.rs +++ b/src/sequencer.rs @@ -1,33 +1,38 @@ use std::{io::Cursor, sync::Arc}; use anyhow::Result; -use hyper::body::Bytes; -use ipld_core::ipld::Ipld; +use bytes::Bytes; use tokio::sync::mpsc; -use crate::{wire_proto::StreamingEvent, RelayServer}; +use crate::{ + wire_proto::{StreamEvent, StreamEventPayload}, + RelayServer, +}; async fn run_sequencer( server: Arc, - mut event_rx: mpsc::Receiver, + mut event_rx: mpsc::Receiver, ) -> Result<()> { - let db = server.db.clone(); - let mut curr_seq = db - .get(b"curr_seq")? + let mut curr_seq = server + .db + .get(b"history_last_seq")? .map(|v| { - let mut buf = [0u8; 8]; - let len = 8.min(v.len()); + let mut buf = [0u8; 16]; + let len = 16.min(v.len()); buf[..len].copy_from_slice(&v[..len]); - u64::from_le_bytes(buf) + u128::from_le_bytes(buf) }) .unwrap_or_default(); - let events = db.open_tree(b"events")?; - tracing::debug!(seq = %curr_seq, "initial sequence number"); - while let Some((header, mut payload)) = event_rx.recv().await { - let seq_bump = matches!( + while let Some((header, payload)) = event_rx.recv().await { + curr_seq += 1; + server + .db + .insert(b"history_last_seq", &u128::to_le_bytes(curr_seq))?; + + /* if matches!( header.t.as_deref(), Some("#commit") | Some("#handle") @@ -36,31 +41,34 @@ async fn run_sequencer( | Some("#migrate") | Some("#tombstone") | Some("#labels") - ); - - if seq_bump { - curr_seq += 1; + ) { payload.insert("seq".into(), Ipld::Integer(curr_seq as i128)); - db.insert(b"curr_seq", &u64::to_le_bytes(curr_seq))?; - } + } */ let mut cursor = Cursor::new(Vec::with_capacity(1024 * 1024)); serde_ipld_dagcbor::to_writer(&mut cursor, &header)?; - serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?; + + match payload { + StreamEventPayload::Commit(mut payload) => { + payload.seq = curr_seq as i64; + serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?; + } + StreamEventPayload::Unknown(payload) => { + serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?; + } + } let data = Bytes::from(cursor.into_inner()); - if seq_bump { - server.raw_block_tx.send(data.clone())?; - events.insert(u64::to_be_bytes(curr_seq), &*data)?; - } else { - server.raw_block_tx.send(data)?; - } + let _ = server.raw_block_tx.send(data.clone()); + server + .db_history + .insert(u128::to_be_bytes(curr_seq), &*data)?; } Ok(()) } -pub fn start_sequencer(server: Arc, event_rx: mpsc::Receiver) { +pub fn start_sequencer(server: Arc, event_rx: mpsc::Receiver) { tokio::task::spawn(async move { if let Err(e) = run_sequencer(server, event_rx).await { tracing::error!("sequencer error: {e:?}"); diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..3ee8ca4 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,21 @@ +use std::sync::Arc; + +use anyhow::Result; +use rustls::pki_types::ServerName; +use tokio::net::TcpStream; +use tokio_rustls::{client::TlsStream, TlsConnector}; + +pub async fn open_tls_stream( + tcp_stream: TcpStream, + domain_tls: ServerName<'static>, +) -> Result> { + let root_store = + rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + let tls_connector = TlsConnector::from(Arc::new(client_config)); + + let stream = tls_connector.connect(domain_tls, tcp_stream).await?; + Ok(stream) +} diff --git a/src/user.rs b/src/user.rs new file mode 100644 index 0000000..4eea0c7 --- /dev/null +++ b/src/user.rs @@ -0,0 +1,97 @@ +use anyhow::{bail, Context, Result}; +use atrium_api::did_doc::DidDocument; +use bytes::Buf; +use http_body_util::BodyExt; +use hyper::{client::conn::http1, Request, StatusCode}; +use hyper_util::rt::TokioIo; +use rustls::pki_types::ServerName; +use serde::{Deserialize, Serialize}; +use tokio::net::TcpStream; + +use crate::{ + http::{body_empty, HttpBody}, + tls::open_tls_stream, + RelayServer, +}; + +#[derive(Serialize, Deserialize)] +pub struct User { + pub did: String, + pub pds: Option, + #[serde(default)] + pub takedown: bool, + #[serde(default)] + pub tombstone: bool, +} + +pub async fn fetch_user(server: &RelayServer, did: &str) -> Result { + tracing::debug!(%did, "fetching user"); + if did.starts_with("did:plc:") { + // TODO: configurable plc resolver location + let domain = "plc.directory"; + + let tcp_stream = TcpStream::connect((domain, 443)).await?; + let domain_tls: ServerName<'_> = ServerName::try_from(domain.to_string())?; + let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?; + let io = TokioIo::new(tls_stream); + + let req = Request::builder() + .method("GET") + .uri(format!("https://{domain}/{did}")) + .header("Host", domain.to_string()) + .body(body_empty())?; + + let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + tracing::debug!("handshake"); + + let res = sender + .send_request(req) + .await + .context("Failed to send plc request")?; + if res.status() != StatusCode::OK { + bail!("plc directory returned non-200 status"); + } + + tracing::debug!("got response"); + + let body = res.collect().await?.aggregate(); + let did_doc = serde_json::from_reader::<_, DidDocument>(body.reader()) + .context("Failed to parse plc DID doc as JSON")?; + + let user = User { + pds: did_doc.get_pds_endpoint(), + did: did_doc.id, + takedown: false, + tombstone: false, + }; + + store_user(server, &user).await?; + + Ok(user) + } else if did.starts_with("did:web:") { + todo!("resolve did web") + } else { + bail!("unknown did type {did}"); + } +} + +pub async fn lookup_user(server: &RelayServer, did: &str) -> Result { + if let Some(cached_user) = server.db_users.get(did)? { + let cached_user = serde_ipld_dagcbor::from_slice::(&cached_user)?; + return Ok(cached_user); + } + + return fetch_user(server, did).await; +} + +pub async fn store_user(server: &RelayServer, user: &User) -> Result<()> { + let data = serde_ipld_dagcbor::to_vec(&user)?; + server.db_users.insert(&user.did, data)?; + Ok(()) +} diff --git a/src/wire_proto.rs b/src/wire_proto.rs index 42382a6..8eadf3e 100644 --- a/src/wire_proto.rs +++ b/src/wire_proto.rs @@ -1,13 +1,19 @@ use std::collections::BTreeMap; +use atrium_api::com::atproto::sync::subscribe_repos; use ipld_core::ipld::Ipld; use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] -pub struct SubscriptionHeader { +pub struct StreamEventHeader { pub op: i64, #[serde(default)] pub t: Option, } -pub type StreamingEvent = (SubscriptionHeader, BTreeMap); +pub enum StreamEventPayload { + Commit(subscribe_repos::Commit), + Unknown(BTreeMap), +} + +pub type StreamEvent = (StreamEventHeader, StreamEventPayload);