implement requestCrawl, enforce repo limit
This commit is contained in:
		
							parent
							
								
									5b297e0565
								
							
						
					
					
						commit
						9ee21c23d8
					
				
					 4 changed files with 116 additions and 8 deletions
				
			
		|  | @ -1,4 +1,6 @@ | ||||||
| use crate::{relay_subscription::handle_subscription, RelayServer}; | use crate::{ | ||||||
|  |     relay_subscription::handle_subscription, request_crawl::handle_request_crawl, RelayServer, | ||||||
|  | }; | ||||||
| 
 | 
 | ||||||
| use std::{net::SocketAddr, sync::Arc}; | use std::{net::SocketAddr, sync::Arc}; | ||||||
| 
 | 
 | ||||||
|  | @ -37,6 +39,10 @@ async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<Serve | ||||||
|             handle_subscription(server, req).await |             handle_subscription(server, req).await | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         (&Method::POST, "/xrpc/com.atproto.sync.requestCrawl") => { | ||||||
|  |             handle_request_crawl(server, req).await | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         _ => Ok(Response::builder() |         _ => Ok(Response::builder() | ||||||
|             .status(StatusCode::NOT_FOUND) |             .status(StatusCode::NOT_FOUND) | ||||||
|             .header("Content-Type", "text/plain") |             .header("Content-Type", "text/plain") | ||||||
|  |  | ||||||
|  | @ -1,10 +1,11 @@ | ||||||
| use std::{io::Cursor, sync::Arc, time::Duration}; | use std::{io::Cursor, sync::Arc, time::Duration}; | ||||||
| 
 | 
 | ||||||
| use anyhow::{bail, Result}; | use anyhow::{bail, Context, Result}; | ||||||
| use atrium_api::com::atproto::sync::subscribe_repos; | use atrium_api::com::atproto::sync::{list_repos, subscribe_repos}; | ||||||
| use bytes::Bytes; | use bytes::{Buf, Bytes}; | ||||||
| use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError}; | use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError}; | ||||||
| use hyper::{header, upgrade::Upgraded, Request}; | use http_body_util::BodyExt; | ||||||
|  | use hyper::{client::conn::http1, header, upgrade::Upgraded, Request, StatusCode}; | ||||||
| use hyper_util::rt::{TokioExecutor, TokioIo}; | use hyper_util::rt::{TokioExecutor, TokioIo}; | ||||||
| use ipld_core::ipld::Ipld; | use ipld_core::ipld::Ipld; | ||||||
| use serde_ipld_dagcbor::DecodeError; | use serde_ipld_dagcbor::DecodeError; | ||||||
|  | @ -12,7 +13,7 @@ use tokio::{net::TcpStream, sync::mpsc}; | ||||||
| use tokio_rustls::rustls::pki_types::ServerName; | use tokio_rustls::rustls::pki_types::ServerName; | ||||||
| 
 | 
 | ||||||
| use crate::{ | use crate::{ | ||||||
|     http::body_empty, |     http::{body_empty, HttpBody}, | ||||||
|     tls::open_tls_stream, |     tls::open_tls_stream, | ||||||
|     user::{fetch_user, lookup_user}, |     user::{fetch_user, lookup_user}, | ||||||
|     wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload}, |     wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload}, | ||||||
|  | @ -322,10 +323,50 @@ impl DataServerSubscription { | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | async fn get_repo_count(host: &str) -> Result<usize> { | ||||||
|  |     let tcp_stream = TcpStream::connect((host, 443)).await?; | ||||||
|  |     let host_tls = ServerName::try_from(host.to_string())?; | ||||||
|  |     let tls_stream = open_tls_stream(tcp_stream, host_tls).await?; | ||||||
|  |     let io = TokioIo::new(tls_stream); | ||||||
|  | 
 | ||||||
|  |     let req = Request::builder() | ||||||
|  |         .method("GET") | ||||||
|  |         .uri(format!( | ||||||
|  |             "https://{host}/xrpc/com.atproto.sync.listRepos?limit=1000" | ||||||
|  |         )) | ||||||
|  |         .header("Host", host.to_string()) | ||||||
|  |         .body(body_empty())?; | ||||||
|  | 
 | ||||||
|  |     let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?; | ||||||
|  |     tokio::task::spawn(async move { | ||||||
|  |         if let Err(err) = conn.await { | ||||||
|  |             println!("Connection failed: {:?}", err); | ||||||
|  |         } | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     let res = sender | ||||||
|  |         .send_request(req) | ||||||
|  |         .await | ||||||
|  |         .context("Failed to send repo count request")?; | ||||||
|  |     if res.status() != StatusCode::OK { | ||||||
|  |         bail!("server returned non-200 status for listRepos"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     let body = res.collect().await?.aggregate(); | ||||||
|  |     let output = serde_json::from_reader::<_, list_repos::Output>(body.reader()) | ||||||
|  |         .context("Failed to parse listRepos response as JSON")?; | ||||||
|  | 
 | ||||||
|  |     Ok(output.repos.len()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| #[tracing::instrument(skip_all, fields(host = %host))] | #[tracing::instrument(skip_all, fields(host = %host))] | ||||||
| async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> { | async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> { | ||||||
|     tracing::debug!("establishing connection"); |     tracing::debug!("establishing connection"); | ||||||
| 
 | 
 | ||||||
|  |     if get_repo_count(&host).await? >= 1000 { | ||||||
|  |         bail!("too many repos! ditching from cerulea relay") | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     let mut subscription = DataServerSubscription::new(server, host); |     let mut subscription = DataServerSubscription::new(server, host); | ||||||
| 
 | 
 | ||||||
|     // TODO: load seq from db ?
 |     // TODO: load seq from db ?
 | ||||||
|  | @ -386,6 +427,26 @@ async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> | ||||||
|     Ok(()) |     Ok(()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()> { | ||||||
|  |     { | ||||||
|  |         let mut active_indexers = server.active_indexers.lock().await; | ||||||
|  |         if active_indexers.contains(&host) { | ||||||
|  |             bail!("Indexer already running for host {}", &host); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         active_indexers.insert(host.clone()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     let r = host_subscription(Arc::clone(&server), host.clone()).await; | ||||||
|  | 
 | ||||||
|  |     { | ||||||
|  |         let mut active_indexers = server.active_indexers.lock().await; | ||||||
|  |         active_indexers.remove(&host); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     r | ||||||
|  | } | ||||||
|  | 
 | ||||||
| pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) { | pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) { | ||||||
|     // in future we will spider out but right now i just want da stuff from my PDS
 |     // in future we will spider out but right now i just want da stuff from my PDS
 | ||||||
| 
 | 
 | ||||||
|  | @ -393,7 +454,7 @@ pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) { | ||||||
|         let host = host.to_string(); |         let host = host.to_string(); | ||||||
|         let server = Arc::clone(&server); |         let server = Arc::clone(&server); | ||||||
|         tokio::task::spawn(async move { |         tokio::task::spawn(async move { | ||||||
|             if let Err(e) = host_subscription(server, host).await { |             if let Err(e) = index_server(server, host).await { | ||||||
|                 tracing::warn!("encountered error subscribing to PDS: {e:?}"); |                 tracing::warn!("encountered error subscribing to PDS: {e:?}"); | ||||||
|             } |             } | ||||||
|         }); |         }); | ||||||
|  |  | ||||||
|  | @ -1,5 +1,7 @@ | ||||||
|  | use std::collections::HashSet; | ||||||
|  | 
 | ||||||
| use bytes::Bytes; | use bytes::Bytes; | ||||||
| use tokio::sync::{broadcast, mpsc}; | use tokio::sync::{broadcast, mpsc, Mutex}; | ||||||
| use wire_proto::StreamEvent; | use wire_proto::StreamEvent; | ||||||
| 
 | 
 | ||||||
| pub struct RelayServer { | pub struct RelayServer { | ||||||
|  | @ -7,6 +9,8 @@ pub struct RelayServer { | ||||||
|     pub db_history: sled::Tree, |     pub db_history: sled::Tree, | ||||||
|     pub db_users: sled::Tree, |     pub db_users: sled::Tree, | ||||||
| 
 | 
 | ||||||
|  |     pub active_indexers: Mutex<HashSet<String>>, | ||||||
|  | 
 | ||||||
|     pub event_tx: mpsc::Sender<StreamEvent>, |     pub event_tx: mpsc::Sender<StreamEvent>, | ||||||
|     pub raw_block_tx: broadcast::Sender<Bytes>, |     pub raw_block_tx: broadcast::Sender<Bytes>, | ||||||
| } | } | ||||||
|  | @ -18,6 +22,8 @@ impl RelayServer { | ||||||
|             event_tx, |             event_tx, | ||||||
|             raw_block_tx, |             raw_block_tx, | ||||||
| 
 | 
 | ||||||
|  |             active_indexers: Default::default(), | ||||||
|  | 
 | ||||||
|             db_history: db |             db_history: db | ||||||
|                 .open_tree("history") |                 .open_tree("history") | ||||||
|                 .expect("failed to open history tree"), |                 .expect("failed to open history tree"), | ||||||
|  | @ -30,6 +36,7 @@ impl RelayServer { | ||||||
| pub mod http; | pub mod http; | ||||||
| pub mod indexer; | pub mod indexer; | ||||||
| pub mod relay_subscription; | pub mod relay_subscription; | ||||||
|  | pub mod request_crawl; | ||||||
| pub mod sequencer; | pub mod sequencer; | ||||||
| pub mod tls; | pub mod tls; | ||||||
| pub mod user; | pub mod user; | ||||||
|  |  | ||||||
							
								
								
									
										34
									
								
								src/request_crawl.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								src/request_crawl.rs
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,34 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
|  | use anyhow::Result; | ||||||
|  | use atrium_api::com::atproto::sync::request_crawl; | ||||||
|  | use bytes::Buf; | ||||||
|  | use http_body_util::BodyExt; | ||||||
|  | use hyper::{body::Incoming, Request, Response}; | ||||||
|  | 
 | ||||||
|  | use crate::{ | ||||||
|  |     http::{body_empty, body_full, ServerResponse}, | ||||||
|  |     indexer::index_server, | ||||||
|  |     RelayServer, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | pub async fn handle_request_crawl( | ||||||
|  |     server: Arc<RelayServer>, | ||||||
|  |     req: Request<Incoming>, | ||||||
|  | ) -> Result<ServerResponse> { | ||||||
|  |     let body = req.collect().await?.aggregate(); | ||||||
|  |     let input = match serde_json::from_reader::<_, request_crawl::Input>(body.reader()) { | ||||||
|  |         Ok(input) => input, | ||||||
|  |         Err(_) => { | ||||||
|  |             // TODO: surely we can build out an XRPC abstraction or something
 | ||||||
|  |             return Ok(Response::builder().status(400).body(body_full( | ||||||
|  |                 r#"{ "error": "InvalidRequest", "message": "Failed to parse request body" }"#, | ||||||
|  |             ))?); | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let hostname = input.data.hostname; | ||||||
|  |     index_server(server, hostname).await?; | ||||||
|  | 
 | ||||||
|  |     Ok(Response::builder().status(200).body(body_empty())?) | ||||||
|  | } | ||||||
		Loading…
	
		Reference in a new issue