implement requestCrawl, enforce repo limit
parent
5b297e0565
commit
9ee21c23d8
|
@ -1,4 +1,6 @@
|
||||||
use crate::{relay_subscription::handle_subscription, RelayServer};
|
use crate::{
|
||||||
|
relay_subscription::handle_subscription, request_crawl::handle_request_crawl, RelayServer,
|
||||||
|
};
|
||||||
|
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
|
@ -37,6 +39,10 @@ async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<Serve
|
||||||
handle_subscription(server, req).await
|
handle_subscription(server, req).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(&Method::POST, "/xrpc/com.atproto.sync.requestCrawl") => {
|
||||||
|
handle_request_crawl(server, req).await
|
||||||
|
}
|
||||||
|
|
||||||
_ => Ok(Response::builder()
|
_ => Ok(Response::builder()
|
||||||
.status(StatusCode::NOT_FOUND)
|
.status(StatusCode::NOT_FOUND)
|
||||||
.header("Content-Type", "text/plain")
|
.header("Content-Type", "text/plain")
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
use std::{io::Cursor, sync::Arc, time::Duration};
|
use std::{io::Cursor, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use anyhow::{bail, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use atrium_api::com::atproto::sync::subscribe_repos;
|
use atrium_api::com::atproto::sync::{list_repos, subscribe_repos};
|
||||||
use bytes::Bytes;
|
use bytes::{Buf, Bytes};
|
||||||
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
|
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
|
||||||
use hyper::{header, upgrade::Upgraded, Request};
|
use http_body_util::BodyExt;
|
||||||
|
use hyper::{client::conn::http1, header, upgrade::Upgraded, Request, StatusCode};
|
||||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||||
use ipld_core::ipld::Ipld;
|
use ipld_core::ipld::Ipld;
|
||||||
use serde_ipld_dagcbor::DecodeError;
|
use serde_ipld_dagcbor::DecodeError;
|
||||||
|
@ -12,7 +13,7 @@ use tokio::{net::TcpStream, sync::mpsc};
|
||||||
use tokio_rustls::rustls::pki_types::ServerName;
|
use tokio_rustls::rustls::pki_types::ServerName;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
http::body_empty,
|
http::{body_empty, HttpBody},
|
||||||
tls::open_tls_stream,
|
tls::open_tls_stream,
|
||||||
user::{fetch_user, lookup_user},
|
user::{fetch_user, lookup_user},
|
||||||
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
|
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
|
||||||
|
@ -322,10 +323,50 @@ impl DataServerSubscription {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_repo_count(host: &str) -> Result<usize> {
|
||||||
|
let tcp_stream = TcpStream::connect((host, 443)).await?;
|
||||||
|
let host_tls = ServerName::try_from(host.to_string())?;
|
||||||
|
let tls_stream = open_tls_stream(tcp_stream, host_tls).await?;
|
||||||
|
let io = TokioIo::new(tls_stream);
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri(format!(
|
||||||
|
"https://{host}/xrpc/com.atproto.sync.listRepos?limit=1000"
|
||||||
|
))
|
||||||
|
.header("Host", host.to_string())
|
||||||
|
.body(body_empty())?;
|
||||||
|
|
||||||
|
let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?;
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
if let Err(err) = conn.await {
|
||||||
|
println!("Connection failed: {:?}", err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let res = sender
|
||||||
|
.send_request(req)
|
||||||
|
.await
|
||||||
|
.context("Failed to send repo count request")?;
|
||||||
|
if res.status() != StatusCode::OK {
|
||||||
|
bail!("server returned non-200 status for listRepos");
|
||||||
|
}
|
||||||
|
|
||||||
|
let body = res.collect().await?.aggregate();
|
||||||
|
let output = serde_json::from_reader::<_, list_repos::Output>(body.reader())
|
||||||
|
.context("Failed to parse listRepos response as JSON")?;
|
||||||
|
|
||||||
|
Ok(output.repos.len())
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, fields(host = %host))]
|
#[tracing::instrument(skip_all, fields(host = %host))]
|
||||||
async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> {
|
async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> {
|
||||||
tracing::debug!("establishing connection");
|
tracing::debug!("establishing connection");
|
||||||
|
|
||||||
|
if get_repo_count(&host).await? >= 1000 {
|
||||||
|
bail!("too many repos! ditching from cerulea relay")
|
||||||
|
}
|
||||||
|
|
||||||
let mut subscription = DataServerSubscription::new(server, host);
|
let mut subscription = DataServerSubscription::new(server, host);
|
||||||
|
|
||||||
// TODO: load seq from db ?
|
// TODO: load seq from db ?
|
||||||
|
@ -386,6 +427,26 @@ async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()>
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()> {
|
||||||
|
{
|
||||||
|
let mut active_indexers = server.active_indexers.lock().await;
|
||||||
|
if active_indexers.contains(&host) {
|
||||||
|
bail!("Indexer already running for host {}", &host);
|
||||||
|
}
|
||||||
|
|
||||||
|
active_indexers.insert(host.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let r = host_subscription(Arc::clone(&server), host.clone()).await;
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut active_indexers = server.active_indexers.lock().await;
|
||||||
|
active_indexers.remove(&host);
|
||||||
|
}
|
||||||
|
|
||||||
|
r
|
||||||
|
}
|
||||||
|
|
||||||
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
|
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
|
||||||
// in future we will spider out but right now i just want da stuff from my PDS
|
// in future we will spider out but right now i just want da stuff from my PDS
|
||||||
|
|
||||||
|
@ -393,7 +454,7 @@ pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
|
||||||
let host = host.to_string();
|
let host = host.to_string();
|
||||||
let server = Arc::clone(&server);
|
let server = Arc::clone(&server);
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
if let Err(e) = host_subscription(server, host).await {
|
if let Err(e) = index_server(server, host).await {
|
||||||
tracing::warn!("encountered error subscribing to PDS: {e:?}");
|
tracing::warn!("encountered error subscribing to PDS: {e:?}");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc, Mutex};
|
||||||
use wire_proto::StreamEvent;
|
use wire_proto::StreamEvent;
|
||||||
|
|
||||||
pub struct RelayServer {
|
pub struct RelayServer {
|
||||||
|
@ -7,6 +9,8 @@ pub struct RelayServer {
|
||||||
pub db_history: sled::Tree,
|
pub db_history: sled::Tree,
|
||||||
pub db_users: sled::Tree,
|
pub db_users: sled::Tree,
|
||||||
|
|
||||||
|
pub active_indexers: Mutex<HashSet<String>>,
|
||||||
|
|
||||||
pub event_tx: mpsc::Sender<StreamEvent>,
|
pub event_tx: mpsc::Sender<StreamEvent>,
|
||||||
pub raw_block_tx: broadcast::Sender<Bytes>,
|
pub raw_block_tx: broadcast::Sender<Bytes>,
|
||||||
}
|
}
|
||||||
|
@ -18,6 +22,8 @@ impl RelayServer {
|
||||||
event_tx,
|
event_tx,
|
||||||
raw_block_tx,
|
raw_block_tx,
|
||||||
|
|
||||||
|
active_indexers: Default::default(),
|
||||||
|
|
||||||
db_history: db
|
db_history: db
|
||||||
.open_tree("history")
|
.open_tree("history")
|
||||||
.expect("failed to open history tree"),
|
.expect("failed to open history tree"),
|
||||||
|
@ -30,6 +36,7 @@ impl RelayServer {
|
||||||
pub mod http;
|
pub mod http;
|
||||||
pub mod indexer;
|
pub mod indexer;
|
||||||
pub mod relay_subscription;
|
pub mod relay_subscription;
|
||||||
|
pub mod request_crawl;
|
||||||
pub mod sequencer;
|
pub mod sequencer;
|
||||||
pub mod tls;
|
pub mod tls;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use atrium_api::com::atproto::sync::request_crawl;
|
||||||
|
use bytes::Buf;
|
||||||
|
use http_body_util::BodyExt;
|
||||||
|
use hyper::{body::Incoming, Request, Response};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
http::{body_empty, body_full, ServerResponse},
|
||||||
|
indexer::index_server,
|
||||||
|
RelayServer,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub async fn handle_request_crawl(
|
||||||
|
server: Arc<RelayServer>,
|
||||||
|
req: Request<Incoming>,
|
||||||
|
) -> Result<ServerResponse> {
|
||||||
|
let body = req.collect().await?.aggregate();
|
||||||
|
let input = match serde_json::from_reader::<_, request_crawl::Input>(body.reader()) {
|
||||||
|
Ok(input) => input,
|
||||||
|
Err(_) => {
|
||||||
|
// TODO: surely we can build out an XRPC abstraction or something
|
||||||
|
return Ok(Response::builder().status(400).body(body_full(
|
||||||
|
r#"{ "error": "InvalidRequest", "message": "Failed to parse request body" }"#,
|
||||||
|
))?);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let hostname = input.data.hostname;
|
||||||
|
index_server(server, hostname).await?;
|
||||||
|
|
||||||
|
Ok(Response::builder().status(200).body(body_empty())?)
|
||||||
|
}
|
Loading…
Reference in New Issue