From 9ee21c23d868eda3e58213618696c105286bc811 Mon Sep 17 00:00:00 2001 From: Charlotte Som Date: Wed, 27 Nov 2024 01:29:16 +0200 Subject: [PATCH] implement requestCrawl, enforce repo limit --- src/http.rs | 8 ++++- src/indexer.rs | 73 ++++++++++++++++++++++++++++++++++++++++---- src/lib.rs | 9 +++++- src/request_crawl.rs | 34 +++++++++++++++++++++ 4 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 src/request_crawl.rs diff --git a/src/http.rs b/src/http.rs index a78e4a5..2b08051 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,6 @@ -use crate::{relay_subscription::handle_subscription, RelayServer}; +use crate::{ + relay_subscription::handle_subscription, request_crawl::handle_request_crawl, RelayServer, +}; use std::{net::SocketAddr, sync::Arc}; @@ -37,6 +39,10 @@ async fn serve(server: Arc, req: Request) -> Result { + handle_request_crawl(server, req).await + } + _ => Ok(Response::builder() .status(StatusCode::NOT_FOUND) .header("Content-Type", "text/plain") diff --git a/src/indexer.rs b/src/indexer.rs index 9083848..e643734 100644 --- a/src/indexer.rs +++ b/src/indexer.rs @@ -1,10 +1,11 @@ use std::{io::Cursor, sync::Arc, time::Duration}; -use anyhow::{bail, Result}; -use atrium_api::com::atproto::sync::subscribe_repos; -use bytes::Bytes; +use anyhow::{bail, Context, Result}; +use atrium_api::com::atproto::sync::{list_repos, subscribe_repos}; +use bytes::{Buf, Bytes}; use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError}; -use hyper::{header, upgrade::Upgraded, Request}; +use http_body_util::BodyExt; +use hyper::{client::conn::http1, header, upgrade::Upgraded, Request, StatusCode}; use hyper_util::rt::{TokioExecutor, TokioIo}; use ipld_core::ipld::Ipld; use serde_ipld_dagcbor::DecodeError; @@ -12,7 +13,7 @@ use tokio::{net::TcpStream, sync::mpsc}; use tokio_rustls::rustls::pki_types::ServerName; use crate::{ - http::body_empty, + http::{body_empty, HttpBody}, tls::open_tls_stream, user::{fetch_user, lookup_user}, wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload}, @@ -322,10 +323,50 @@ impl DataServerSubscription { } } +async fn get_repo_count(host: &str) -> Result { + let tcp_stream = TcpStream::connect((host, 443)).await?; + let host_tls = ServerName::try_from(host.to_string())?; + let tls_stream = open_tls_stream(tcp_stream, host_tls).await?; + let io = TokioIo::new(tls_stream); + + let req = Request::builder() + .method("GET") + .uri(format!( + "https://{host}/xrpc/com.atproto.sync.listRepos?limit=1000" + )) + .header("Host", host.to_string()) + .body(body_empty())?; + + let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let res = sender + .send_request(req) + .await + .context("Failed to send repo count request")?; + if res.status() != StatusCode::OK { + bail!("server returned non-200 status for listRepos"); + } + + let body = res.collect().await?.aggregate(); + let output = serde_json::from_reader::<_, list_repos::Output>(body.reader()) + .context("Failed to parse listRepos response as JSON")?; + + Ok(output.repos.len()) +} + #[tracing::instrument(skip_all, fields(host = %host))] async fn host_subscription(server: Arc, host: String) -> Result<()> { tracing::debug!("establishing connection"); + if get_repo_count(&host).await? >= 1000 { + bail!("too many repos! ditching from cerulea relay") + } + let mut subscription = DataServerSubscription::new(server, host); // TODO: load seq from db ? @@ -386,6 +427,26 @@ async fn host_subscription(server: Arc, host: String) -> Result<()> Ok(()) } +pub async fn index_server(server: Arc, host: String) -> Result<()> { + { + let mut active_indexers = server.active_indexers.lock().await; + if active_indexers.contains(&host) { + bail!("Indexer already running for host {}", &host); + } + + active_indexers.insert(host.clone()); + } + + let r = host_subscription(Arc::clone(&server), host.clone()).await; + + { + let mut active_indexers = server.active_indexers.lock().await; + active_indexers.remove(&host); + } + + r +} + pub fn index_servers(server: Arc, hosts: &[String]) { // in future we will spider out but right now i just want da stuff from my PDS @@ -393,7 +454,7 @@ pub fn index_servers(server: Arc, hosts: &[String]) { let host = host.to_string(); let server = Arc::clone(&server); tokio::task::spawn(async move { - if let Err(e) = host_subscription(server, host).await { + if let Err(e) = index_server(server, host).await { tracing::warn!("encountered error subscribing to PDS: {e:?}"); } }); diff --git a/src/lib.rs b/src/lib.rs index f9e7915..9155b95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ +use std::collections::HashSet; + use bytes::Bytes; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::{broadcast, mpsc, Mutex}; use wire_proto::StreamEvent; pub struct RelayServer { @@ -7,6 +9,8 @@ pub struct RelayServer { pub db_history: sled::Tree, pub db_users: sled::Tree, + pub active_indexers: Mutex>, + pub event_tx: mpsc::Sender, pub raw_block_tx: broadcast::Sender, } @@ -18,6 +22,8 @@ impl RelayServer { event_tx, raw_block_tx, + active_indexers: Default::default(), + db_history: db .open_tree("history") .expect("failed to open history tree"), @@ -30,6 +36,7 @@ impl RelayServer { pub mod http; pub mod indexer; pub mod relay_subscription; +pub mod request_crawl; pub mod sequencer; pub mod tls; pub mod user; diff --git a/src/request_crawl.rs b/src/request_crawl.rs new file mode 100644 index 0000000..6c16505 --- /dev/null +++ b/src/request_crawl.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use anyhow::Result; +use atrium_api::com::atproto::sync::request_crawl; +use bytes::Buf; +use http_body_util::BodyExt; +use hyper::{body::Incoming, Request, Response}; + +use crate::{ + http::{body_empty, body_full, ServerResponse}, + indexer::index_server, + RelayServer, +}; + +pub async fn handle_request_crawl( + server: Arc, + req: Request, +) -> Result { + let body = req.collect().await?.aggregate(); + let input = match serde_json::from_reader::<_, request_crawl::Input>(body.reader()) { + Ok(input) => input, + Err(_) => { + // TODO: surely we can build out an XRPC abstraction or something + return Ok(Response::builder().status(400).body(body_full( + r#"{ "error": "InvalidRequest", "message": "Failed to parse request body" }"#, + ))?); + } + }; + + let hostname = input.data.hostname; + index_server(server, hostname).await?; + + Ok(Response::builder().status(200).body(body_empty())?) +}