191 lines
6.2 KiB
Rust
191 lines
6.2 KiB
Rust
use std::{io::Cursor, sync::Arc};
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
use atrium_api::com::atproto::sync::subscribe_repos;
|
|
use bytes::Bytes;
|
|
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
|
|
use hyper::{header, upgrade::Upgraded, Request, Uri};
|
|
use hyper_util::rt::{TokioExecutor, TokioIo};
|
|
use ipld_core::ipld::Ipld;
|
|
use serde_ipld_dagcbor::DecodeError;
|
|
use tokio::{
|
|
net::TcpStream,
|
|
sync::{
|
|
broadcast::{self},
|
|
mpsc,
|
|
},
|
|
};
|
|
use tokio_rustls::{rustls::pki_types::ServerName, TlsConnector};
|
|
|
|
use crate::{
|
|
http::body_empty,
|
|
tls::open_tls_stream,
|
|
user::lookup_user,
|
|
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
|
|
RelayServer,
|
|
};
|
|
|
|
async fn create_ws_client(
|
|
domain: &str,
|
|
port: u16,
|
|
path: &str,
|
|
) -> Result<FragmentCollector<TokioIo<Upgraded>>> {
|
|
let tcp_stream = TcpStream::connect((domain, port)).await?;
|
|
|
|
let domain_tls = ServerName::try_from(domain.to_string())?;
|
|
let tls_stream = open_tls_stream(tcp_stream, domain_tls).await?;
|
|
|
|
let req = Request::builder()
|
|
.method("GET")
|
|
.uri(format!("wss://{}:{}{}", &domain, port, path))
|
|
.header("Host", domain.to_string())
|
|
.header(header::UPGRADE, "websocket")
|
|
.header(header::CONNECTION, "upgrade")
|
|
.header(
|
|
"Sec-WebSocket-Key",
|
|
fastwebsockets::handshake::generate_key(),
|
|
)
|
|
.header("Sec-WebSocket-Version", "13")
|
|
.body(body_empty())?;
|
|
|
|
let (mut ws, _) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream)
|
|
.await
|
|
.unwrap();
|
|
ws.set_auto_pong(true);
|
|
ws.set_auto_close(true);
|
|
|
|
Ok(FragmentCollector::new(ws))
|
|
}
|
|
|
|
struct DataServerSubscription {
|
|
server: Arc<RelayServer>,
|
|
host: String,
|
|
raw_block_tx: broadcast::Sender<Bytes>,
|
|
event_tx: mpsc::Sender<StreamEvent>,
|
|
last_seq: Option<i128>,
|
|
}
|
|
|
|
impl DataServerSubscription {
|
|
fn new(server: Arc<RelayServer>, host: String) -> Self {
|
|
Self {
|
|
host,
|
|
raw_block_tx: server.raw_block_tx.clone(),
|
|
event_tx: server.event_tx.clone(),
|
|
last_seq: None,
|
|
server,
|
|
}
|
|
}
|
|
|
|
async fn handle_event(&mut self, frame: Bytes) -> Result<()> {
|
|
// TODO: validate if this message is valid to come from this host
|
|
|
|
let buf: &[u8] = &frame;
|
|
let mut cursor = Cursor::new(buf);
|
|
let (header_buf, payload_buf) =
|
|
match serde_ipld_dagcbor::from_reader::<Ipld, _>(&mut cursor) {
|
|
Err(DecodeError::TrailingData) => buf.split_at(cursor.position() as usize),
|
|
_ => bail!("invalid frame type"),
|
|
};
|
|
let header = serde_ipld_dagcbor::from_slice::<StreamEventHeader>(header_buf)?;
|
|
|
|
let payload = match header.t.as_deref() {
|
|
Some("#commit") => {
|
|
let payload =
|
|
serde_ipld_dagcbor::from_slice::<subscribe_repos::Commit>(payload_buf)?;
|
|
|
|
let user = lookup_user(&self.server, &payload.repo).await?;
|
|
let Some(pds) = user.pds else {
|
|
bail!("user has no associated pds? {}", user.did);
|
|
};
|
|
let uri: Uri = pds.parse()?;
|
|
if uri.authority().map(|a| a.host()) != Some(&self.host) {
|
|
bail!(
|
|
"commit from non-authoritative pds (got {} expected {})",
|
|
self.host,
|
|
pds
|
|
);
|
|
}
|
|
|
|
StreamEventPayload::Commit(payload)
|
|
}
|
|
Some(t) => {
|
|
tracing::warn!("dropped unknown message type '{}'", t);
|
|
return Ok(());
|
|
}
|
|
None => {
|
|
return Ok(());
|
|
// skip ig
|
|
}
|
|
};
|
|
|
|
self.event_tx.send((header, payload)).await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub async fn subscribe_to_host(server: Arc<RelayServer>, host: String) -> Result<()> {
|
|
tracing::debug!(%host, "establishing connection");
|
|
|
|
let mut subscription = DataServerSubscription::new(server, host);
|
|
|
|
'reconnect: loop {
|
|
let mut ws = create_ws_client(
|
|
&subscription.host,
|
|
443,
|
|
"/xrpc/com.atproto.sync.subscribeRepos",
|
|
)
|
|
.await?;
|
|
|
|
tracing::debug!(host = %subscription.host, "listening");
|
|
|
|
loop {
|
|
match ws.read_frame().await {
|
|
Ok(frame) if frame.opcode == OpCode::Binary => {
|
|
let bytes = match frame.payload {
|
|
Payload::BorrowedMut(slice) => Bytes::from(&*slice),
|
|
Payload::Borrowed(slice) => Bytes::from(slice),
|
|
Payload::Owned(vec) => Bytes::from(vec),
|
|
Payload::Bytes(bytes_mut) => Bytes::from(bytes_mut),
|
|
};
|
|
|
|
if let Err(e) = subscription.handle_event(bytes).await {
|
|
tracing::error!("error handling event (skipping): {e:?}");
|
|
}
|
|
}
|
|
Ok(frame) => {
|
|
tracing::warn!("unexpected frame type {:?}", frame.opcode);
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(host = %subscription.host, "{e:?}");
|
|
// TODO: should we try reconnect in every situation?
|
|
if let WebSocketError::UnexpectedEOF = e {
|
|
tracing::debug!(host = %subscription.host, "reconnecting");
|
|
// TODO: should we sleep at all here
|
|
continue 'reconnect;
|
|
} else {
|
|
break 'reconnect;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
tracing::debug!(host = %subscription.host, "disconnected");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
|
|
// in future we will spider out but right now i just want da stuff from my PDS
|
|
|
|
for host in hosts.iter() {
|
|
let host = host.to_string();
|
|
let server = Arc::clone(&server);
|
|
tokio::task::spawn(async move {
|
|
if let Err(e) = subscribe_to_host(server, host).await {
|
|
tracing::warn!("encountered error subscribing to PDS: {e:?}");
|
|
}
|
|
});
|
|
}
|
|
}
|