463 lines
15 KiB
Rust
463 lines
15 KiB
Rust
use std::{io::Cursor, sync::Arc, time::Duration};
|
|
|
|
use anyhow::{bail, Context, Result};
|
|
use atrium_api::com::atproto::sync::{list_repos, subscribe_repos};
|
|
use bytes::{Buf, Bytes};
|
|
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
|
|
use http_body_util::BodyExt;
|
|
use hyper::{client::conn::http1, header, upgrade::Upgraded, Request, StatusCode};
|
|
use hyper_util::rt::{TokioExecutor, TokioIo};
|
|
use ipld_core::ipld::Ipld;
|
|
use serde_ipld_dagcbor::DecodeError;
|
|
use tokio::{net::TcpStream, sync::mpsc};
|
|
use tokio_rustls::rustls::pki_types::ServerName;
|
|
|
|
use crate::{
|
|
http::{body_empty, HttpBody},
|
|
tls::open_tls_stream,
|
|
user::{fetch_user, lookup_user},
|
|
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
|
|
RelayServer,
|
|
};
|
|
|
|
async fn create_ws_client(
|
|
domain: &str,
|
|
port: u16,
|
|
path: &str,
|
|
) -> Result<FragmentCollector<TokioIo<Upgraded>>> {
|
|
let tcp_stream = TcpStream::connect((domain, port)).await?;
|
|
|
|
let domain_tls = ServerName::try_from(domain.to_string())?;
|
|
let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?;
|
|
|
|
let req = Request::builder()
|
|
.method("GET")
|
|
.uri(format!("wss://{}:{}{}", &domain, port, path))
|
|
.header("Host", domain.to_string())
|
|
.header(header::UPGRADE, "websocket")
|
|
.header(header::CONNECTION, "upgrade")
|
|
.header(
|
|
"Sec-WebSocket-Key",
|
|
fastwebsockets::handshake::generate_key(),
|
|
)
|
|
.header("Sec-WebSocket-Version", "13")
|
|
.body(body_empty())?;
|
|
|
|
let (mut ws, _) =
|
|
fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream).await?;
|
|
ws.set_auto_pong(true);
|
|
ws.set_auto_close(true);
|
|
|
|
Ok(FragmentCollector::new(ws))
|
|
}
|
|
|
|
struct DataServerSubscription {
|
|
server: Arc<RelayServer>,
|
|
host: String,
|
|
event_tx: mpsc::Sender<StreamEvent>,
|
|
last_seq: Option<i64>,
|
|
}
|
|
|
|
impl DataServerSubscription {
|
|
fn new(server: Arc<RelayServer>, host: String) -> Self {
|
|
Self {
|
|
host,
|
|
event_tx: server.event_tx.clone(),
|
|
last_seq: None,
|
|
server,
|
|
}
|
|
}
|
|
|
|
async fn handle_commit(
|
|
&mut self,
|
|
event: subscribe_repos::Commit,
|
|
) -> Result<Option<StreamEventPayload>> {
|
|
let last_seq = self.last_seq.unwrap_or_default();
|
|
if event.seq < last_seq {
|
|
bail!(
|
|
"got event out of order from stream (seq = {}, prev = {})",
|
|
event.seq,
|
|
last_seq
|
|
)
|
|
}
|
|
self.last_seq = Some(event.seq);
|
|
|
|
let mut user = lookup_user(&self.server, &event.repo).await?;
|
|
|
|
let pds = user.pds.as_deref().unwrap_or_default();
|
|
if pds != self.host {
|
|
tracing::warn!(
|
|
"received event from different pds than expected (got {} expected {})",
|
|
self.host,
|
|
pds
|
|
);
|
|
|
|
// re-fetch user (without cache)
|
|
user = fetch_user(&self.server, &event.repo).await?;
|
|
let fresh_pds = user.pds.as_deref().unwrap_or_default();
|
|
if fresh_pds != self.host {
|
|
bail!(
|
|
"commit from non-authoritative pds (got {} expected {})",
|
|
self.host,
|
|
fresh_pds
|
|
);
|
|
}
|
|
}
|
|
|
|
// TODO: lookup did in takedown db tree
|
|
let takedown = false;
|
|
if takedown {
|
|
tracing::debug!(did = %user.did, seq = %event.seq, "dropping commit event from taken-down user");
|
|
return Ok(None);
|
|
}
|
|
|
|
if event.rebase {
|
|
tracing::debug!(did = %user.did, seq = %event.seq, "dropping commit event with rebase flag");
|
|
return Ok(None);
|
|
}
|
|
|
|
Ok(Some(StreamEventPayload::Commit(event)))
|
|
}
|
|
|
|
async fn handle_handle(
|
|
&mut self,
|
|
event: subscribe_repos::Handle,
|
|
) -> Result<Option<StreamEventPayload>> {
|
|
let last_seq = self.last_seq.unwrap_or_default();
|
|
if event.seq < last_seq {
|
|
bail!(
|
|
"got event out of order from stream (seq = {}, prev = {})",
|
|
event.seq,
|
|
last_seq
|
|
)
|
|
}
|
|
self.last_seq = Some(event.seq);
|
|
|
|
let user = fetch_user(&self.server, &event.did).await?;
|
|
if user.handle.as_deref() != Some(event.handle.as_str()) {
|
|
tracing::warn!(
|
|
seq = %event.seq,
|
|
expected = ?event.handle.as_str(),
|
|
got = ?user.handle,
|
|
"handle update did not update handle to asserted value"
|
|
);
|
|
}
|
|
|
|
Ok(Some(StreamEventPayload::Handle(event)))
|
|
}
|
|
|
|
async fn handle_identity(
|
|
&mut self,
|
|
event: subscribe_repos::Identity,
|
|
) -> Result<Option<StreamEventPayload>> {
|
|
let last_seq = self.last_seq.unwrap_or_default();
|
|
if event.seq < last_seq {
|
|
bail!(
|
|
"got event out of order from stream (seq = {}, prev = {})",
|
|
event.seq,
|
|
last_seq
|
|
)
|
|
}
|
|
self.last_seq = Some(event.seq);
|
|
|
|
if let Some(handle) = event.handle.as_ref() {
|
|
let user = fetch_user(&self.server, &event.did).await?;
|
|
if user.handle.as_deref() != Some(handle.as_str()) {
|
|
tracing::warn!(
|
|
seq = %event.seq,
|
|
expected = ?handle.as_str(),
|
|
got = ?user.handle,
|
|
"identity update did not update handle to asserted value"
|
|
);
|
|
}
|
|
}
|
|
|
|
Ok(Some(StreamEventPayload::Identity(event)))
|
|
}
|
|
|
|
async fn handle_account(
|
|
&mut self,
|
|
mut event: subscribe_repos::Account,
|
|
) -> Result<Option<StreamEventPayload>> {
|
|
let last_seq = self.last_seq.unwrap_or_default();
|
|
if event.seq < last_seq {
|
|
bail!(
|
|
"got event out of order from stream (seq = {}, prev = {})",
|
|
event.seq,
|
|
last_seq
|
|
)
|
|
}
|
|
self.last_seq = Some(event.seq);
|
|
|
|
let user = fetch_user(&self.server, &event.did).await?;
|
|
let pds = user.pds.as_deref().unwrap_or_default();
|
|
if pds != self.host {
|
|
bail!(
|
|
"account event from non-authoritative pds (got {} expected {})",
|
|
pds,
|
|
&self.host
|
|
)
|
|
}
|
|
|
|
// TODO: handle takedowns
|
|
let takedown = false;
|
|
if takedown {
|
|
event.status = Some("takendown".into());
|
|
event.active = false;
|
|
}
|
|
|
|
// TODO: mark user status ?
|
|
|
|
Ok(Some(StreamEventPayload::Account(event)))
|
|
}
|
|
|
|
async fn handle_migrate(
|
|
&mut self,
|
|
event: subscribe_repos::Migrate,
|
|
) -> Result<Option<StreamEventPayload>> {
|
|
let last_seq = self.last_seq.unwrap_or_default();
|
|
if event.seq < last_seq {
|
|
bail!(
|
|
"got event out of order from stream (seq = {}, prev = {})",
|
|
event.seq,
|
|
last_seq
|
|
)
|
|
}
|
|
self.last_seq = Some(event.seq);
|
|
|
|
// let _user = fetch_user(&self.server, &event.did).await?;
|
|
|
|
Ok(Some(StreamEventPayload::Migrate(event)))
|
|
}
|
|
|
|
async fn handle_tombstone(
|
|
&mut self,
|
|
event: subscribe_repos::Tombstone,
|
|
) -> Result<Option<StreamEventPayload>> {
|
|
let last_seq = self.last_seq.unwrap_or_default();
|
|
if event.seq < last_seq {
|
|
bail!(
|
|
"got event out of order from stream (seq = {}, prev = {})",
|
|
event.seq,
|
|
last_seq
|
|
)
|
|
}
|
|
self.last_seq = Some(event.seq);
|
|
|
|
let user = lookup_user(&self.server, &event.did).await?;
|
|
let pds = user.pds.as_deref().unwrap_or_default();
|
|
if pds != self.host {
|
|
bail!(
|
|
"unauthoritative tombstone event from {} for {}",
|
|
&self.host,
|
|
event.did.as_str()
|
|
);
|
|
}
|
|
|
|
// TODO: mark user status as deleted ?
|
|
|
|
Ok(Some(StreamEventPayload::Tombstone(event)))
|
|
}
|
|
|
|
async fn handle_event(&mut self, frame: Bytes) -> Result<()> {
|
|
let buf: &[u8] = &frame;
|
|
let mut cursor = Cursor::new(buf);
|
|
let (header_buf, payload_buf) =
|
|
match serde_ipld_dagcbor::from_reader::<Ipld, _>(&mut cursor) {
|
|
Err(DecodeError::TrailingData) => buf.split_at(cursor.position() as usize),
|
|
_ => bail!("invalid frame type"),
|
|
};
|
|
let header = serde_ipld_dagcbor::from_slice::<StreamEventHeader>(header_buf)?;
|
|
|
|
let payload = match header.t.as_deref() {
|
|
Some("#commit") => {
|
|
let payload =
|
|
serde_ipld_dagcbor::from_slice::<subscribe_repos::Commit>(payload_buf)?;
|
|
self.handle_commit(payload).await?
|
|
}
|
|
Some("#handle") => {
|
|
let payload =
|
|
serde_ipld_dagcbor::from_slice::<subscribe_repos::Handle>(payload_buf)?;
|
|
self.handle_handle(payload).await?
|
|
}
|
|
Some("#identity") => {
|
|
let payload =
|
|
serde_ipld_dagcbor::from_slice::<subscribe_repos::Identity>(payload_buf)?;
|
|
self.handle_identity(payload).await?
|
|
}
|
|
Some("#account") => {
|
|
let payload =
|
|
serde_ipld_dagcbor::from_slice::<subscribe_repos::Account>(payload_buf)?;
|
|
self.handle_account(payload).await?
|
|
}
|
|
Some("#migrate") => {
|
|
let payload =
|
|
serde_ipld_dagcbor::from_slice::<subscribe_repos::Migrate>(payload_buf)?;
|
|
self.handle_migrate(payload).await?
|
|
}
|
|
Some("#tombstone") => {
|
|
let payload =
|
|
serde_ipld_dagcbor::from_slice::<subscribe_repos::Tombstone>(payload_buf)?;
|
|
self.handle_tombstone(payload).await?
|
|
}
|
|
Some("#info") => {
|
|
let payload = serde_ipld_dagcbor::from_slice::<subscribe_repos::Info>(payload_buf)?;
|
|
if payload.name == "OutdatedCursor" {
|
|
tracing::warn!(message = ?payload.message, "outdated cursor");
|
|
}
|
|
|
|
None
|
|
}
|
|
Some(t) => {
|
|
tracing::warn!("dropped unknown message type '{}'", t);
|
|
None
|
|
}
|
|
None => None,
|
|
};
|
|
|
|
if let Some(payload) = payload {
|
|
self.event_tx.send((header, payload)).await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
async fn get_repo_count(host: &str) -> Result<usize> {
|
|
let tcp_stream = TcpStream::connect((host, 443)).await?;
|
|
let host_tls = ServerName::try_from(host.to_string())?;
|
|
let tls_stream = open_tls_stream(tcp_stream, host_tls).await?;
|
|
let io = TokioIo::new(tls_stream);
|
|
|
|
let req = Request::builder()
|
|
.method("GET")
|
|
.uri(format!(
|
|
"https://{host}/xrpc/com.atproto.sync.listRepos?limit=1000"
|
|
))
|
|
.header("Host", host.to_string())
|
|
.body(body_empty())?;
|
|
|
|
let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?;
|
|
tokio::task::spawn(async move {
|
|
if let Err(err) = conn.await {
|
|
println!("Connection failed: {:?}", err);
|
|
}
|
|
});
|
|
|
|
let res = sender
|
|
.send_request(req)
|
|
.await
|
|
.context("Failed to send repo count request")?;
|
|
if res.status() != StatusCode::OK {
|
|
bail!("server returned non-200 status for listRepos");
|
|
}
|
|
|
|
let body = res.collect().await?.aggregate();
|
|
let output = serde_json::from_reader::<_, list_repos::Output>(body.reader())
|
|
.context("Failed to parse listRepos response as JSON")?;
|
|
|
|
Ok(output.repos.len())
|
|
}
|
|
|
|
#[tracing::instrument(skip_all, fields(host = %host))]
|
|
async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> {
|
|
tracing::debug!("establishing connection");
|
|
|
|
if get_repo_count(&host).await? >= 1000 {
|
|
bail!("too many repos! ditching from cerulea relay")
|
|
}
|
|
|
|
let _ = server.add_good_host(host.clone()).await;
|
|
let mut subscription = DataServerSubscription::new(server, host);
|
|
|
|
// TODO: load seq from db ?
|
|
|
|
'reconnect: loop {
|
|
let mut ws = create_ws_client(
|
|
&subscription.host,
|
|
443,
|
|
&format!(
|
|
"/xrpc/com.atproto.sync.subscribeRepos{}",
|
|
subscription
|
|
.last_seq
|
|
.map(|c| format!("?cursor={c}"))
|
|
.unwrap_or_default()
|
|
),
|
|
)
|
|
.await?;
|
|
tracing::debug!(seq = ?subscription.last_seq, "listening");
|
|
|
|
loop {
|
|
match ws.read_frame().await {
|
|
Ok(frame) if frame.opcode == OpCode::Binary => {
|
|
let bytes = match frame.payload {
|
|
Payload::BorrowedMut(slice) => Bytes::from(&*slice),
|
|
Payload::Borrowed(slice) => Bytes::from(slice),
|
|
Payload::Owned(vec) => Bytes::from(vec),
|
|
Payload::Bytes(bytes_mut) => Bytes::from(bytes_mut),
|
|
};
|
|
|
|
if let Err(e) = subscription.handle_event(bytes).await {
|
|
tracing::error!("error handling event (skipping): {e:?}");
|
|
}
|
|
}
|
|
Ok(frame) if frame.opcode == OpCode::Close => {
|
|
tracing::debug!("got close frame. reconnecting in 10s");
|
|
tokio::time::sleep(Duration::from_secs(10)).await;
|
|
continue 'reconnect;
|
|
}
|
|
Ok(frame) => {
|
|
tracing::warn!("unexpected frame type {:?}", frame.opcode);
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("{e:?}");
|
|
// TODO: should we try to reconnect in every situation?
|
|
if let WebSocketError::UnexpectedEOF = e {
|
|
tracing::debug!("got unexpected EOF. reconnecting immediately");
|
|
continue 'reconnect;
|
|
} else {
|
|
break 'reconnect;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
tracing::debug!("disconnected");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()> {
|
|
{
|
|
let mut active_indexers = server.active_indexers.lock().await;
|
|
if active_indexers.contains(&host) {
|
|
bail!("Indexer already running for host {}", &host);
|
|
}
|
|
|
|
active_indexers.insert(host.clone());
|
|
}
|
|
|
|
let r = host_subscription(Arc::clone(&server), host.clone()).await;
|
|
|
|
{
|
|
let mut active_indexers = server.active_indexers.lock().await;
|
|
active_indexers.remove(&host);
|
|
}
|
|
|
|
r
|
|
}
|
|
|
|
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
|
|
// in future we will spider out but right now i just want da stuff from my PDS
|
|
|
|
for host in hosts.iter() {
|
|
let host = host.to_string();
|
|
let server = Arc::clone(&server);
|
|
tokio::task::spawn(async move {
|
|
if let Err(e) = index_server(server, host.clone()).await {
|
|
tracing::warn!(%host, "encountered error subscribing to PDS: {e:?}");
|
|
}
|
|
});
|
|
}
|
|
}
|