better reconnection better tracing

main
Charlotte Som 2024-11-25 20:35:26 +02:00
parent aefb726109
commit 642abf6aa5
2 changed files with 73 additions and 43 deletions

View File

@ -1,6 +1,6 @@
use std::{io::Cursor, sync::Arc};
use std::{io::Cursor, sync::Arc, time::Duration};
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{bail, Result};
use atrium_api::com::atproto::sync::subscribe_repos;
use bytes::Bytes;
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
@ -8,14 +8,8 @@ use hyper::{header, upgrade::Upgraded, Request, Uri};
use hyper_util::rt::{TokioExecutor, TokioIo};
use ipld_core::ipld::Ipld;
use serde_ipld_dagcbor::DecodeError;
use tokio::{
net::TcpStream,
sync::{
broadcast::{self},
mpsc,
},
};
use tokio_rustls::{rustls::pki_types::ServerName, TlsConnector};
use tokio::{net::TcpStream, sync::mpsc};
use tokio_rustls::rustls::pki_types::ServerName;
use crate::{
http::body_empty,
@ -60,22 +54,46 @@ async fn create_ws_client(
struct DataServerSubscription {
server: Arc<RelayServer>,
host: String,
raw_block_tx: broadcast::Sender<Bytes>,
event_tx: mpsc::Sender<StreamEvent>,
last_seq: Option<i128>,
last_seq: Option<i64>,
}
impl DataServerSubscription {
fn new(server: Arc<RelayServer>, host: String) -> Self {
Self {
host,
raw_block_tx: server.raw_block_tx.clone(),
event_tx: server.event_tx.clone(),
last_seq: None,
server,
}
}
async fn handle_commit(
&mut self,
payload: subscribe_repos::Commit,
) -> Result<Option<StreamEventPayload>> {
let user = lookup_user(&self.server, &payload.repo).await?;
let Some(pds) = user.pds else {
bail!("user has no associated pds? {:?}", user);
};
let uri: Uri = pds.parse()?;
if uri.authority().map(|a| a.host()) != Some(&self.host) {
bail!(
"commit from non-authoritative pds (got {} expected {})",
self.host,
pds
);
}
if user.takedown {
tracing::debug!(did = %user.did, seq = %payload.seq, "dropping commit event from taken-down user");
return Ok(None);
}
self.last_seq = Some(payload.seq);
Ok(Some(StreamEventPayload::Commit(payload)))
}
async fn handle_event(&mut self, frame: Bytes) -> Result<()> {
// TODO: validate if this message is valid to come from this host
@ -92,52 +110,60 @@ impl DataServerSubscription {
Some("#commit") => {
let payload =
serde_ipld_dagcbor::from_slice::<subscribe_repos::Commit>(payload_buf)?;
self.handle_commit(payload).await?
}
let user = lookup_user(&self.server, &payload.repo).await?;
let Some(pds) = user.pds else {
bail!("user has no associated pds? {}", user.did);
};
let uri: Uri = pds.parse()?;
if uri.authority().map(|a| a.host()) != Some(&self.host) {
bail!(
"commit from non-authoritative pds (got {} expected {})",
self.host,
pds
);
Some("#handle") => {
// TODO
None
}
Some("#info") => {
let payload = serde_ipld_dagcbor::from_slice::<subscribe_repos::Info>(payload_buf)?;
if payload.name == "OutdatedCursor" {
tracing::warn!(message = ?payload.message, "outdated cursor");
}
StreamEventPayload::Commit(payload)
None
}
Some(t) => {
tracing::warn!("dropped unknown message type '{}'", t);
return Ok(());
}
None => {
return Ok(());
// skip ig
None
}
None => None,
};
self.event_tx.send((header, payload)).await?;
if let Some(payload) = payload {
self.event_tx.send((header, payload)).await?;
}
Ok(())
}
}
pub async fn subscribe_to_host(server: Arc<RelayServer>, host: String) -> Result<()> {
tracing::debug!(%host, "establishing connection");
#[tracing::instrument(skip_all, fields(host = %host))]
async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> {
tracing::debug!("establishing connection");
let mut subscription = DataServerSubscription::new(server, host);
// TODO: load seq from db ?
'reconnect: loop {
let mut ws = create_ws_client(
&subscription.host,
443,
"/xrpc/com.atproto.sync.subscribeRepos",
&format!(
"/xrpc/com.atproto.sync.subscribeRepos{}",
subscription
.last_seq
.map(|c| format!("?cursor={c}"))
.unwrap_or_default()
),
)
.await?;
tracing::debug!(host = %subscription.host, "listening");
tracing::debug!(seq = ?subscription.last_seq, "listening");
loop {
match ws.read_frame().await {
@ -153,15 +179,19 @@ pub async fn subscribe_to_host(server: Arc<RelayServer>, host: String) -> Result
tracing::error!("error handling event (skipping): {e:?}");
}
}
Ok(frame) if frame.opcode == OpCode::Close => {
tracing::debug!("got close frame. reconnecting in 10s");
tokio::time::sleep(Duration::from_secs(10)).await;
continue 'reconnect;
}
Ok(frame) => {
tracing::warn!("unexpected frame type {:?}", frame.opcode);
}
Err(e) => {
tracing::error!(host = %subscription.host, "{e:?}");
// TODO: should we try reconnect in every situation?
tracing::error!("{e:?}");
// TODO: should we try to reconnect in every situation?
if let WebSocketError::UnexpectedEOF = e {
tracing::debug!(host = %subscription.host, "reconnecting");
// TODO: should we sleep at all here
tracing::debug!("got unexpected EOF. reconnecting immediately");
continue 'reconnect;
} else {
break 'reconnect;
@ -171,7 +201,7 @@ pub async fn subscribe_to_host(server: Arc<RelayServer>, host: String) -> Result
}
}
tracing::debug!(host = %subscription.host, "disconnected");
tracing::debug!("disconnected");
Ok(())
}
@ -183,7 +213,7 @@ pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
let host = host.to_string();
let server = Arc::clone(&server);
tokio::task::spawn(async move {
if let Err(e) = subscribe_to_host(server, host).await {
if let Err(e) = host_subscription(server, host).await {
tracing::warn!("encountered error subscribing to PDS: {e:?}");
}
});

View File

@ -14,7 +14,7 @@ use crate::{
RelayServer,
};
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct User {
pub did: String,
pub pds: Option<String>,