diff --git a/src/indexer.rs b/src/indexer.rs index dd5ac41..6c21628 100644 --- a/src/indexer.rs +++ b/src/indexer.rs @@ -1,6 +1,6 @@ -use std::{io::Cursor, sync::Arc}; +use std::{io::Cursor, sync::Arc, time::Duration}; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{bail, Result}; use atrium_api::com::atproto::sync::subscribe_repos; use bytes::Bytes; use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError}; @@ -8,14 +8,8 @@ use hyper::{header, upgrade::Upgraded, Request, Uri}; use hyper_util::rt::{TokioExecutor, TokioIo}; use ipld_core::ipld::Ipld; use serde_ipld_dagcbor::DecodeError; -use tokio::{ - net::TcpStream, - sync::{ - broadcast::{self}, - mpsc, - }, -}; -use tokio_rustls::{rustls::pki_types::ServerName, TlsConnector}; +use tokio::{net::TcpStream, sync::mpsc}; +use tokio_rustls::rustls::pki_types::ServerName; use crate::{ http::body_empty, @@ -60,22 +54,46 @@ async fn create_ws_client( struct DataServerSubscription { server: Arc, host: String, - raw_block_tx: broadcast::Sender, event_tx: mpsc::Sender, - last_seq: Option, + last_seq: Option, } impl DataServerSubscription { fn new(server: Arc, host: String) -> Self { Self { host, - raw_block_tx: server.raw_block_tx.clone(), event_tx: server.event_tx.clone(), last_seq: None, server, } } + async fn handle_commit( + &mut self, + payload: subscribe_repos::Commit, + ) -> Result> { + let user = lookup_user(&self.server, &payload.repo).await?; + let Some(pds) = user.pds else { + bail!("user has no associated pds? {:?}", user); + }; + let uri: Uri = pds.parse()?; + if uri.authority().map(|a| a.host()) != Some(&self.host) { + bail!( + "commit from non-authoritative pds (got {} expected {})", + self.host, + pds + ); + } + + if user.takedown { + tracing::debug!(did = %user.did, seq = %payload.seq, "dropping commit event from taken-down user"); + return Ok(None); + } + + self.last_seq = Some(payload.seq); + Ok(Some(StreamEventPayload::Commit(payload))) + } + async fn handle_event(&mut self, frame: Bytes) -> Result<()> { // TODO: validate if this message is valid to come from this host @@ -92,52 +110,60 @@ impl DataServerSubscription { Some("#commit") => { let payload = serde_ipld_dagcbor::from_slice::(payload_buf)?; + self.handle_commit(payload).await? + } - let user = lookup_user(&self.server, &payload.repo).await?; - let Some(pds) = user.pds else { - bail!("user has no associated pds? {}", user.did); - }; - let uri: Uri = pds.parse()?; - if uri.authority().map(|a| a.host()) != Some(&self.host) { - bail!( - "commit from non-authoritative pds (got {} expected {})", - self.host, - pds - ); + Some("#handle") => { + // TODO + None + } + + Some("#info") => { + let payload = serde_ipld_dagcbor::from_slice::(payload_buf)?; + if payload.name == "OutdatedCursor" { + tracing::warn!(message = ?payload.message, "outdated cursor"); } - StreamEventPayload::Commit(payload) + None } + Some(t) => { tracing::warn!("dropped unknown message type '{}'", t); - return Ok(()); - } - None => { - return Ok(()); - // skip ig + None } + None => None, }; - self.event_tx.send((header, payload)).await?; + if let Some(payload) = payload { + self.event_tx.send((header, payload)).await?; + } Ok(()) } } -pub async fn subscribe_to_host(server: Arc, host: String) -> Result<()> { - tracing::debug!(%host, "establishing connection"); +#[tracing::instrument(skip_all, fields(host = %host))] +async fn host_subscription(server: Arc, host: String) -> Result<()> { + tracing::debug!("establishing connection"); let mut subscription = DataServerSubscription::new(server, host); + // TODO: load seq from db ? + 'reconnect: loop { let mut ws = create_ws_client( &subscription.host, 443, - "/xrpc/com.atproto.sync.subscribeRepos", + &format!( + "/xrpc/com.atproto.sync.subscribeRepos{}", + subscription + .last_seq + .map(|c| format!("?cursor={c}")) + .unwrap_or_default() + ), ) .await?; - - tracing::debug!(host = %subscription.host, "listening"); + tracing::debug!(seq = ?subscription.last_seq, "listening"); loop { match ws.read_frame().await { @@ -153,15 +179,19 @@ pub async fn subscribe_to_host(server: Arc, host: String) -> Result tracing::error!("error handling event (skipping): {e:?}"); } } + Ok(frame) if frame.opcode == OpCode::Close => { + tracing::debug!("got close frame. reconnecting in 10s"); + tokio::time::sleep(Duration::from_secs(10)).await; + continue 'reconnect; + } Ok(frame) => { tracing::warn!("unexpected frame type {:?}", frame.opcode); } Err(e) => { - tracing::error!(host = %subscription.host, "{e:?}"); - // TODO: should we try reconnect in every situation? + tracing::error!("{e:?}"); + // TODO: should we try to reconnect in every situation? if let WebSocketError::UnexpectedEOF = e { - tracing::debug!(host = %subscription.host, "reconnecting"); - // TODO: should we sleep at all here + tracing::debug!("got unexpected EOF. reconnecting immediately"); continue 'reconnect; } else { break 'reconnect; @@ -171,7 +201,7 @@ pub async fn subscribe_to_host(server: Arc, host: String) -> Result } } - tracing::debug!(host = %subscription.host, "disconnected"); + tracing::debug!("disconnected"); Ok(()) } @@ -183,7 +213,7 @@ pub fn index_servers(server: Arc, hosts: &[String]) { let host = host.to_string(); let server = Arc::clone(&server); tokio::task::spawn(async move { - if let Err(e) = subscribe_to_host(server, host).await { + if let Err(e) = host_subscription(server, host).await { tracing::warn!("encountered error subscribing to PDS: {e:?}"); } }); diff --git a/src/user.rs b/src/user.rs index 4eea0c7..4aa455d 100644 --- a/src/user.rs +++ b/src/user.rs @@ -14,7 +14,7 @@ use crate::{ RelayServer, }; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct User { pub did: String, pub pds: Option,