From 88ac43b18c56d6492d2003d53e9b76b6b8d05c3f Mon Sep 17 00:00:00 2001 From: Charlotte Som Date: Mon, 25 Nov 2024 23:30:13 +0200 Subject: [PATCH] rebroadcast handle, identity, account events --- src/indexer.rs | 157 +++++++++++++++++++++++++++++++++++++++------- src/sequencer.rs | 12 ++++ src/user.rs | 37 ++++++----- src/wire_proto.rs | 3 + 4 files changed, 169 insertions(+), 40 deletions(-) diff --git a/src/indexer.rs b/src/indexer.rs index 6c21628..1612408 100644 --- a/src/indexer.rs +++ b/src/indexer.rs @@ -4,7 +4,7 @@ use anyhow::{bail, Result}; use atrium_api::com::atproto::sync::subscribe_repos; use bytes::Bytes; use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError}; -use hyper::{header, upgrade::Upgraded, Request, Uri}; +use hyper::{header, upgrade::Upgraded, Request}; use hyper_util::rt::{TokioExecutor, TokioIo}; use ipld_core::ipld::Ipld; use serde_ipld_dagcbor::DecodeError; @@ -14,7 +14,7 @@ use tokio_rustls::rustls::pki_types::ServerName; use crate::{ http::body_empty, tls::open_tls_stream, - user::lookup_user, + user::{fetch_user, lookup_user}, wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload}, RelayServer, }; @@ -42,9 +42,8 @@ async fn create_ws_client( .header("Sec-WebSocket-Version", "13") .body(body_empty())?; - let (mut ws, _) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream) - .await - .unwrap(); + let (mut ws, _) = + fastwebsockets::handshake::client(&TokioExecutor::new(), req, tls_stream).await?; ws.set_auto_pong(true); ws.set_auto_close(true); @@ -70,33 +69,134 @@ impl DataServerSubscription { async fn handle_commit( &mut self, - payload: subscribe_repos::Commit, + event: subscribe_repos::Commit, ) -> Result> { - let user = lookup_user(&self.server, &payload.repo).await?; - let Some(pds) = user.pds else { - bail!("user has no associated pds? {:?}", user); - }; - let uri: Uri = pds.parse()?; - if uri.authority().map(|a| a.host()) != Some(&self.host) { + let last_seq = self.last_seq.unwrap_or_default(); + if event.seq < last_seq { bail!( - "commit from non-authoritative pds (got {} expected {})", + "got event out of order from stream (seq = {}, prev = {})", + event.seq, + last_seq + ) + } + self.last_seq = Some(event.seq); + + let mut user = lookup_user(&self.server, &event.repo).await?; + + let pds = user.pds.as_deref().unwrap_or_default(); + if pds != self.host { + tracing::warn!( + "received event from different pds than expected (got {} expected {})", self.host, pds ); + + // re-fetch user (without cache) + user = fetch_user(&self.server, &event.repo).await?; + let fresh_pds = user.pds.as_deref().unwrap_or_default(); + if fresh_pds != self.host { + bail!( + "commit from non-authoritative pds (got {} expected {})", + self.host, + fresh_pds + ); + } } - if user.takedown { - tracing::debug!(did = %user.did, seq = %payload.seq, "dropping commit event from taken-down user"); + // TODO: lookup did in takedown db tree + let takedown = false; + if takedown { + tracing::debug!(did = %user.did, seq = %event.seq, "dropping commit event from taken-down user"); return Ok(None); } - self.last_seq = Some(payload.seq); - Ok(Some(StreamEventPayload::Commit(payload))) + if event.rebase { + tracing::debug!(did = %user.did, seq = %event.seq, "dropping commit event with rebase flag"); + return Ok(None); + } + + Ok(Some(StreamEventPayload::Commit(event))) + } + + async fn handle_handle( + &mut self, + event: subscribe_repos::Handle, + ) -> Result> { + let last_seq = self.last_seq.unwrap_or_default(); + if event.seq < last_seq { + bail!( + "got event out of order from stream (seq = {}, prev = {})", + event.seq, + last_seq + ) + } + self.last_seq = Some(event.seq); + + let user = fetch_user(&self.server, &event.did).await?; + if user.handle.as_deref() != Some(event.handle.as_str()) { + tracing::warn!( + seq = %event.seq, + expected = ?event.handle.as_str(), + got = ?user.handle, + "handle update did not update handle to asserted value" + ); + } + + Ok(Some(StreamEventPayload::Handle(event))) + } + + async fn handle_identity( + &mut self, + event: subscribe_repos::Identity, + ) -> Result> { + let last_seq = self.last_seq.unwrap_or_default(); + if event.seq < last_seq { + bail!( + "got event out of order from stream (seq = {}, prev = {})", + event.seq, + last_seq + ) + } + self.last_seq = Some(event.seq); + + Ok(Some(StreamEventPayload::Identity(event))) + } + + async fn handle_account( + &mut self, + mut event: subscribe_repos::Account, + ) -> Result> { + let last_seq = self.last_seq.unwrap_or_default(); + if event.seq < last_seq { + bail!( + "got event out of order from stream (seq = {}, prev = {})", + event.seq, + last_seq + ) + } + self.last_seq = Some(event.seq); + + let user = fetch_user(&self.server, &event.did).await?; + let pds = user.pds.as_deref().unwrap_or_default(); + if pds != self.host { + bail!( + "account event from non-authoritative pds (got {} expected {})", + pds, + &self.host + ) + } + + // TODO: handle takedowns + let takedown = false; + if takedown { + event.status = Some("takendown".into()); + event.active = false; + } + + Ok(Some(StreamEventPayload::Account(event))) } async fn handle_event(&mut self, frame: Bytes) -> Result<()> { - // TODO: validate if this message is valid to come from this host - let buf: &[u8] = &frame; let mut cursor = Cursor::new(buf); let (header_buf, payload_buf) = @@ -112,12 +212,22 @@ impl DataServerSubscription { serde_ipld_dagcbor::from_slice::(payload_buf)?; self.handle_commit(payload).await? } - Some("#handle") => { - // TODO - None + let payload = + serde_ipld_dagcbor::from_slice::(payload_buf)?; + self.handle_handle(payload).await? } - + Some("#identity") => { + let payload = + serde_ipld_dagcbor::from_slice::(payload_buf)?; + self.handle_identity(payload).await? + } + Some("#account") => { + let payload = + serde_ipld_dagcbor::from_slice::(payload_buf)?; + self.handle_account(payload).await? + } + // TODO: migrate, tombstone Some("#info") => { let payload = serde_ipld_dagcbor::from_slice::(payload_buf)?; if payload.name == "OutdatedCursor" { @@ -126,7 +236,6 @@ impl DataServerSubscription { None } - Some(t) => { tracing::warn!("dropped unknown message type '{}'", t); None diff --git a/src/sequencer.rs b/src/sequencer.rs index 7709d9d..609a686 100644 --- a/src/sequencer.rs +++ b/src/sequencer.rs @@ -53,6 +53,18 @@ async fn run_sequencer( payload.seq = curr_seq as i64; serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?; } + StreamEventPayload::Handle(mut payload) => { + payload.seq = curr_seq as i64; + serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?; + } + StreamEventPayload::Identity(mut payload) => { + payload.seq = curr_seq as i64; + serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?; + } + StreamEventPayload::Account(mut payload) => { + payload.seq = curr_seq as i64; + serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?; + } StreamEventPayload::Unknown(payload) => { serde_ipld_dagcbor::to_writer(&mut cursor, &payload)?; } diff --git a/src/user.rs b/src/user.rs index 4aa455d..775c1d9 100644 --- a/src/user.rs +++ b/src/user.rs @@ -2,7 +2,7 @@ use anyhow::{bail, Context, Result}; use atrium_api::did_doc::DidDocument; use bytes::Buf; use http_body_util::BodyExt; -use hyper::{client::conn::http1, Request, StatusCode}; +use hyper::{client::conn::http1, Request, StatusCode, Uri}; use hyper_util::rt::TokioIo; use rustls::pki_types::ServerName; use serde::{Deserialize, Serialize}; @@ -19,9 +19,7 @@ pub struct User { pub did: String, pub pds: Option, #[serde(default)] - pub takedown: bool, - #[serde(default)] - pub tombstone: bool, + pub handle: Option, } pub async fn fetch_user(server: &RelayServer, did: &str) -> Result { @@ -48,8 +46,6 @@ pub async fn fetch_user(server: &RelayServer, did: &str) -> Result { } }); - tracing::debug!("handshake"); - let res = sender .send_request(req) .await @@ -58,20 +54,29 @@ pub async fn fetch_user(server: &RelayServer, did: &str) -> Result { bail!("plc directory returned non-200 status"); } - tracing::debug!("got response"); - let body = res.collect().await?.aggregate(); let did_doc = serde_json::from_reader::<_, DidDocument>(body.reader()) .context("Failed to parse plc DID doc as JSON")?; - let user = User { - pds: did_doc.get_pds_endpoint(), - did: did_doc.id, - takedown: false, - tombstone: false, - }; + let pds_endpoint = did_doc.get_pds_endpoint(); + let pds_uri: Option = pds_endpoint.as_deref().unwrap_or_default().parse().ok(); + let pds = pds_uri + .as_ref() + .and_then(|u| u.authority()) + .map(|a| a.host()) + .map(|s| s.to_string()); - store_user(server, &user).await?; + let handle = did_doc + .also_known_as + .and_then(|v| v.into_iter().next()) + .and_then(|s| s.strip_prefix("at://").map(str::to_string)); + let did = did_doc.id; + + // TODO: check if handle resolves to did and fill none otherwise + + let user = User { pds, did, handle }; + + store_user(server, &user)?; Ok(user) } else if did.starts_with("did:web:") { @@ -90,7 +95,7 @@ pub async fn lookup_user(server: &RelayServer, did: &str) -> Result { return fetch_user(server, did).await; } -pub async fn store_user(server: &RelayServer, user: &User) -> Result<()> { +pub fn store_user(server: &RelayServer, user: &User) -> Result<()> { let data = serde_ipld_dagcbor::to_vec(&user)?; server.db_users.insert(&user.did, data)?; Ok(()) diff --git a/src/wire_proto.rs b/src/wire_proto.rs index 8eadf3e..899a942 100644 --- a/src/wire_proto.rs +++ b/src/wire_proto.rs @@ -13,6 +13,9 @@ pub struct StreamEventHeader { pub enum StreamEventPayload { Commit(subscribe_repos::Commit), + Handle(subscribe_repos::Handle), + Identity(subscribe_repos::Identity), + Account(subscribe_repos::Account), Unknown(BTreeMap), }