implement requestCrawl, enforce repo limit
parent
5b297e0565
commit
9ee21c23d8
|
@ -1,4 +1,6 @@
|
|||
use crate::{relay_subscription::handle_subscription, RelayServer};
|
||||
use crate::{
|
||||
relay_subscription::handle_subscription, request_crawl::handle_request_crawl, RelayServer,
|
||||
};
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
|
@ -37,6 +39,10 @@ async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<Serve
|
|||
handle_subscription(server, req).await
|
||||
}
|
||||
|
||||
(&Method::POST, "/xrpc/com.atproto.sync.requestCrawl") => {
|
||||
handle_request_crawl(server, req).await
|
||||
}
|
||||
|
||||
_ => Ok(Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.header("Content-Type", "text/plain")
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
use std::{io::Cursor, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use atrium_api::com::atproto::sync::subscribe_repos;
|
||||
use bytes::Bytes;
|
||||
use anyhow::{bail, Context, Result};
|
||||
use atrium_api::com::atproto::sync::{list_repos, subscribe_repos};
|
||||
use bytes::{Buf, Bytes};
|
||||
use fastwebsockets::{FragmentCollector, OpCode, Payload, WebSocketError};
|
||||
use hyper::{header, upgrade::Upgraded, Request};
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{client::conn::http1, header, upgrade::Upgraded, Request, StatusCode};
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use ipld_core::ipld::Ipld;
|
||||
use serde_ipld_dagcbor::DecodeError;
|
||||
|
@ -12,7 +13,7 @@ use tokio::{net::TcpStream, sync::mpsc};
|
|||
use tokio_rustls::rustls::pki_types::ServerName;
|
||||
|
||||
use crate::{
|
||||
http::body_empty,
|
||||
http::{body_empty, HttpBody},
|
||||
tls::open_tls_stream,
|
||||
user::{fetch_user, lookup_user},
|
||||
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
|
||||
|
@ -322,10 +323,50 @@ impl DataServerSubscription {
|
|||
}
|
||||
}
|
||||
|
||||
async fn get_repo_count(host: &str) -> Result<usize> {
|
||||
let tcp_stream = TcpStream::connect((host, 443)).await?;
|
||||
let host_tls = ServerName::try_from(host.to_string())?;
|
||||
let tls_stream = open_tls_stream(tcp_stream, host_tls).await?;
|
||||
let io = TokioIo::new(tls_stream);
|
||||
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.uri(format!(
|
||||
"https://{host}/xrpc/com.atproto.sync.listRepos?limit=1000"
|
||||
))
|
||||
.header("Host", host.to_string())
|
||||
.body(body_empty())?;
|
||||
|
||||
let (mut sender, conn) = http1::handshake::<_, HttpBody>(io).await?;
|
||||
tokio::task::spawn(async move {
|
||||
if let Err(err) = conn.await {
|
||||
println!("Connection failed: {:?}", err);
|
||||
}
|
||||
});
|
||||
|
||||
let res = sender
|
||||
.send_request(req)
|
||||
.await
|
||||
.context("Failed to send repo count request")?;
|
||||
if res.status() != StatusCode::OK {
|
||||
bail!("server returned non-200 status for listRepos");
|
||||
}
|
||||
|
||||
let body = res.collect().await?.aggregate();
|
||||
let output = serde_json::from_reader::<_, list_repos::Output>(body.reader())
|
||||
.context("Failed to parse listRepos response as JSON")?;
|
||||
|
||||
Ok(output.repos.len())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(host = %host))]
|
||||
async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> {
|
||||
tracing::debug!("establishing connection");
|
||||
|
||||
if get_repo_count(&host).await? >= 1000 {
|
||||
bail!("too many repos! ditching from cerulea relay")
|
||||
}
|
||||
|
||||
let mut subscription = DataServerSubscription::new(server, host);
|
||||
|
||||
// TODO: load seq from db ?
|
||||
|
@ -386,6 +427,26 @@ async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()>
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()> {
|
||||
{
|
||||
let mut active_indexers = server.active_indexers.lock().await;
|
||||
if active_indexers.contains(&host) {
|
||||
bail!("Indexer already running for host {}", &host);
|
||||
}
|
||||
|
||||
active_indexers.insert(host.clone());
|
||||
}
|
||||
|
||||
let r = host_subscription(Arc::clone(&server), host.clone()).await;
|
||||
|
||||
{
|
||||
let mut active_indexers = server.active_indexers.lock().await;
|
||||
active_indexers.remove(&host);
|
||||
}
|
||||
|
||||
r
|
||||
}
|
||||
|
||||
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
|
||||
// in future we will spider out but right now i just want da stuff from my PDS
|
||||
|
||||
|
@ -393,7 +454,7 @@ pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
|
|||
let host = host.to_string();
|
||||
let server = Arc::clone(&server);
|
||||
tokio::task::spawn(async move {
|
||||
if let Err(e) = host_subscription(server, host).await {
|
||||
if let Err(e) = index_server(server, host).await {
|
||||
tracing::warn!("encountered error subscribing to PDS: {e:?}");
|
||||
}
|
||||
});
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use std::collections::HashSet;
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tokio::sync::{broadcast, mpsc, Mutex};
|
||||
use wire_proto::StreamEvent;
|
||||
|
||||
pub struct RelayServer {
|
||||
|
@ -7,6 +9,8 @@ pub struct RelayServer {
|
|||
pub db_history: sled::Tree,
|
||||
pub db_users: sled::Tree,
|
||||
|
||||
pub active_indexers: Mutex<HashSet<String>>,
|
||||
|
||||
pub event_tx: mpsc::Sender<StreamEvent>,
|
||||
pub raw_block_tx: broadcast::Sender<Bytes>,
|
||||
}
|
||||
|
@ -18,6 +22,8 @@ impl RelayServer {
|
|||
event_tx,
|
||||
raw_block_tx,
|
||||
|
||||
active_indexers: Default::default(),
|
||||
|
||||
db_history: db
|
||||
.open_tree("history")
|
||||
.expect("failed to open history tree"),
|
||||
|
@ -30,6 +36,7 @@ impl RelayServer {
|
|||
pub mod http;
|
||||
pub mod indexer;
|
||||
pub mod relay_subscription;
|
||||
pub mod request_crawl;
|
||||
pub mod sequencer;
|
||||
pub mod tls;
|
||||
pub mod user;
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use atrium_api::com::atproto::sync::request_crawl;
|
||||
use bytes::Buf;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{body::Incoming, Request, Response};
|
||||
|
||||
use crate::{
|
||||
http::{body_empty, body_full, ServerResponse},
|
||||
indexer::index_server,
|
||||
RelayServer,
|
||||
};
|
||||
|
||||
pub async fn handle_request_crawl(
|
||||
server: Arc<RelayServer>,
|
||||
req: Request<Incoming>,
|
||||
) -> Result<ServerResponse> {
|
||||
let body = req.collect().await?.aggregate();
|
||||
let input = match serde_json::from_reader::<_, request_crawl::Input>(body.reader()) {
|
||||
Ok(input) => input,
|
||||
Err(_) => {
|
||||
// TODO: surely we can build out an XRPC abstraction or something
|
||||
return Ok(Response::builder().status(400).body(body_full(
|
||||
r#"{ "error": "InvalidRequest", "message": "Failed to parse request body" }"#,
|
||||
))?);
|
||||
}
|
||||
};
|
||||
|
||||
let hostname = input.data.hostname;
|
||||
index_server(server, hostname).await?;
|
||||
|
||||
Ok(Response::builder().status(200).body(body_empty())?)
|
||||
}
|
Loading…
Reference in New Issue