use std::{io::Cursor, sync::Arc}; use anyhow::{anyhow, bail, Context, Result}; use atrium_api::com::atproto::sync::subscribe_repos; use bytes::Bytes; use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError}; use hyper::{header, upgrade::Upgraded, Request, Uri}; use hyper_util::rt::{TokioExecutor, TokioIo}; use ipld_core::ipld::Ipld; use serde_ipld_dagcbor::DecodeError; use tokio::{ net::TcpStream, sync::{ broadcast::{self}, mpsc, }, }; use tokio_rustls::{rustls::pki_types::ServerName, TlsConnector}; use crate::{ http::body_empty, tls::open_tls_stream, user::lookup_user, wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload}, RelayServer, }; async fn create_ws_client( domain: &str, port: u16, path: &str, ) -> Result>> { let tcp_stream = TcpStream::connect((domain, port)).await?; let domain_tls = ServerName::try_from(domain.to_string())?; let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?; let req = Request::builder() .method("GET") .uri(format!("wss://{}:{}{}", &domain, port, path)) .header("Host", domain.to_string()) .header(header::UPGRADE, "websocket") .header(header::CONNECTION, "upgrade") .header( "Sec-WebSocket-Key", fastwebsockets::handshake::generate_key(), ) .header("Sec-WebSocket-Version", "13") .body(body_empty())?; let (mut ws, _) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream) .await .unwrap(); ws.set_auto_pong(true); ws.set_auto_close(true); Ok(FragmentCollector::new(ws)) } struct DataServerSubscription { server: Arc, host: String, raw_block_tx: broadcast::Sender, event_tx: mpsc::Sender, last_seq: Option, } impl DataServerSubscription { fn new(server: Arc, host: String) -> Self { Self { host, raw_block_tx: server.raw_block_tx.clone(), event_tx: server.event_tx.clone(), last_seq: None, server, } } async fn handle_event(&mut self, frame: Bytes) -> Result<()> { // TODO: validate if this message is valid to come from this host let buf: &[u8] = &frame; let mut cursor = Cursor::new(buf); let (header_buf, payload_buf) = match serde_ipld_dagcbor::from_reader::(&mut cursor) { Err(DecodeError::TrailingData) => buf.split_at(cursor.position() as usize), _ => bail!("invalid frame type"), }; let header = serde_ipld_dagcbor::from_slice::(header_buf)?; let payload = match header.t.as_deref() { Some("#commit") => { let payload = serde_ipld_dagcbor::from_slice::(payload_buf)?; let user = lookup_user(&self.server, &payload.repo).await?; let Some(pds) = user.pds else { bail!("user has no associated pds? {}", user.did); }; let uri: Uri = pds.parse()?; if uri.authority().map(|a| a.host()) != Some(&self.host) { bail!( "commit from non-authoritative pds (got {} expected {})", self.host, pds ); } StreamEventPayload::Commit(payload) } Some(t) => { tracing::warn!("dropped unknown message type '{}'", t); return Ok(()); } None => { return Ok(()); // skip ig } }; self.event_tx.send((header, payload)).await?; Ok(()) } } pub async fn subscribe_to_host(server: Arc, host: String) -> Result<()> { tracing::debug!(%host, "establishing connection"); let mut subscription = DataServerSubscription::new(server, host); 'reconnect: loop { let mut ws = create_ws_client( &subscription.host, 443, "/xrpc/com.atproto.sync.subscribeRepos", ) .await?; tracing::debug!(host = %subscription.host, "listening"); loop { match ws.read_frame().await { Ok(frame) if frame.opcode == OpCode::Binary => { let bytes = match frame.payload { Payload::BorrowedMut(slice) => Bytes::from(&*slice), Payload::Borrowed(slice) => Bytes::from(slice), Payload::Owned(vec) => Bytes::from(vec), Payload::Bytes(bytes_mut) => Bytes::from(bytes_mut), }; if let Err(e) = subscription.handle_event(bytes).await { tracing::error!("error handling event (skipping): {e:?}"); } } Ok(frame) => { tracing::warn!("unexpected frame type {:?}", frame.opcode); } Err(e) => { tracing::error!(host = %subscription.host, "{e:?}"); // TODO: should we try reconnect in every situation? if let WebSocketError::UnexpectedEOF = e { tracing::debug!(host = %subscription.host, "reconnecting"); // TODO: should we sleep at all here continue 'reconnect; } else { break 'reconnect; } } } } } tracing::debug!(host = %subscription.host, "disconnected"); Ok(()) } pub fn index_servers(server: Arc, hosts: &[String]) { // in future we will spider out but right now i just want da stuff from my PDS for host in hosts.iter() { let host = host.to_string(); let server = Arc::clone(&server); tokio::task::spawn(async move { if let Err(e) = subscribe_to_host(server, host).await { tracing::warn!("encountered error subscribing to PDS: {e:?}"); } }); } }