relay-legacy/src/indexer.rs

191 lines
6.2 KiB
Rust

use std::{io::Cursor, sync::Arc};
use anyhow::{anyhow, bail, Context, Result};
use atrium_api::com::atproto::sync::subscribe_repos;
use bytes::Bytes;
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
use hyper::{header, upgrade::Upgraded, Request, Uri};
use hyper_util::rt::{TokioExecutor, TokioIo};
use ipld_core::ipld::Ipld;
use serde_ipld_dagcbor::DecodeError;
use tokio::{
net::TcpStream,
sync::{
broadcast::{self},
mpsc,
},
};
use tokio_rustls::{rustls::pki_types::ServerName, TlsConnector};
use crate::{
http::body_empty,
tls::open_tls_stream,
user::lookup_user,
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
RelayServer,
};
async fn create_ws_client(
domain: &str,
port: u16,
path: &str,
) -> Result<FragmentCollector<TokioIo<Upgraded>>> {
let tcp_stream = TcpStream::connect((domain, port)).await?;
let domain_tls = ServerName::try_from(domain.to_string())?;
let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?;
let req = Request::builder()
.method("GET")
.uri(format!("wss://{}:{}{}", &domain, port, path))
.header("Host", domain.to_string())
.header(header::UPGRADE, "websocket")
.header(header::CONNECTION, "upgrade")
.header(
"Sec-WebSocket-Key",
fastwebsockets::handshake::generate_key(),
)
.header("Sec-WebSocket-Version", "13")
.body(body_empty())?;
let (mut ws, _) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream)
.await
.unwrap();
ws.set_auto_pong(true);
ws.set_auto_close(true);
Ok(FragmentCollector::new(ws))
}
struct DataServerSubscription {
server: Arc<RelayServer>,
host: String,
raw_block_tx: broadcast::Sender<Bytes>,
event_tx: mpsc::Sender<StreamEvent>,
last_seq: Option<i128>,
}
impl DataServerSubscription {
fn new(server: Arc<RelayServer>, host: String) -> Self {
Self {
host,
raw_block_tx: server.raw_block_tx.clone(),
event_tx: server.event_tx.clone(),
last_seq: None,
server,
}
}
async fn handle_event(&mut self, frame: Bytes) -> Result<()> {
// TODO: validate if this message is valid to come from this host
let buf: &[u8] = &frame;
let mut cursor = Cursor::new(buf);
let (header_buf, payload_buf) =
match serde_ipld_dagcbor::from_reader::<Ipld, _>(&mut cursor) {
Err(DecodeError::TrailingData) => buf.split_at(cursor.position() as usize),
_ => bail!("invalid frame type"),
};
let header = serde_ipld_dagcbor::from_slice::<StreamEventHeader>(header_buf)?;
let payload = match header.t.as_deref() {
Some("#commit") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Commit>(payload_buf)?;
let user = lookup_user(&self.server, &payload.repo).await?;
let Some(pds) = user.pds else {
bail!("user has no associated pds? {}", user.did);
};
let uri: Uri = pds.parse()?;
if uri.authority().map(|a| a.host()) != Some(&self.host) {
bail!(
"commit from non-authoritative pds (got {} expected {})",
self.host,
pds
);
}
StreamEventPayload::Commit(payload)
}
Some(t) => {
tracing::warn!("dropped unknown message type '{}'", t);
return Ok(());
}
None => {
return Ok(());
// skip ig
}
};
self.event_tx.send((header, payload)).await?;
Ok(())
}
}
pub async fn subscribe_to_host(server: Arc<RelayServer>, host: String) -> Result<()> {
tracing::debug!(%host, "establishing connection");
let mut subscription = DataServerSubscription::new(server, host);
'reconnect: loop {
let mut ws = create_ws_client(
&subscription.host,
443,
"/xrpc/com.atproto.sync.subscribeRepos",
)
.await?;
tracing::debug!(host = %subscription.host, "listening");
loop {
match ws.read_frame().await {
Ok(frame) if frame.opcode == OpCode::Binary => {
let bytes = match frame.payload {
Payload::BorrowedMut(slice) => Bytes::from(&*slice),
Payload::Borrowed(slice) => Bytes::from(slice),
Payload::Owned(vec) => Bytes::from(vec),
Payload::Bytes(bytes_mut) => Bytes::from(bytes_mut),
};
if let Err(e) = subscription.handle_event(bytes).await {
tracing::error!("error handling event (skipping): {e:?}");
}
}
Ok(frame) => {
tracing::warn!("unexpected frame type {:?}", frame.opcode);
}
Err(e) => {
tracing::error!(host = %subscription.host, "{e:?}");
// TODO: should we try reconnect in every situation?
if let WebSocketError::UnexpectedEOF = e {
tracing::debug!(host = %subscription.host, "reconnecting");
// TODO: should we sleep at all here
continue 'reconnect;
} else {
break 'reconnect;
}
}
}
}
}
tracing::debug!(host = %subscription.host, "disconnected");
Ok(())
}
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
// in future we will spider out but right now i just want da stuff from my PDS
for host in hosts.iter() {
let host = host.to_string();
let server = Arc::clone(&server);
tokio::task::spawn(async move {
if let Err(e) = subscribe_to_host(server, host).await {
tracing::warn!("encountered error subscribing to PDS: {e:?}");
}
});
}
}