implement requestCrawl, enforce repo limit

main
Charlotte Som 2024-11-27 01:29:16 +02:00
parent 5b297e0565
commit 9ee21c23d8
4 changed files with 116 additions and 8 deletions

View File

@ -1,4 +1,6 @@
use crate::{relay_subscription::handle_subscription, RelayServer};
use crate::{
relay_subscription::handle_subscription, request_crawl::handle_request_crawl, RelayServer,
};
use std::{net::SocketAddr, sync::Arc};
@ -37,6 +39,10 @@ async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<Serve
handle_subscription(server, req).await
}
(&Method::POST, "/xrpc/com.atproto.sync.requestCrawl") => {
handle_request_crawl(server, req).await
}
_ => Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "text/plain")

View File

@ -1,10 +1,11 @@
use std::{io::Cursor, sync::Arc, time::Duration};
use anyhow::{bail, Result};
use atrium_api::com::atproto::sync::subscribe_repos;
use bytes::Bytes;
use anyhow::{bail, Context, Result};
use atrium_api::com::atproto::sync::{list_repos, subscribe_repos};
use bytes::{Buf, Bytes};
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
use hyper::{header, upgrade::Upgraded, Request};
use http_body_util::BodyExt;
use hyper::{client::conn::http1, header, upgrade::Upgraded, Request, StatusCode};
use hyper_util::rt::{TokioExecutor, TokioIo};
use ipld_core::ipld::Ipld;
use serde_ipld_dagcbor::DecodeError;
@ -12,7 +13,7 @@ use tokio::{net::TcpStream, sync::mpsc};
use tokio_rustls::rustls::pki_types::ServerName;
use crate::{
http::body_empty,
http::{body_empty, HttpBody},
tls::open_tls_stream,
user::{fetch_user, lookup_user},
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
@ -322,10 +323,50 @@ impl DataServerSubscription {
}
}
async fn get_repo_count(host: &str) -> Result<usize> {
let tcp_stream = TcpStream::connect((host, 443)).await?;
let host_tls = ServerName::try_from(host.to_string())?;
let tls_stream = open_tls_stream(tcp_stream, host_tls).await?;
let io = TokioIo::new(tls_stream);
let req = Request::builder()
.method("GET")
.uri(format!(
"https://{host}/xrpc/com.atproto.sync.listRepos?limit=1000"
))
.header("Host", host.to_string())
.body(body_empty())?;
let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});
let res = sender
.send_request(req)
.await
.context("Failed to send repo count request")?;
if res.status() != StatusCode::OK {
bail!("server returned non-200 status for listRepos");
}
let body = res.collect().await?.aggregate();
let output = serde_json::from_reader::<_, list_repos::Output>(body.reader())
.context("Failed to parse listRepos response as JSON")?;
Ok(output.repos.len())
}
#[tracing::instrument(skip_all, fields(host = %host))]
async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> {
tracing::debug!("establishing connection");
if get_repo_count(&host).await? >= 1000 {
bail!("too many repos! ditching from cerulea relay")
}
let mut subscription = DataServerSubscription::new(server, host);
// TODO: load seq from db ?
@ -386,6 +427,26 @@ async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()>
Ok(())
}
pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()> {
{
let mut active_indexers = server.active_indexers.lock().await;
if active_indexers.contains(&host) {
bail!("Indexer already running for host {}", &host);
}
active_indexers.insert(host.clone());
}
let r = host_subscription(Arc::clone(&server), host.clone()).await;
{
let mut active_indexers = server.active_indexers.lock().await;
active_indexers.remove(&host);
}
r
}
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
// in future we will spider out but right now i just want da stuff from my PDS
@ -393,7 +454,7 @@ pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
let host = host.to_string();
let server = Arc::clone(&server);
tokio::task::spawn(async move {
if let Err(e) = host_subscription(server, host).await {
if let Err(e) = index_server(server, host).await {
tracing::warn!("encountered error subscribing to PDS: {e:?}");
}
});

View File

@ -1,5 +1,7 @@
use std::collections::HashSet;
use bytes::Bytes;
use tokio::sync::{broadcast, mpsc};
use tokio::sync::{broadcast, mpsc, Mutex};
use wire_proto::StreamEvent;
pub struct RelayServer {
@ -7,6 +9,8 @@ pub struct RelayServer {
pub db_history: sled::Tree,
pub db_users: sled::Tree,
pub active_indexers: Mutex<HashSet<String>>,
pub event_tx: mpsc::Sender<StreamEvent>,
pub raw_block_tx: broadcast::Sender<Bytes>,
}
@ -18,6 +22,8 @@ impl RelayServer {
event_tx,
raw_block_tx,
active_indexers: Default::default(),
db_history: db
.open_tree("history")
.expect("failed to open history tree"),
@ -30,6 +36,7 @@ impl RelayServer {
pub mod http;
pub mod indexer;
pub mod relay_subscription;
pub mod request_crawl;
pub mod sequencer;
pub mod tls;
pub mod user;

34
src/request_crawl.rs Normal file
View File

@ -0,0 +1,34 @@
use std::sync::Arc;
use anyhow::Result;
use atrium_api::com::atproto::sync::request_crawl;
use bytes::Buf;
use http_body_util::BodyExt;
use hyper::{body::Incoming, Request, Response};
use crate::{
http::{body_empty, body_full, ServerResponse},
indexer::index_server,
RelayServer,
};
pub async fn handle_request_crawl(
server: Arc<RelayServer>,
req: Request<Incoming>,
) -> Result<ServerResponse> {
let body = req.collect().await?.aggregate();
let input = match serde_json::from_reader::<_, request_crawl::Input>(body.reader()) {
Ok(input) => input,
Err(_) => {
// TODO: surely we can build out an XRPC abstraction or something
return Ok(Response::builder().status(400).body(body_full(
r#"{ "error": "InvalidRequest", "message": "Failed to parse request body" }"#,
))?);
}
};
let hostname = input.data.hostname;
index_server(server, hostname).await?;
Ok(Response::builder().status(200).body(body_empty())?)
}