diff --git a/src/database.rs b/src/database.rs index b32f539..5a896a8 100644 --- a/src/database.rs +++ b/src/database.rs @@ -23,7 +23,7 @@ use rocket::{ futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}, outcome::{try_outcome, IntoOutcome}, request::{FromRequest, Request}, - State, + Shutdown, State, }; use ruma::{DeviceId, ServerName, UserId}; use serde::{de::IgnoredAny, Deserialize}; @@ -199,7 +199,7 @@ impl Database { } /// Load an existing database or create a new one. - pub async fn load_or_create(config: Config) -> Result>> { + pub async fn load_or_create(config: &Config) -> Result>> { Self::check_sled_or_sqlite_db(&config)?; let builder = Engine::open(&config)?; @@ -425,6 +425,17 @@ impl Database { Ok(db) } + #[cfg(feature = "conduit_bin")] + pub async fn start_on_shutdown_tasks(db: Arc>, shutdown: Shutdown) { + tokio::spawn(async move { + shutdown.await; + + log::info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); + + db.read().await.globals.rotate.fire(); + }); + } + pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) { let userid_bytes = user_id.as_bytes().to_vec(); let mut userid_prefix = userid_bytes.clone(); diff --git a/src/main.rs b/src/main.rs index e0d2e3d..324a3ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -220,11 +220,17 @@ async fn main() { config.warn_deprecated(); - let db = Database::load_or_create(config) + let db = Database::load_or_create(&config) .await .expect("config is valid"); - let rocket = setup_rocket(raw_config, db); + let rocket = setup_rocket(raw_config, Arc::clone(&db)) + .ignite() + .await + .unwrap(); + + Database::start_on_shutdown_tasks(db, rocket.shutdown()).await; + rocket.launch().await.unwrap(); }