'RelayServer' => 'AppState'

This commit is contained in:
Charlotte Som 2025-01-15 05:56:24 +00:00
parent 3204162282
commit d4efcecad4
9 changed files with 69 additions and 72 deletions

View file

@ -6,12 +6,12 @@ use hyper::{body::Incoming, Request};
use ipld_core::ipld::Ipld;
use serde_ipld_dagcbor::DecodeError;
use crate::{http::ServerResponse, wire_proto::StreamEventHeader, RelayServer};
use crate::{http::ServerResponse, wire_proto::StreamEventHeader, AppState};
pub fn purge_did(server: &RelayServer, did: &str) -> Result<()> {
pub fn purge_did(app: &AppState, did: &str) -> Result<()> {
// drop commits
for event in server.db_history.iter() {
for event in app.db_history.iter() {
let (seq, event) = event?;
let mut cursor = Cursor::new(&event);
let (header_buf, payload_buf) =
@ -37,7 +37,7 @@ pub fn purge_did(server: &RelayServer, did: &str) -> Result<()> {
if let Some(event_did) = event_did {
if event_did.as_str() == did {
let _ = server.db_history.remove(seq);
let _ = app.db_history.remove(seq);
}
}
}
@ -48,7 +48,7 @@ pub fn purge_did(server: &RelayServer, did: &str) -> Result<()> {
// TODO: ban host
pub async fn handle_purge_did(
_server: Arc<RelayServer>,
_app: Arc<AppState>,
_req: Request<Incoming>,
) -> Result<ServerResponse> {
// TODO:

View file

@ -6,7 +6,7 @@ use tokio::sync::{broadcast, mpsc, Mutex};
use crate::wire_proto::StreamEvent;
pub struct RelayServer {
pub struct AppState {
pub db: sled::Db,
pub db_history: sled::Tree,
pub db_users: sled::Tree,
@ -21,7 +21,7 @@ pub struct RelayServer {
pub raw_block_tx: broadcast::Sender<Bytes>,
}
impl RelayServer {
impl AppState {
pub fn new(db: sled::Db, event_tx: mpsc::Sender<StreamEvent>) -> Self {
let (raw_block_tx, _) = broadcast::channel(128);

View file

@ -13,7 +13,7 @@ use tokio::net::TcpListener;
use crate::{
admin::handle_purge_did,
relay::{request_crawl::handle_request_crawl, subscribe::handle_subscription},
RelayServer,
AppState,
};
pub type HttpBody = BoxBody<Bytes, hyper::Error>;
@ -26,7 +26,7 @@ pub fn body_full<T: Into<Bytes>>(chunk: T) -> HttpBody {
pub type ServerResponse = Response<BoxBody<Bytes, hyper::Error>>;
async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<ServerResponse> {
async fn serve(app: Arc<AppState>, req: Request<Incoming>) -> Result<ServerResponse> {
let path = req.uri().path();
tracing::debug!("{}", path);
@ -38,13 +38,13 @@ async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<Serve
.body(body_full("cerulea relay running..."))?),
(&Method::GET, "/xrpc/com.atproto.sync.subscribeRepos") => {
handle_subscription(server, req).await
handle_subscription(app, req).await
}
(&Method::POST, "/xrpc/com.atproto.sync.requestCrawl") => {
handle_request_crawl(server, req).await
handle_request_crawl(app, req).await
}
(&Method::POST, "/api/admin/purge-did") => handle_purge_did(server, req).await,
(&Method::POST, "/api/admin/purge-did") => handle_purge_did(app, req).await,
_ => Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
@ -53,7 +53,7 @@ async fn serve(server: Arc<RelayServer>, req: Request<Incoming>) -> Result<Serve
}
}
pub async fn listen(server: Arc<RelayServer>, addr: SocketAddr) -> Result<()> {
pub async fn listen(app: Arc<AppState>, addr: SocketAddr) -> Result<()> {
tracing::info!("Listening on: http://{addr}/ ...");
let listener = TcpListener::bind(addr).await?;
@ -61,7 +61,7 @@ pub async fn listen(server: Arc<RelayServer>, addr: SocketAddr) -> Result<()> {
loop {
let (stream, _client_addr) = listener.accept().await?;
let io = TokioIo::new(stream);
let server = Arc::clone(&server);
let server = Arc::clone(&app);
tokio::task::spawn(async move {
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service_fn(move |req| serve(Arc::clone(&server), req)))

View file

@ -8,7 +8,7 @@ use cerulea_relay::{
http::{self},
relay::index::index_servers,
sequencer::start_sequencer,
RelayServer,
AppState,
};
#[derive(Parser, Debug)]
@ -52,7 +52,7 @@ async fn main() -> Result<()> {
let (event_tx, event_rx) = mpsc::channel(128);
let mut server = RelayServer::new(db, event_tx);
let mut server = AppState::new(db, event_tx);
if let Some(plc_directory) = args.plc_resolver {
server.plc_resolver = Cow::Owned(plc_directory);
}

View file

@ -18,7 +18,7 @@ use crate::{
tls::open_tls_stream,
user::{fetch_user, lookup_user},
wire_proto::{StreamEvent, StreamEventHeader, StreamEventPayload},
RelayServer,
AppState,
};
async fn create_ws_client(
@ -53,19 +53,19 @@ async fn create_ws_client(
}
struct DataServerSubscription {
server: Arc<RelayServer>,
app: Arc<AppState>,
host: String,
event_tx: mpsc::Sender<StreamEvent>,
last_seq: Option<i64>,
}
impl DataServerSubscription {
fn new(server: Arc<RelayServer>, host: String) -> Self {
fn new(app: Arc<AppState>, host: String) -> Self {
Self {
host,
event_tx: server.event_tx.clone(),
event_tx: app.event_tx.clone(),
last_seq: None,
server,
app,
}
}
@ -83,7 +83,7 @@ impl DataServerSubscription {
}
self.last_seq = Some(event.seq);
let mut user = lookup_user(&self.server, &event.repo).await?;
let mut user = lookup_user(&self.app, &event.repo).await?;
let pds = user.pds.as_deref().unwrap_or_default();
if pds != self.host {
@ -94,7 +94,7 @@ impl DataServerSubscription {
);
// re-fetch user (without cache)
user = fetch_user(&self.server, &event.repo).await?;
user = fetch_user(&self.app, &event.repo).await?;
let fresh_pds = user.pds.as_deref().unwrap_or_default();
if fresh_pds != self.host {
bail!(
@ -139,7 +139,7 @@ impl DataServerSubscription {
}
self.last_seq = Some(event.seq);
let user = fetch_user(&self.server, &event.did).await?;
let user = fetch_user(&self.app, &event.did).await?;
if user.handle.as_deref() != Some(event.handle.as_str()) {
tracing::warn!(
seq = %event.seq,
@ -167,7 +167,7 @@ impl DataServerSubscription {
self.last_seq = Some(event.seq);
if let Some(handle) = event.handle.as_ref() {
let user = fetch_user(&self.server, &event.did).await?;
let user = fetch_user(&self.app, &event.did).await?;
if user.handle.as_deref() != Some(handle.as_str()) {
tracing::warn!(
seq = %event.seq,
@ -195,7 +195,7 @@ impl DataServerSubscription {
}
self.last_seq = Some(event.seq);
let user = fetch_user(&self.server, &event.did).await?;
let user = fetch_user(&self.app, &event.did).await?;
let pds = user.pds.as_deref().unwrap_or_default();
if pds != self.host {
bail!(
@ -250,7 +250,7 @@ impl DataServerSubscription {
}
self.last_seq = Some(event.seq);
let user = lookup_user(&self.server, &event.did).await?;
let user = lookup_user(&self.app, &event.did).await?;
let pds = user.pds.as_deref().unwrap_or_default();
if pds != self.host {
bail!(
@ -329,7 +329,7 @@ impl DataServerSubscription {
}
fn load_cursor(&mut self) -> Result<()> {
if let Some(saved_cursor) = self.server.db_index_cursors.get(&self.host)? {
if let Some(saved_cursor) = self.app.db_index_cursors.get(&self.host)? {
let mut cur_buf = [0u8; 8];
let len = 8.min(saved_cursor.len());
cur_buf[..len].copy_from_slice(&saved_cursor[..len]);
@ -341,7 +341,7 @@ impl DataServerSubscription {
fn save_cursor(&self) -> Result<()> {
if let Some(cur) = self.last_seq {
self.server
self.app
.db_index_cursors
.insert(&self.host, &i64::to_be_bytes(cur))?;
}
@ -387,15 +387,15 @@ async fn get_repo_count(host: &str) -> Result<usize> {
}
#[tracing::instrument(skip_all, fields(host = %host))]
async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()> {
async fn host_subscription(app: Arc<AppState>, host: String) -> Result<()> {
tracing::debug!("establishing connection");
if get_repo_count(&host).await? >= 1000 {
bail!("too many repos! ditching from cerulea relay")
}
let _ = server.add_good_host(host.clone()).await;
let mut subscription = DataServerSubscription::new(server, host);
let _ = app.add_good_host(host.clone()).await;
let mut subscription = DataServerSubscription::new(app, host);
subscription.load_cursor()?;
tracing::debug!(seq = ?subscription.last_seq, "starting subscription");
@ -456,8 +456,8 @@ async fn host_subscription(server: Arc<RelayServer>, host: String) -> Result<()>
Ok(())
}
pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()> {
if server.is_banned_host(&host)? {
pub async fn index_server(app: Arc<AppState>, host: String) -> Result<()> {
if app.is_banned_host(&host)? {
bail!("refusing to start indexer for banned host '{}'", &host);
}
@ -465,7 +465,7 @@ pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()>
// instead of just having a BTreeSet<String> :)
{
let mut active_indexers = server.active_indexers.lock().await;
let mut active_indexers = app.active_indexers.lock().await;
if active_indexers.contains(&host) {
bail!("Indexer already running for host");
}
@ -473,30 +473,30 @@ pub async fn index_server(server: Arc<RelayServer>, host: String) -> Result<()>
active_indexers.insert(host.clone());
}
let r = host_subscription(Arc::clone(&server), host.clone()).await;
let r = host_subscription(Arc::clone(&app), host.clone()).await;
{
let mut active_indexers = server.active_indexers.lock().await;
let mut active_indexers = app.active_indexers.lock().await;
active_indexers.remove(&host);
}
r
}
pub fn start_indexing_server(server: Arc<RelayServer>, host: String) {
pub fn start_indexing_server(app: Arc<AppState>, host: String) {
tokio::task::spawn(async move {
if let Err(e) = index_server(server, host.clone()).await {
if let Err(e) = index_server(app, host.clone()).await {
tracing::warn!(%host, "encountered error subscribing to PDS: {e:?}");
}
});
}
pub fn index_servers(server: Arc<RelayServer>, hosts: &[String]) {
pub fn index_servers(app: Arc<AppState>, hosts: &[String]) {
// in future we will spider out but right now i just want da stuff from my PDS
for host in hosts.iter() {
let host = host.to_string();
let server = Arc::clone(&server);
let server = Arc::clone(&app);
start_indexing_server(server, host);
}
}

View file

@ -9,11 +9,11 @@ use hyper::{body::Incoming, Request, Response};
use crate::{
http::{body_full, ServerResponse},
relay::index::start_indexing_server,
RelayServer,
AppState,
};
pub async fn handle_request_crawl(
server: Arc<RelayServer>,
app: Arc<AppState>,
req: Request<Incoming>,
) -> Result<ServerResponse> {
let body = req.collect().await?.aggregate();
@ -31,7 +31,7 @@ pub async fn handle_request_crawl(
};
let hostname = input.data.hostname;
start_indexing_server(server, hostname);
start_indexing_server(app, hostname);
Ok(Response::builder()
.status(200)

View file

@ -1,4 +1,4 @@
use crate::RelayServer;
use crate::AppState;
use std::sync::Arc;
@ -120,7 +120,7 @@ async fn rebroadcast_block<'f>(block_rx: &mut Receiver<Bytes>) -> Operation<'f>
}
async fn run_subscription(
server: Arc<RelayServer>,
app: Arc<AppState>,
req: Request<Incoming>,
ws: WebSocket<TokioIo<Upgraded>>,
) {
@ -139,7 +139,7 @@ async fn run_subscription(
let cursor_bytes = cursor.to_be_bytes();
let mut count = 0;
for event in server.db_history.range(cursor_bytes..) {
for event in app.db_history.range(cursor_bytes..) {
let (_seq, event) = match event {
Ok(ev) => ev,
Err(e) => {
@ -163,7 +163,7 @@ async fn run_subscription(
}
// live tailing:
let mut raw_block_rx = server.raw_block_tx.subscribe();
let mut raw_block_rx = app.raw_block_tx.subscribe();
while sub.running {
let op = tokio::select! {
biased;
@ -177,7 +177,7 @@ async fn run_subscription(
}
pub async fn handle_subscription(
server: Arc<RelayServer>,
app: Arc<AppState>,
mut req: Request<Incoming>,
) -> Result<ServerResponse> {
if !is_upgrade_request(&req) {
@ -196,7 +196,7 @@ pub async fn handle_subscription(
}
};
run_subscription(server, req, ws).await;
run_subscription(app, req, ws).await;
});
let (head, _) = res.into_parts();

View file

@ -6,14 +6,14 @@ use tokio::sync::mpsc;
use crate::{
wire_proto::{StreamEvent, StreamEventPayload},
RelayServer,
AppState,
};
async fn run_sequencer(
server: Arc<RelayServer>,
app: Arc<AppState>,
mut event_rx: mpsc::Receiver<StreamEvent>,
) -> Result<()> {
let mut curr_seq = server
let mut curr_seq = app
.db
.get(b"history_last_seq")?
.map(|v| {
@ -28,8 +28,7 @@ async fn run_sequencer(
while let Some((header, payload)) = event_rx.recv().await {
curr_seq += 1;
server
.db
app.db
.insert(b"history_last_seq", &u128::to_le_bytes(curr_seq))?;
/* if matches!(
@ -69,18 +68,16 @@ async fn run_sequencer(
}
let data = Bytes::from(cursor.into_inner());
let _ = server.raw_block_tx.send(data.clone());
server
.db_history
.insert(u128::to_be_bytes(curr_seq), &*data)?;
let _ = app.raw_block_tx.send(data.clone());
app.db_history.insert(u128::to_be_bytes(curr_seq), &*data)?;
}
Ok(())
}
pub fn start_sequencer(server: Arc<RelayServer>, event_rx: mpsc::Receiver<StreamEvent>) {
pub fn start_sequencer(app: Arc<AppState>, event_rx: mpsc::Receiver<StreamEvent>) {
tokio::task::spawn(async move {
if let Err(e) = run_sequencer(server, event_rx).await {
if let Err(e) = run_sequencer(app, event_rx).await {
tracing::error!("sequencer error: {e:?}");
}
});

View file

@ -11,7 +11,7 @@ use tokio::net::TcpStream;
use crate::{
http::{body_empty, HttpBody},
tls::open_tls_stream,
RelayServer,
AppState,
};
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -76,9 +76,9 @@ fn create_user_from_did_doc(did_doc: DidDocument) -> Result<User> {
Ok(user)
}
pub async fn fetch_did_doc(server: &RelayServer, did: &str) -> Result<DidDocument> {
pub async fn fetch_did_doc(app: &AppState, did: &str) -> Result<DidDocument> {
if did.starts_with("did:plc:") {
let domain: &str = &server.plc_resolver;
let domain: &str = &app.plc_resolver;
let tcp_stream = TcpStream::connect((domain, 443)).await?;
let domain_tls: ServerName<'_> = ServerName::try_from(domain.to_string())?;
@ -148,26 +148,26 @@ pub async fn fetch_did_doc(server: &RelayServer, did: &str) -> Result<DidDocumen
}
}
pub async fn fetch_user(server: &RelayServer, did: &str) -> Result<User> {
pub async fn fetch_user(app: &AppState, did: &str) -> Result<User> {
tracing::debug!(%did, "fetching user");
let did_doc = fetch_did_doc(server, did).await?;
let did_doc = fetch_did_doc(app, did).await?;
let user = create_user_from_did_doc(did_doc)?;
store_user(server, &user)?;
store_user(app, &user)?;
Ok(user)
}
pub async fn lookup_user(server: &RelayServer, did: &str) -> Result<User> {
if let Some(cached_user) = server.db_users.get(did)? {
pub async fn lookup_user(app: &AppState, did: &str) -> Result<User> {
if let Some(cached_user) = app.db_users.get(did)? {
let cached_user = serde_ipld_dagcbor::from_slice::<User>(&cached_user)?;
return Ok(cached_user);
}
fetch_user(server, did).await
fetch_user(app, did).await
}
pub fn store_user(server: &RelayServer, user: &User) -> Result<()> {
pub fn store_user(app: &AppState, user: &User) -> Result<()> {
let data = serde_ipld_dagcbor::to_vec(&user)?;
server.db_users.insert(&user.did, data)?;
app.db_users.insert(&user.did, data)?;
Ok(())
}