properly validate pds repo authority for commits, use atrium types

This commit is contained in:
Charlotte Som 2024-11-25 19:45:43 +02:00
parent 0a9b998469
commit f0e5dc7428
12 changed files with 535 additions and 129 deletions

254
Cargo.lock generated
View file

@ -26,6 +26,21 @@ dependencies = [
"memchr",
]
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anyhow"
version = "1.0.93"
@ -38,6 +53,40 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atrium-api"
version = "0.24.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c4e5d077f7941ec5484964fba8697b7571a964b0a714e02ae7bc7332833c36b"
dependencies = [
"atrium-xrpc",
"chrono",
"http",
"ipld-core",
"langtag",
"regex",
"serde",
"serde_bytes",
"serde_json",
"thiserror",
"tokio",
"trait-variant",
]
[[package]]
name = "atrium-xrpc"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f223b98be2acdd7afe5b867744aee8258413ed09993099de0a036b247db0ec4c"
dependencies = [
"http",
"serde",
"serde_html_form",
"serde_json",
"thiserror",
"trait-variant",
]
[[package]]
name = "autocfg"
version = "1.4.0"
@ -141,6 +190,12 @@ dependencies = [
"generic-array",
]
[[package]]
name = "bumpalo"
version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "byteorder"
version = "1.5.0"
@ -152,6 +207,9 @@ name = "bytes"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da"
dependencies = [
"serde",
]
[[package]]
name = "cbor4ii"
@ -178,6 +236,8 @@ name = "cerulea_relay"
version = "0.1.0"
dependencies = [
"anyhow",
"atrium-api",
"bytes",
"fastwebsockets",
"http-body-util",
"hyper",
@ -188,8 +248,8 @@ dependencies = [
"rustls",
"serde",
"serde_ipld_dagcbor",
"serde_json",
"sled",
"tap",
"tokio",
"tokio-rustls",
"tracing",
@ -213,6 +273,21 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-targets",
]
[[package]]
name = "cid"
version = "0.11.1"
@ -247,6 +322,12 @@ dependencies = [
"cc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "core2"
version = "0.4.0"
@ -389,6 +470,15 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "form_urlencoded"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
dependencies = [
"percent-encoding",
]
[[package]]
name = "fs2"
version = "0.4.3"
@ -611,6 +701,29 @@ dependencies = [
"tracing",
]
[[package]]
name = "iana-time-zone"
version = "0.1.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "indexmap"
version = "2.6.0"
@ -665,6 +778,24 @@ dependencies = [
"libc",
]
[[package]]
name = "js-sys"
version = "0.3.72"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9"
dependencies = [
"wasm-bindgen",
]
[[package]]
name = "langtag"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed60c85f254d6ae8450cec15eedd921efbc4d1bdf6fcf6202b9a58b403f6f805"
dependencies = [
"serde",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
@ -799,6 +930,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "object"
version = "0.36.5"
@ -1123,6 +1263,12 @@ dependencies = [
"untrusted",
]
[[package]]
name = "ryu"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -1158,6 +1304,19 @@ dependencies = [
"syn 2.0.89",
]
[[package]]
name = "serde_html_form"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8de514ef58196f1fc96dcaef80fe6170a1ce6215df9687a93fe8300e773fefc5"
dependencies = [
"form_urlencoded",
"indexmap",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_ipld_dagcbor"
version = "0.6.1"
@ -1170,6 +1329,18 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_json"
version = "1.0.133"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "sha1"
version = "0.10.6"
@ -1287,12 +1458,6 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "thiserror"
version = "1.0.69"
@ -1443,6 +1608,17 @@ dependencies = [
"tracing-log",
]
[[package]]
name = "trait-variant"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.89",
]
[[package]]
name = "try-lock"
version = "0.2.5"
@ -1515,6 +1691,61 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasm-bindgen"
version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e"
dependencies = [
"cfg-if",
"once_cell",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.89",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.89",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d"
[[package]]
name = "webpki-roots"
version = "0.26.7"
@ -1558,6 +1789,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-sys"
version = "0.52.0"

View file

@ -5,6 +5,8 @@ edition = "2021"
[dependencies]
anyhow = "1.0.93"
atrium-api = { version = "0.24.8", default-features = false, features = ["tokio"] }
bytes = { version = "1.8.0", features = ["serde"] }
fastwebsockets = { version = "0.8.0", features = ["hyper", "unstable-split", "upgrade"] }
http-body-util = "0.1.2"
hyper = { version = "1.5.1", features = ["client", "full", "http1", "http2", "server"] }
@ -15,8 +17,8 @@ qstring = "0.7.2"
rustls = "0.23.18"
serde = { version = "1.0.215", features = ["derive"] }
serde_ipld_dagcbor = "0.6.1"
serde_json = "1.0.133"
sled = { version = "0.34.7", features = ["compression"] }
tap = "1.0.1"
tokio = { version = "1.41.1", features = ["full"] }
tokio-rustls = "0.26.0"
tracing = "0.1.40"

View file

@ -1,4 +1,4 @@
use crate::{prelude::*, relay_subscription::handle_subscription, RelayServer};
use crate::{relay_subscription::handle_subscription, RelayServer};
use std::{net::SocketAddr, sync::Arc};
@ -12,11 +12,11 @@ use hyper::{
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
pub type ServerResponseBody = BoxBody<Bytes, hyper::Error>;
pub fn empty() -> ServerResponseBody {
pub type HttpBody = BoxBody<Bytes, hyper::Error>;
pub fn body_empty() -> HttpBody {
Empty::<Bytes>::new().map_err(|e| match e {}).boxed()
}
pub fn full<T: Into<Bytes>>(chunk: T) -> ServerResponseBody {
pub fn body_full<T: Into<Bytes>>(chunk: T) -> HttpBody {
Full::new(chunk.into()).map_err(|e| match e {}).boxed()
}
@ -28,21 +28,19 @@ async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<Serve
tracing::debug!("{}", path);
match (req.method(), path) {
(&Method::GET, "/") => Response::builder()
(&Method::GET, "/") => Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/plain")
.body(full("cerulea relay running..."))?
.pipe(Ok),
.body(body_full("cerulea relay running..."))?),
(&Method::GET, "/xrpc/com.atproto.sync.subscribeRepos") => {
handle_subscription(server, req).await
}
_ => Response::builder()
_ => Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "text/plain")
.body(full("Not Found"))?
.pipe(Ok),
.body(body_full("Not Found"))?),
}
}

View file

@ -1,8 +1,10 @@
use std::{io::Cursor, sync::Arc};
use anyhow::{anyhow, Result};
use fastwebsockets::{FragmentCollector, OpCode, Payload};
use hyper::{body::Bytes, header, upgrade::Upgraded, Request};
use anyhow::{anyhow, bail, Context, Result};
use atrium_api::com::atproto::sync::subscribe_repos;
use bytes::Bytes;
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
use hyper::{header, upgrade::Upgraded, Request, Uri};
use hyper_util::rt::{TokioExecutor, TokioIo};
use ipld_core::ipld::Ipld;
use serde_ipld_dagcbor::DecodeError;
@ -16,8 +18,10 @@ use tokio::{
use tokio_rustls::{rustls::pki_types::ServerName, TlsConnector};
use crate::{
http::empty,
wire_proto::{StreamingEvent, SubscriptionHeader},
http::body_empty,
tls::open_tls_stream,
user::lookup_user,
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
RelayServer,
};
@ -26,22 +30,15 @@ async fn create_ws_client(
port: u16,
path: &str,
) -> Result<FragmentCollector<TokioIo<Upgraded>>> {
let addr = format!("{}:{}", domain, port);
let tcp_stream = TcpStream::connect(&addr).await?;
let tcp_stream = TcpStream::connect((domain, port)).await?;
let root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let domain_tls = ServerName::try_from(domain.to_string())?;
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let tls_connector = TlsConnector::from(Arc::new(client_config));
let tls_stream = tls_connector.connect(domain_tls, tcp_stream).await?;
let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?;
let req = Request::builder()
.method("GET")
.uri(format!("wss://{}:{}{}", &domain, port, path))
.header("Host", &addr)
.header("Host", domain.to_string())
.header(header::UPGRADE, "websocket")
.header(header::CONNECTION, "upgrade")
.header(
@ -49,7 +46,7 @@ async fn create_ws_client(
fastwebsockets::handshake::generate_key(),
)
.header("Sec-WebSocket-Version", "13")
.body(empty())?;
.body(body_empty())?;
let (mut ws, _) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream)
.await
@ -61,69 +58,78 @@ async fn create_ws_client(
}
struct DataServerSubscription {
server: Arc<RelayServer>,
host: String,
raw_block_tx: broadcast::Sender<Bytes>,
event_tx: mpsc::Sender<StreamingEvent>,
event_tx: mpsc::Sender<StreamEvent>,
last_seq: Option<i128>,
}
impl DataServerSubscription {
fn new(server: &RelayServer, host: String) -> Self {
fn new(server: Arc<RelayServer>, host: String) -> Self {
Self {
host,
raw_block_tx: server.raw_block_tx.clone(),
event_tx: server.event_tx.clone(),
last_seq: None,
server,
}
}
async fn handle_event(&mut self, frame: Bytes) -> Result<()> {
// TODO: validate if this message is valid to come from this host
let buf: &[u8] = &frame;
let mut cursor = Cursor::new(buf);
let (header_buf, payload_buf) =
match serde_ipld_dagcbor::from_reader::<Ipld, _>(&mut cursor) {
Err(DecodeError::TrailingData) => buf.split_at(cursor.position() as usize),
_ => return Err(anyhow!("invalid frame type")),
_ => bail!("invalid frame type"),
};
let header = serde_ipld_dagcbor::from_slice::<SubscriptionHeader>(header_buf)?;
let payload = serde_ipld_dagcbor::from_slice::<Ipld>(payload_buf)?;
let header = serde_ipld_dagcbor::from_slice::<StreamEventHeader>(header_buf)?;
let Ipld::Map(payload) = payload else {
return Err(anyhow!(
"invalid payload type (expected Map, got {:?})",
payload.kind()
)); // message payloads must always be objects
let payload = match header.t.as_deref() {
Some("#commit") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Commit>(payload_buf)?;
let user = lookup_user(&self.server, &payload.repo).await?;
let Some(pds) = user.pds else {
bail!("user has no associated pds? {}", user.did);
};
match header.t.as_deref() {
Some("#commit") | Some("#handle") | Some("#identity") | Some("#account")
| Some("#migrate") | Some("#tombstone") | Some("#labels") => {
if let Some(Ipld::Integer(i)) = payload.get("seq") {
self.last_seq = Some(*i);
let uri: Uri = pds.parse()?;
if uri.authority().map(|a| a.host()) != Some(&self.host) {
bail!(
"commit from non-authoritative pds (got {} expected {})",
self.host,
pds
);
}
// send to sequencer, rewrites `seq` in payload
StreamEventPayload::Commit(payload)
}
Some(t) => {
tracing::warn!("dropped unknown message type '{}'", t);
return Ok(());
}
None => {
return Ok(());
// skip ig
}
};
self.event_tx.send((header, payload)).await?;
}
// Some("#info") | unknown
_ => {
// no need to do sequence numbering :3 we can just emit the raw data
self.raw_block_tx.send(frame)?;
}
}
Ok(())
}
}
pub async fn subscribe_to_host(server: &RelayServer, host: String) -> Result<()> {
pub async fn subscribe_to_host(server: Arc<RelayServer>, host: String) -> Result<()> {
tracing::debug!(%host, "establishing connection");
let mut subscription = DataServerSubscription::new(server, host);
// TODO: reconnect (with backoff) (using cursor) if we lose connection
'reconnect: loop {
let mut ws = create_ws_client(
&subscription.host,
443,
@ -133,8 +139,9 @@ pub async fn subscribe_to_host(server: &RelayServer, host: String) -> Result<()>
tracing::debug!(host = %subscription.host, "listening");
while let Ok(frame) = ws.read_frame().await {
if frame.opcode == OpCode::Binary {
loop {
match ws.read_frame().await {
Ok(frame) if frame.opcode == OpCode::Binary => {
let bytes = match frame.payload {
Payload::BorrowedMut(slice) => Bytes::from(&*slice),
Payload::Borrowed(slice) => Bytes::from(slice),
@ -144,7 +151,22 @@ pub async fn subscribe_to_host(server: &RelayServer, host: String) -> Result<()>
if let Err(e) = subscription.handle_event(bytes).await {
tracing::error!("error handling event (skipping): {e:?}");
continue;
}
}
Ok(frame) => {
tracing::warn!("unexpected frame type {:?}", frame.opcode);
}
Err(e) => {
tracing::error!(host = %subscription.host, "{e:?}");
// TODO: should we try reconnect in every situation?
if let WebSocketError::UnexpectedEOF = e {
tracing::debug!(host = %subscription.host, "reconnecting");
// TODO: should we sleep at all here
continue 'reconnect;
} else {
break 'reconnect;
}
}
}
}
}
@ -157,13 +179,11 @@ pub async fn subscribe_to_host(server: &RelayServer, host: String) -> Result<()>
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
// in future we will spider out but right now i just want da stuff from my PDS
// TODO: we should be able to track and close / retry these connections lol
for host in hosts.iter() {
let host = host.to_string();
let server = Arc::clone(&server);
tokio::task::spawn(async move {
if let Err(e) = subscribe_to_host(&server, host).await {
if let Err(e) = subscribe_to_host(server, host).await {
tracing::warn!("encountered error subscribing to PDS: {e:?}");
}
});

View file

@ -1,23 +1,28 @@
use hyper::body::Bytes;
use bytes::Bytes;
use tokio::sync::{broadcast, mpsc};
use wire_proto::StreamingEvent;
pub mod prelude;
use wire_proto::StreamEvent;
pub struct RelayServer {
pub db: sled::Db,
pub db_history: sled::Tree,
pub db_users: sled::Tree,
pub event_tx: mpsc::Sender<StreamingEvent>,
pub event_tx: mpsc::Sender<StreamEvent>,
pub raw_block_tx: broadcast::Sender<Bytes>,
}
impl RelayServer {
pub fn new(db: sled::Db, event_tx: mpsc::Sender<StreamingEvent>) -> Self {
pub fn new(db: sled::Db, event_tx: mpsc::Sender<StreamEvent>) -> Self {
let (raw_block_tx, _) = broadcast::channel(128);
Self {
db,
event_tx,
raw_block_tx,
db_history: db
.open_tree("history")
.expect("failed to open history tree"),
db_users: db.open_tree("users").expect("failed to open users tree"),
db,
}
}
}
@ -26,4 +31,6 @@ pub mod http;
pub mod indexer;
pub mod relay_subscription;
pub mod sequencer;
pub mod tls;
pub mod user;
pub mod wire_proto;

View file

@ -17,16 +17,25 @@ async fn main() -> Result<()> {
.with(EnvFilter::from_str("cerulea_relay=debug").unwrap())
.init();
let db = sled::open("data").expect("Failed to open database");
let db = sled::Config::default()
.path("data")
// TODO: configurable cache capacity
.cache_capacity(1024 * 1024 * 1024)
.use_compression(true)
.open()
.expect("Failed to open database");
let (event_tx, event_rx) = mpsc::channel(128);
let server = Arc::new(RelayServer::new(db, event_tx));
// TODO: configurable static list of hosts / crawler
index_servers(Arc::clone(&server), &["pds.bun.how".into()]);
start_sequencer(Arc::clone(&server), event_rx);
// TODO: configurable bind address
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
http::listen(server, addr).await?;
Ok(())
}

View file

@ -1 +0,0 @@
pub use tap::prelude::*;

View file

@ -1,4 +1,4 @@
use crate::{prelude::*, RelayServer};
use crate::RelayServer;
use std::sync::Arc;
@ -23,7 +23,7 @@ use tokio::{
};
use uuid::Uuid;
use crate::http::{empty, full, ServerResponse};
use crate::http::{body_empty, body_full, ServerResponse};
enum Operation<'f> {
NoOp,
@ -143,11 +143,11 @@ async fn run_subscription(
}
// live tailing:
let mut block_rx = server.raw_block_tx.subscribe();
let mut raw_block_rx = server.raw_block_tx.subscribe();
while sub.running {
let op = tokio::select! {
biased;
op = rebroadcast_block(&mut block_rx) => op,
op = rebroadcast_block(&mut raw_block_rx) => op,
op = read_frame(&mut ws_rx) => op,
};
sub.dispatch_operation(&mut ws_tx, op).await;
@ -161,10 +161,9 @@ pub async fn handle_subscription(
mut req: Request<Incoming>,
) -> Result<ServerResponse> {
if !is_upgrade_request(&req) {
return Response::builder()
return Ok(Response::builder()
.status(StatusCode::UPGRADE_REQUIRED)
.body(full("Upgrade Required"))?
.pipe(Ok);
.body(body_full("Upgrade Required"))?);
}
let (res, ws_fut) = upgrade(&mut req)?;
@ -181,5 +180,5 @@ pub async fn handle_subscription(
});
let (head, _) = res.into_parts();
Response::from_parts(head, empty()).pipe(Ok)
Ok(Response::from_parts(head, body_empty()))
}

View file

@ -1,33 +1,38 @@
use std::{io::Cursor, sync::Arc};
use anyhow::Result;
use hyper::body::Bytes;
use ipld_core::ipld::Ipld;
use bytes::Bytes;
use tokio::sync::mpsc;
use crate::{wire_proto::StreamingEvent, RelayServer};
use crate::{
wire_proto::{StreamEvent, StreamEventPayload},
RelayServer,
};
async fn run_sequencer(
server: Arc<RelayServer>,
mut event_rx: mpsc::Receiver<StreamingEvent>,
mut event_rx: mpsc::Receiver<StreamEvent>,
) -> Result<()> {
let db = server.db.clone();
let mut curr_seq = db
.get(b"curr_seq")?
let mut curr_seq = server
.db
.get(b"history_last_seq")?
.map(|v| {
let mut buf = [0u8; 8];
let len = 8.min(v.len());
let mut buf = [0u8; 16];
let len = 16.min(v.len());
buf[..len].copy_from_slice(&v[..len]);
u64::from_le_bytes(buf)
u128::from_le_bytes(buf)
})
.unwrap_or_default();
let events = db.open_tree(b"events")?;
tracing::debug!(seq = %curr_seq, "initial sequence number");
while let Some((header, mut payload)) = event_rx.recv().await {
let seq_bump = matches!(
while let Some((header, payload)) = event_rx.recv().await {
curr_seq += 1;
server
.db
.insert(b"history_last_seq", &u128::to_le_bytes(curr_seq))?;
/* if matches!(
header.t.as_deref(),
Some("#commit")
| Some("#handle")
@ -36,31 +41,34 @@ async fn run_sequencer(
| Some("#migrate")
| Some("#tombstone")
| Some("#labels")
);
if seq_bump {
curr_seq += 1;
) {
payload.insert("seq".into(), Ipld::Integer(curr_seq as i128));
db.insert(b"curr_seq", &u64::to_le_bytes(curr_seq))?;
}
} */
let mut cursor = Cursor::new(Vec::with_capacity(1024 * 1024));
serde_ipld_dagcbor::to_writer(&mut cursor, &header)?;
match payload {
StreamEventPayload::Commit(mut payload) => {
payload.seq = curr_seq as i64;
serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?;
}
StreamEventPayload::Unknown(payload) => {
serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?;
}
}
let data = Bytes::from(cursor.into_inner());
if seq_bump {
server.raw_block_tx.send(data.clone())?;
events.insert(u64::to_be_bytes(curr_seq), &*data)?;
} else {
server.raw_block_tx.send(data)?;
}
let _ = server.raw_block_tx.send(data.clone());
server
.db_history
.insert(u128::to_be_bytes(curr_seq), &*data)?;
}
Ok(())
}
pub fn start_sequencer(server: Arc<RelayServer>, event_rx: mpsc::Receiver<StreamingEvent>) {
pub fn start_sequencer(server: Arc<RelayServer>, event_rx: mpsc::Receiver<StreamEvent>) {
tokio::task::spawn(async move {
if let Err(e) = run_sequencer(server, event_rx).await {
tracing::error!("sequencer error: {e:?}");

21
src/tls.rs Normal file
View file

@ -0,0 +1,21 @@
use std::sync::Arc;
use anyhow::Result;
use rustls::pki_types::ServerName;
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};
pub async fn open_tls_stream(
tcp_stream: TcpStream,
domain_tls: ServerName<'static>,
) -> Result<TlsStream<TcpStream>> {
let root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let tls_connector = TlsConnector::from(Arc::new(client_config));
let stream = tls_connector.connect(domain_tls, tcp_stream).await?;
Ok(stream)
}

97
src/user.rs Normal file
View file

@ -0,0 +1,97 @@
use anyhow::{bail, Context, Result};
use atrium_api::did_doc::DidDocument;
use bytes::Buf;
use http_body_util::BodyExt;
use hyper::{client::conn::http1, Request, StatusCode};
use hyper_util::rt::TokioIo;
use rustls::pki_types::ServerName;
use serde::{Deserialize, Serialize};
use tokio::net::TcpStream;
use crate::{
http::{body_empty, HttpBody},
tls::open_tls_stream,
RelayServer,
};
#[derive(Serialize, Deserialize)]
pub struct User {
pub did: String,
pub pds: Option<String>,
#[serde(default)]
pub takedown: bool,
#[serde(default)]
pub tombstone: bool,
}
pub async fn fetch_user(server: &RelayServer, did: &str) -> Result<User> {
tracing::debug!(%did, "fetching user");
if did.starts_with("did:plc:") {
// TODO: configurable plc resolver location
let domain = "plc.directory";
let tcp_stream = TcpStream::connect((domain, 443)).await?;
let domain_tls: ServerName<'_> = ServerName::try_from(domain.to_string())?;
let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?;
let io = TokioIo::new(tls_stream);
let req = Request::builder()
.method("GET")
.uri(format!("https://{domain}/{did}"))
.header("Host", domain.to_string())
.body(body_empty())?;
let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});
tracing::debug!("handshake");
let res = sender
.send_request(req)
.await
.context("Failed to send plc request")?;
if res.status() != StatusCode::OK {
bail!("plc directory returned non-200 status");
}
tracing::debug!("got response");
let body = res.collect().await?.aggregate();
let did_doc = serde_json::from_reader::<_, DidDocument>(body.reader())
.context("Failed to parse plc DID doc as JSON")?;
let user = User {
pds: did_doc.get_pds_endpoint(),
did: did_doc.id,
takedown: false,
tombstone: false,
};
store_user(server, &user).await?;
Ok(user)
} else if did.starts_with("did:web:") {
todo!("resolve did web")
} else {
bail!("unknown did type {did}");
}
}
pub async fn lookup_user(server: &RelayServer, did: &str) -> Result<User> {
if let Some(cached_user) = server.db_users.get(did)? {
let cached_user = serde_ipld_dagcbor::from_slice::<User>(&cached_user)?;
return Ok(cached_user);
}
return fetch_user(server, did).await;
}
pub async fn store_user(server: &RelayServer, user: &User) -> Result<()> {
let data = serde_ipld_dagcbor::to_vec(&user)?;
server.db_users.insert(&user.did, data)?;
Ok(())
}

View file

@ -1,13 +1,19 @@
use std::collections::BTreeMap;
use atrium_api::com::atproto::sync::subscribe_repos;
use ipld_core::ipld::Ipld;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize)]
pub struct SubscriptionHeader {
pub struct StreamEventHeader {
pub op: i64,
#[serde(default)]
pub t: Option<String>,
}
pub type StreamingEvent = (SubscriptionHeader, BTreeMap<String, Ipld>);
pub enum StreamEventPayload {
Commit(subscribe_repos::Commit),
Unknown(BTreeMap<String, Ipld>),
}
pub type StreamEvent = (StreamEventHeader, StreamEventPayload);