diff --git a/src/admin.rs b/src/admin.rs index 23ef87b..1c24d33 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -6,12 +6,12 @@ use hyper::{body::Incoming, Request}; use ipld_core::ipld::Ipld; use serde_ipld_dagcbor::DecodeError; -use crate::{http::ServerResponse, wire_proto::StreamEventHeader, RelayServer}; +use crate::{http::ServerResponse, wire_proto::StreamEventHeader, AppState}; -pub fn purge_did(server: &RelayServer, did: &str) -> Result<()> { +pub fn purge_did(app: &AppState, did: &str) -> Result<()> { // drop commits - for event in server.db_history.iter() { + for event in app.db_history.iter() { let (seq, event) = event?; let mut cursor = Cursor::new(&event); let (header_buf, payload_buf) = @@ -37,7 +37,7 @@ pub fn purge_did(server: &RelayServer, did: &str) -> Result<()> { if let Some(event_did) = event_did { if event_did.as_str() == did { - let _ = server.db_history.remove(seq); + let _ = app.db_history.remove(seq); } } } @@ -48,7 +48,7 @@ pub fn purge_did(server: &RelayServer, did: &str) -> Result<()> { // TODO: ban host pub async fn handle_purge_did( - _server: Arc, + _app: Arc, _req: Request, ) -> Result { // TODO: diff --git a/src/app_state.rs b/src/app_state.rs index 1ed0597..b687b07 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -6,7 +6,7 @@ use tokio::sync::{broadcast, mpsc, Mutex}; use crate::wire_proto::StreamEvent; -pub struct RelayServer { +pub struct AppState { pub db: sled::Db, pub db_history: sled::Tree, pub db_users: sled::Tree, @@ -21,7 +21,7 @@ pub struct RelayServer { pub raw_block_tx: broadcast::Sender, } -impl RelayServer { +impl AppState { pub fn new(db: sled::Db, event_tx: mpsc::Sender) -> Self { let (raw_block_tx, _) = broadcast::channel(128); diff --git a/src/http.rs b/src/http.rs index b577199..df07c20 100644 --- a/src/http.rs +++ b/src/http.rs @@ -13,7 +13,7 @@ use tokio::net::TcpListener; use crate::{ admin::handle_purge_did, relay::{request_crawl::handle_request_crawl, subscribe::handle_subscription}, - RelayServer, + AppState, }; pub type HttpBody = BoxBody; @@ -26,7 +26,7 @@ pub fn body_full>(chunk: T) -> HttpBody { pub type ServerResponse = Response>; -async fn serve(server: Arc, req: Request) -> Result { +async fn serve(app: Arc, req: Request) -> Result { let path = req.uri().path(); tracing::debug!("{}", path); @@ -38,13 +38,13 @@ async fn serve(server: Arc, req: Request) -> Result { - handle_subscription(server, req).await + handle_subscription(app, req).await } (&Method::POST, "/xrpc/com.atproto.sync.requestCrawl") => { - handle_request_crawl(server, req).await + handle_request_crawl(app, req).await } - (&Method::POST, "/api/admin/purge-did") => handle_purge_did(server, req).await, + (&Method::POST, "/api/admin/purge-did") => handle_purge_did(app, req).await, _ => Ok(Response::builder() .status(StatusCode::NOT_FOUND) @@ -53,7 +53,7 @@ async fn serve(server: Arc, req: Request) -> Result, addr: SocketAddr) -> Result<()> { +pub async fn listen(app: Arc, addr: SocketAddr) -> Result<()> { tracing::info!("Listening on: http://{addr}/ ..."); let listener = TcpListener::bind(addr).await?; @@ -61,7 +61,7 @@ pub async fn listen(server: Arc, addr: SocketAddr) -> Result<()> { loop { let (stream, _client_addr) = listener.accept().await?; let io = TokioIo::new(stream); - let server = Arc::clone(&server); + let server = Arc::clone(&app); tokio::task::spawn(async move { if let Err(err) = hyper::server::conn::http1::Builder::new() .serve_connection(io, service_fn(move |req| serve(Arc::clone(&server), req))) diff --git a/src/main.rs b/src/main.rs index 2dae133..1ec0996 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ use cerulea_relay::{ http::{self}, relay::index::index_servers, sequencer::start_sequencer, - RelayServer, + AppState, }; #[derive(Parser, Debug)] @@ -52,7 +52,7 @@ async fn main() -> Result<()> { let (event_tx, event_rx) = mpsc::channel(128); - let mut server = RelayServer::new(db, event_tx); + let mut server = AppState::new(db, event_tx); if let Some(plc_directory) = args.plc_resolver { server.plc_resolver = Cow::Owned(plc_directory); } diff --git a/src/relay/index.rs b/src/relay/index.rs index 7a4ba5d..d8e3d70 100644 --- a/src/relay/index.rs +++ b/src/relay/index.rs @@ -18,7 +18,7 @@ use crate::{ tls::open_tls_stream, user::{fetch_user, lookup_user}, wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload}, - RelayServer, + AppState, }; async fn create_ws_client( @@ -53,19 +53,19 @@ async fn create_ws_client( } struct DataServerSubscription { - server: Arc, + app: Arc, host: String, event_tx: mpsc::Sender, last_seq: Option, } impl DataServerSubscription { - fn new(server: Arc, host: String) -> Self { + fn new(app: Arc, host: String) -> Self { Self { host, - event_tx: server.event_tx.clone(), + event_tx: app.event_tx.clone(), last_seq: None, - server, + app, } } @@ -83,7 +83,7 @@ impl DataServerSubscription { } self.last_seq = Some(event.seq); - let mut user = lookup_user(&self.server, &event.repo).await?; + let mut user = lookup_user(&self.app, &event.repo).await?; let pds = user.pds.as_deref().unwrap_or_default(); if pds != self.host { @@ -94,7 +94,7 @@ impl DataServerSubscription { ); // re-fetch user (without cache) - user = fetch_user(&self.server, &event.repo).await?; + user = fetch_user(&self.app, &event.repo).await?; let fresh_pds = user.pds.as_deref().unwrap_or_default(); if fresh_pds != self.host { bail!( @@ -139,7 +139,7 @@ impl DataServerSubscription { } self.last_seq = Some(event.seq); - let user = fetch_user(&self.server, &event.did).await?; + let user = fetch_user(&self.app, &event.did).await?; if user.handle.as_deref() != Some(event.handle.as_str()) { tracing::warn!( seq = %event.seq, @@ -167,7 +167,7 @@ impl DataServerSubscription { self.last_seq = Some(event.seq); if let Some(handle) = event.handle.as_ref() { - let user = fetch_user(&self.server, &event.did).await?; + let user = fetch_user(&self.app, &event.did).await?; if user.handle.as_deref() != Some(handle.as_str()) { tracing::warn!( seq = %event.seq, @@ -195,7 +195,7 @@ impl DataServerSubscription { } self.last_seq = Some(event.seq); - let user = fetch_user(&self.server, &event.did).await?; + let user = fetch_user(&self.app, &event.did).await?; let pds = user.pds.as_deref().unwrap_or_default(); if pds != self.host { bail!( @@ -250,7 +250,7 @@ impl DataServerSubscription { } self.last_seq = Some(event.seq); - let user = lookup_user(&self.server, &event.did).await?; + let user = lookup_user(&self.app, &event.did).await?; let pds = user.pds.as_deref().unwrap_or_default(); if pds != self.host { bail!( @@ -329,7 +329,7 @@ impl DataServerSubscription { } fn load_cursor(&mut self) -> Result<()> { - if let Some(saved_cursor) = self.server.db_index_cursors.get(&self.host)? { + if let Some(saved_cursor) = self.app.db_index_cursors.get(&self.host)? { let mut cur_buf = [0u8; 8]; let len = 8.min(saved_cursor.len()); cur_buf[..len].copy_from_slice(&saved_cursor[..len]); @@ -341,7 +341,7 @@ impl DataServerSubscription { fn save_cursor(&self) -> Result<()> { if let Some(cur) = self.last_seq { - self.server + self.app .db_index_cursors .insert(&self.host, &i64::to_be_bytes(cur))?; } @@ -387,15 +387,15 @@ async fn get_repo_count(host: &str) -> Result { } #[tracing::instrument(skip_all, fields(host = %host))] -async fn host_subscription(server: Arc, host: String) -> Result<()> { +async fn host_subscription(app: Arc, host: String) -> Result<()> { tracing::debug!("establishing connection"); if get_repo_count(&host).await? >= 1000 { bail!("too many repos! ditching from cerulea relay") } - let _ = server.add_good_host(host.clone()).await; - let mut subscription = DataServerSubscription::new(server, host); + let _ = app.add_good_host(host.clone()).await; + let mut subscription = DataServerSubscription::new(app, host); subscription.load_cursor()?; tracing::debug!(seq = ?subscription.last_seq, "starting subscription"); @@ -456,8 +456,8 @@ async fn host_subscription(server: Arc, host: String) -> Result<()> Ok(()) } -pub async fn index_server(server: Arc, host: String) -> Result<()> { - if server.is_banned_host(&host)? { +pub async fn index_server(app: Arc, host: String) -> Result<()> { + if app.is_banned_host(&host)? { bail!("refusing to start indexer for banned host '{}'", &host); } @@ -465,7 +465,7 @@ pub async fn index_server(server: Arc, host: String) -> Result<()> // instead of just having a BTreeSet :) { - let mut active_indexers = server.active_indexers.lock().await; + let mut active_indexers = app.active_indexers.lock().await; if active_indexers.contains(&host) { bail!("Indexer already running for host"); } @@ -473,30 +473,30 @@ pub async fn index_server(server: Arc, host: String) -> Result<()> active_indexers.insert(host.clone()); } - let r = host_subscription(Arc::clone(&server), host.clone()).await; + let r = host_subscription(Arc::clone(&app), host.clone()).await; { - let mut active_indexers = server.active_indexers.lock().await; + let mut active_indexers = app.active_indexers.lock().await; active_indexers.remove(&host); } r } -pub fn start_indexing_server(server: Arc, host: String) { +pub fn start_indexing_server(app: Arc, host: String) { tokio::task::spawn(async move { - if let Err(e) = index_server(server, host.clone()).await { + if let Err(e) = index_server(app, host.clone()).await { tracing::warn!(%host, "encountered error subscribing to PDS: {e:?}"); } }); } -pub fn index_servers(server: Arc, hosts: &[String]) { +pub fn index_servers(app: Arc, hosts: &[String]) { // in future we will spider out but right now i just want da stuff from my PDS for host in hosts.iter() { let host = host.to_string(); - let server = Arc::clone(&server); + let server = Arc::clone(&app); start_indexing_server(server, host); } } diff --git a/src/relay/request_crawl.rs b/src/relay/request_crawl.rs index be894a4..1c4d44b 100644 --- a/src/relay/request_crawl.rs +++ b/src/relay/request_crawl.rs @@ -9,11 +9,11 @@ use hyper::{body::Incoming, Request, Response}; use crate::{ http::{body_full, ServerResponse}, relay::index::start_indexing_server, - RelayServer, + AppState, }; pub async fn handle_request_crawl( - server: Arc, + app: Arc, req: Request, ) -> Result { let body = req.collect().await?.aggregate(); @@ -31,7 +31,7 @@ pub async fn handle_request_crawl( }; let hostname = input.data.hostname; - start_indexing_server(server, hostname); + start_indexing_server(app, hostname); Ok(Response::builder() .status(200) diff --git a/src/relay/subscribe.rs b/src/relay/subscribe.rs index d9f0aef..0c60a08 100644 --- a/src/relay/subscribe.rs +++ b/src/relay/subscribe.rs @@ -1,4 +1,4 @@ -use crate::RelayServer; +use crate::AppState; use std::sync::Arc; @@ -120,7 +120,7 @@ async fn rebroadcast_block<'f>(block_rx: &mut Receiver) -> Operation<'f> } async fn run_subscription( - server: Arc, + app: Arc, req: Request, ws: WebSocket>, ) { @@ -139,7 +139,7 @@ async fn run_subscription( let cursor_bytes = cursor.to_be_bytes(); let mut count = 0; - for event in server.db_history.range(cursor_bytes..) { + for event in app.db_history.range(cursor_bytes..) { let (_seq, event) = match event { Ok(ev) => ev, Err(e) => { @@ -163,7 +163,7 @@ async fn run_subscription( } // live tailing: - let mut raw_block_rx = server.raw_block_tx.subscribe(); + let mut raw_block_rx = app.raw_block_tx.subscribe(); while sub.running { let op = tokio::select! { biased; @@ -177,7 +177,7 @@ async fn run_subscription( } pub async fn handle_subscription( - server: Arc, + app: Arc, mut req: Request, ) -> Result { if !is_upgrade_request(&req) { @@ -196,7 +196,7 @@ pub async fn handle_subscription( } }; - run_subscription(server, req, ws).await; + run_subscription(app, req, ws).await; }); let (head, _) = res.into_parts(); diff --git a/src/sequencer.rs b/src/sequencer.rs index cdf3d7b..8f00892 100644 --- a/src/sequencer.rs +++ b/src/sequencer.rs @@ -6,14 +6,14 @@ use tokio::sync::mpsc; use crate::{ wire_proto::{StreamEvent, StreamEventPayload}, - RelayServer, + AppState, }; async fn run_sequencer( - server: Arc, + app: Arc, mut event_rx: mpsc::Receiver, ) -> Result<()> { - let mut curr_seq = server + let mut curr_seq = app .db .get(b"history_last_seq")? .map(|v| { @@ -28,8 +28,7 @@ async fn run_sequencer( while let Some((header, payload)) = event_rx.recv().await { curr_seq += 1; - server - .db + app.db .insert(b"history_last_seq", &u128::to_le_bytes(curr_seq))?; /* if matches!( @@ -69,18 +68,16 @@ async fn run_sequencer( } let data = Bytes::from(cursor.into_inner()); - let _ = server.raw_block_tx.send(data.clone()); - server - .db_history - .insert(u128::to_be_bytes(curr_seq), &*data)?; + let _ = app.raw_block_tx.send(data.clone()); + app.db_history.insert(u128::to_be_bytes(curr_seq), &*data)?; } Ok(()) } -pub fn start_sequencer(server: Arc, event_rx: mpsc::Receiver) { +pub fn start_sequencer(app: Arc, event_rx: mpsc::Receiver) { tokio::task::spawn(async move { - if let Err(e) = run_sequencer(server, event_rx).await { + if let Err(e) = run_sequencer(app, event_rx).await { tracing::error!("sequencer error: {e:?}"); } }); diff --git a/src/user.rs b/src/user.rs index d877d98..17c17a0 100644 --- a/src/user.rs +++ b/src/user.rs @@ -11,7 +11,7 @@ use tokio::net::TcpStream; use crate::{ http::{body_empty, HttpBody}, tls::open_tls_stream, - RelayServer, + AppState, }; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -76,9 +76,9 @@ fn create_user_from_did_doc(did_doc: DidDocument) -> Result { Ok(user) } -pub async fn fetch_did_doc(server: &RelayServer, did: &str) -> Result { +pub async fn fetch_did_doc(app: &AppState, did: &str) -> Result { if did.starts_with("did:plc:") { - let domain: &str = &server.plc_resolver; + let domain: &str = &app.plc_resolver; let tcp_stream = TcpStream::connect((domain, 443)).await?; let domain_tls: ServerName<'_> = ServerName::try_from(domain.to_string())?; @@ -148,26 +148,26 @@ pub async fn fetch_did_doc(server: &RelayServer, did: &str) -> Result Result { +pub async fn fetch_user(app: &AppState, did: &str) -> Result { tracing::debug!(%did, "fetching user"); - let did_doc = fetch_did_doc(server, did).await?; + let did_doc = fetch_did_doc(app, did).await?; let user = create_user_from_did_doc(did_doc)?; - store_user(server, &user)?; + store_user(app, &user)?; Ok(user) } -pub async fn lookup_user(server: &RelayServer, did: &str) -> Result { - if let Some(cached_user) = server.db_users.get(did)? { +pub async fn lookup_user(app: &AppState, did: &str) -> Result { + if let Some(cached_user) = app.db_users.get(did)? { let cached_user = serde_ipld_dagcbor::from_slice::(&cached_user)?; return Ok(cached_user); } - fetch_user(server, did).await + fetch_user(app, did).await } -pub fn store_user(server: &RelayServer, user: &User) -> Result<()> { +pub fn store_user(app: &AppState, user: &User) -> Result<()> { let data = serde_ipld_dagcbor::to_vec(&user)?; - server.db_users.insert(&user.did, data)?; + app.db_users.insert(&user.did, data)?; Ok(()) }