relay/src/indexer.rs

463 lines
15 KiB
Rust

use std::{io::Cursor, sync::Arc, time::Duration};
use anyhow::{bail, Context, Result};
use atrium_api::com::atproto::sync::{list_repos, subscribe_repos};
use bytes::{Buf, Bytes};
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
use http_body_util::BodyExt;
use hyper::{client::conn::http1, header, upgrade::Upgraded, Request, StatusCode};
use hyper_util::rt::{TokioExecutor, TokioIo};
use ipld_core::ipld::Ipld;
use serde_ipld_dagcbor::DecodeError;
use tokio::{net::TcpStream, sync::mpsc};
use tokio_rustls::rustls::pki_types::ServerName;
use crate::{
http::{body_empty, HttpBody},
tls::open_tls_stream,
user::{fetch_user, lookup_user},
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
RelayServer,
};
async fn create_ws_client(
domain: &str,
port: u16,
path: &str,
) -> Result<FragmentCollector<TokioIo<Upgraded>>> {
let tcp_stream = TcpStream::connect((domain, port)).await?;
let domain_tls = ServerName::try_from(domain.to_string())?;
let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?;
let req = Request::builder()
.method("GET")
.uri(format!("wss://{}:{}{}", &domain, port, path))
.header("Host", domain.to_string())
.header(header::UPGRADE, "websocket")
.header(header::CONNECTION, "upgrade")
.header(
"Sec-WebSocket-Key",
fastwebsockets::handshake::generate_key(),
)
.header("Sec-WebSocket-Version", "13")
.body(body_empty())?;
let (mut ws, _) =
fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream).await?;
ws.set_auto_pong(true);
ws.set_auto_close(true);
Ok(FragmentCollector::new(ws))
}
struct DataServerSubscription {
server: Arc<RelayServer>,
host: String,
event_tx: mpsc::Sender<StreamEvent>,
last_seq: Option<i64>,
}
impl DataServerSubscription {
fn new(server: Arc<RelayServer>, host: String) -> Self {
Self {
host,
event_tx: server.event_tx.clone(),
last_seq: None,
server,
}
}
async fn handle_commit(
&mut self,
event: subscribe_repos::Commit,
) -> Result<Option<StreamEventPayload>> {
let last_seq = self.last_seq.unwrap_or_default();
if event.seq < last_seq {
bail!(
"got event out of order from stream (seq = {}, prev = {})",
event.seq,
last_seq
)
}
self.last_seq = Some(event.seq);
let mut user = lookup_user(&self.server, &event.repo).await?;
let pds = user.pds.as_deref().unwrap_or_default();
if pds != self.host {
tracing::warn!(
"received event from different pds than expected (got {} expected {})",
self.host,
pds
);
// re-fetch user (without cache)
user = fetch_user(&self.server, &event.repo).await?;
let fresh_pds = user.pds.as_deref().unwrap_or_default();
if fresh_pds != self.host {
bail!(
"commit from non-authoritative pds (got {} expected {})",
self.host,
fresh_pds
);
}
}
// TODO: lookup did in takedown db tree
let takedown = false;
if takedown {
tracing::debug!(did = %user.did, seq = %event.seq, "dropping commit event from taken-down user");
return Ok(None);
}
if event.rebase {
tracing::debug!(did = %user.did, seq = %event.seq, "dropping commit event with rebase flag");
return Ok(None);
}
Ok(Some(StreamEventPayload::Commit(event)))
}
async fn handle_handle(
&mut self,
event: subscribe_repos::Handle,
) -> Result<Option<StreamEventPayload>> {
let last_seq = self.last_seq.unwrap_or_default();
if event.seq < last_seq {
bail!(
"got event out of order from stream (seq = {}, prev = {})",
event.seq,
last_seq
)
}
self.last_seq = Some(event.seq);
let user = fetch_user(&self.server, &event.did).await?;
if user.handle.as_deref() != Some(event.handle.as_str()) {
tracing::warn!(
seq = %event.seq,
expected = ?event.handle.as_str(),
got = ?user.handle,
"handle update did not update handle to asserted value"
);
}
Ok(Some(StreamEventPayload::Handle(event)))
}
async fn handle_identity(
&mut self,
event: subscribe_repos::Identity,
) -> Result<Option<StreamEventPayload>> {
let last_seq = self.last_seq.unwrap_or_default();
if event.seq < last_seq {
bail!(
"got event out of order from stream (seq = {}, prev = {})",
event.seq,
last_seq
)
}
self.last_seq = Some(event.seq);
if let Some(handle) = event.handle.as_ref() {
let user = fetch_user(&self.server, &event.did).await?;
if user.handle.as_deref() != Some(handle.as_str()) {
tracing::warn!(
seq = %event.seq,
expected = ?handle.as_str(),
got = ?user.handle,
"identity update did not update handle to asserted value"
);
}
}
Ok(Some(StreamEventPayload::Identity(event)))
}
async fn handle_account(
&mut self,
mut event: subscribe_repos::Account,
) -> Result<Option<StreamEventPayload>> {
let last_seq = self.last_seq.unwrap_or_default();
if event.seq < last_seq {
bail!(
"got event out of order from stream (seq = {}, prev = {})",
event.seq,
last_seq
)
}
self.last_seq = Some(event.seq);
let user = fetch_user(&self.server, &event.did).await?;
let pds = user.pds.as_deref().unwrap_or_default();
if pds != self.host {
bail!(
"account event from non-authoritative pds (got {} expected {})",
pds,
&self.host
)
}
// TODO: handle takedowns
let takedown = false;
if takedown {
event.status = Some("takendown".into());
event.active = false;
}
// TODO: mark user status ?
Ok(Some(StreamEventPayload::Account(event)))
}
async fn handle_migrate(
&mut self,
event: subscribe_repos::Migrate,
) -> Result<Option<StreamEventPayload>> {
let last_seq = self.last_seq.unwrap_or_default();
if event.seq < last_seq {
bail!(
"got event out of order from stream (seq = {}, prev = {})",
event.seq,
last_seq
)
}
self.last_seq = Some(event.seq);
// let _user = fetch_user(&self.server, &event.did).await?;
Ok(Some(StreamEventPayload::Migrate(event)))
}
async fn handle_tombstone(
&mut self,
event: subscribe_repos::Tombstone,
) -> Result<Option<StreamEventPayload>> {
let last_seq = self.last_seq.unwrap_or_default();
if event.seq < last_seq {
bail!(
"got event out of order from stream (seq = {}, prev = {})",
event.seq,
last_seq
)
}
self.last_seq = Some(event.seq);
let user = lookup_user(&self.server, &event.did).await?;
let pds = user.pds.as_deref().unwrap_or_default();
if pds != self.host {
bail!(
"unauthoritative tombstone event from {} for {}",
&self.host,
event.did.as_str()
);
}
// TODO: mark user status as deleted ?
Ok(Some(StreamEventPayload::Tombstone(event)))
}
async fn handle_event(&mut self, frame: Bytes) -> Result<()> {
let buf: &[u8] = &frame;
let mut cursor = Cursor::new(buf);
let (header_buf, payload_buf) =
match serde_ipld_dagcbor::from_reader::<Ipld, _>(&mut cursor) {
Err(DecodeError::TrailingData) => buf.split_at(cursor.position() as usize),
_ => bail!("invalid frame type"),
};
let header = serde_ipld_dagcbor::from_slice::<StreamEventHeader>(header_buf)?;
let payload = match header.t.as_deref() {
Some("#commit") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Commit>(payload_buf)?;
self.handle_commit(payload).await?
}
Some("#handle") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Handle>(payload_buf)?;
self.handle_handle(payload).await?
}
Some("#identity") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Identity>(payload_buf)?;
self.handle_identity(payload).await?
}
Some("#account") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Account>(payload_buf)?;
self.handle_account(payload).await?
}
Some("#migrate") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Migrate>(payload_buf)?;
self.handle_migrate(payload).await?
}
Some("#tombstone") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Tombstone>(payload_buf)?;
self.handle_tombstone(payload).await?
}
Some("#info") => {
let payload = serde_ipld_dagcbor::from_slice::<subscribe_repos::Info>(payload_buf)?;
if payload.name == "OutdatedCursor" {
tracing::warn!(message = ?payload.message, "outdated cursor");
}
None
}
Some(t) => {
tracing::warn!("dropped unknown message type '{}'", t);
None
}
None => None,
};
if let Some(payload) = payload {
self.event_tx.send((header, payload)).await?;
}
Ok(())
}
}
async fn get_repo_count(host: &str) -> Result<usize> {
let tcp_stream = TcpStream::connect((host, 443)).await?;
let host_tls = ServerName::try_from(host.to_string())?;
let tls_stream = open_tls_stream(tcp_stream, host_tls).await?;
let io = TokioIo::new(tls_stream);
let req = Request::builder()
.method("GET")
.uri(format!(
"https://{host}/xrpc/com.atproto.sync.listRepos?limit=1000"
))
.header("Host", host.to_string())
.body(body_empty())?;
let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});
let res = sender
.send_request(req)
.await
.context("Failed to send repo count request")?;
if res.status() != StatusCode::OK {
bail!("server returned non-200 status for listRepos");
}
let body = res.collect().await?.aggregate();
let output = serde_json::from_reader::<_, list_repos::Output>(body.reader())
.context("Failed to parse listRepos response as JSON")?;
Ok(output.repos.len())
}
#[tracing::instrument(skip_all, fields(host = %host))]
async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> {
tracing::debug!("establishing connection");
if get_repo_count(&host).await? >= 1000 {
bail!("too many repos! ditching from cerulea relay")
}
let _ = server.add_good_host(host.clone()).await;
let mut subscription = DataServerSubscription::new(server, host);
// TODO: load seq from db ?
'reconnect: loop {
let mut ws = create_ws_client(
&subscription.host,
443,
&format!(
"/xrpc/com.atproto.sync.subscribeRepos{}",
subscription
.last_seq
.map(|c| format!("?cursor={c}"))
.unwrap_or_default()
),
)
.await?;
tracing::debug!(seq = ?subscription.last_seq, "listening");
loop {
match ws.read_frame().await {
Ok(frame) if frame.opcode == OpCode::Binary => {
let bytes = match frame.payload {
Payload::BorrowedMut(slice) => Bytes::from(&*slice),
Payload::Borrowed(slice) => Bytes::from(slice),
Payload::Owned(vec) => Bytes::from(vec),
Payload::Bytes(bytes_mut) => Bytes::from(bytes_mut),
};
if let Err(e) = subscription.handle_event(bytes).await {
tracing::error!("error handling event (skipping): {e:?}");
}
}
Ok(frame) if frame.opcode == OpCode::Close => {
tracing::debug!("got close frame. reconnecting in 10s");
tokio::time::sleep(Duration::from_secs(10)).await;
continue 'reconnect;
}
Ok(frame) => {
tracing::warn!("unexpected frame type {:?}", frame.opcode);
}
Err(e) => {
tracing::error!("{e:?}");
// TODO: should we try to reconnect in every situation?
if let WebSocketError::UnexpectedEOF = e {
tracing::debug!("got unexpected EOF. reconnecting immediately");
continue 'reconnect;
} else {
break 'reconnect;
}
}
}
}
}
tracing::debug!("disconnected");
Ok(())
}
pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()> {
{
let mut active_indexers = server.active_indexers.lock().await;
if active_indexers.contains(&host) {
bail!("Indexer already running for host {}", &host);
}
active_indexers.insert(host.clone());
}
let r = host_subscription(Arc::clone(&server), host.clone()).await;
{
let mut active_indexers = server.active_indexers.lock().await;
active_indexers.remove(&host);
}
r
}
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
// in future we will spider out but right now i just want da stuff from my PDS
for host in hosts.iter() {
let host = host.to_string();
let server = Arc::clone(&server);
tokio::task::spawn(async move {
if let Err(e) = index_server(server, host.clone()).await {
tracing::warn!(%host, "encountered error subscribing to PDS: {e:?}");
}
});
}
}