diff --git a/rust-toolchain b/rust-toolchain index ba0a719..a63cb35 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.51.0 +1.52.0 diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 7f38eb1..0fc8b28 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -156,20 +156,18 @@ pub async fn register_route( return Err(Error::Uiaa(uiaainfo)); } // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + db.uiaa.create( + &UserId::parse_with_server_name("", db.globals.server_name()) + .expect("we know this is valid"), + "".into(), + &uiaainfo, + &json, + )?; + return Err(Error::Uiaa(uiaainfo)); } else { - if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa.create( - &UserId::parse_with_server_name("", db.globals.server_name()) - .expect("we know this is valid"), - "".into(), - &uiaainfo, - &json, - )?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } } @@ -529,15 +527,13 @@ pub async fn change_password_route( return Err(Error::Uiaa(uiaainfo)); } // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + db.uiaa + .create(&sender_user, &sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); } else { - if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa - .create(&sender_user, &sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } db.users @@ -621,15 +617,13 @@ pub async fn deactivate_route( return Err(Error::Uiaa(uiaainfo)); } // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + db.uiaa + .create(&sender_user, &sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); } else { - if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa - .create(&sender_user, &sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } // Leave all joined rooms and reject all invitations diff --git a/src/client_server/device.rs b/src/client_server/device.rs index a10d788..44b9c32 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -112,15 +112,13 @@ pub async fn delete_device_route( return Err(Error::Uiaa(uiaainfo)); } // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + db.uiaa + .create(&sender_user, &sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); } else { - if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa - .create(&sender_user, &sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } db.users.remove_device(&sender_user, &body.device_id)?; @@ -166,15 +164,13 @@ pub async fn delete_devices_route( return Err(Error::Uiaa(uiaainfo)); } // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + db.uiaa + .create(&sender_user, &sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); } else { - if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa - .create(&sender_user, &sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } for device_id in &body.devices { diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 621e5dd..8eee408 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -141,15 +141,13 @@ pub async fn upload_signing_keys_route( return Err(Error::Uiaa(uiaainfo)); } // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + db.uiaa + .create(&sender_user, &sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); } else { - if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa - .create(&sender_user, &sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } if let Some(master_key) = &body.master_key { diff --git a/src/client_server/user_directory.rs b/src/client_server/user_directory.rs index 14b85a6..a09d527 100644 --- a/src/client_server/user_directory.rs +++ b/src/client_server/user_directory.rs @@ -25,20 +25,22 @@ pub async fn search_users_route( avatar_url: db.users.avatar_url(&user_id).ok()?, }; - if !user + let user_id_matches = user .user_id .to_string() .to_lowercase() - .contains(&body.search_term.to_lowercase()) - && user - .display_name - .as_ref() - .filter(|name| { - name.to_lowercase() - .contains(&body.search_term.to_lowercase()) - }) - .is_none() - { + .contains(&body.search_term.to_lowercase()); + + let user_displayname_matches = user + .display_name + .as_ref() + .filter(|name| { + name.to_lowercase() + .contains(&body.search_term.to_lowercase()) + }) + .is_some(); + + if !user_id_matches && !user_displayname_matches { return None; } diff --git a/src/database.rs b/src/database.rs index 14ce4f0..b32f539 100644 --- a/src/database.rs +++ b/src/database.rs @@ -368,7 +368,7 @@ impl Database { if db.globals.database_version()? < 3 { // Move media to filesystem for (key, content) in db.media.mediaid_file.iter() { - if content.len() == 0 { + if content.is_empty() { continue; } @@ -614,8 +614,8 @@ impl<'r> FromRequest<'r> for DatabaseGuard { } } -impl Into for OwnedRwLockReadGuard { - fn into(self) -> DatabaseGuard { - DatabaseGuard(self) +impl From> for DatabaseGuard { + fn from(val: OwnedRwLockReadGuard) -> Self { + Self(val) } } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 22a5559..25d236a 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -121,7 +121,7 @@ impl Pool { let spilled = Self::prepare_conn(&self.path, None).unwrap(); - return HoldingConn::FromOwned(spilled, spill_arc); + HoldingConn::FromOwned(spilled, spill_arc) } } @@ -250,16 +250,7 @@ macro_rules! iter_from_thread { impl Tree for SqliteTable { fn get(&self, key: &[u8]) -> Result>> { - let guard = self.engine.pool.read_lock(); - - // let start = Instant::now(); - - let val = self.get_with_guard(&guard, key); - - // debug!("get: took {:?}", start.elapsed()); - // debug!("get key: {:?}", &key) - - val + self.get_with_guard(&self.engine.pool.read_lock(), key) } fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { diff --git a/src/database/globals.rs b/src/database/globals.rs index 4242cf5..307ec40 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -16,7 +16,7 @@ use std::{ sync::{Arc, RwLock}, time::{Duration, Instant}, }; -use tokio::sync::{broadcast, Semaphore}; +use tokio::sync::{broadcast, watch::Receiver, Semaphore}; use trust_dns_resolver::TokioAsyncResolver; use super::abstraction::Tree; @@ -26,6 +26,11 @@ pub const COUNTER: &[u8] = b"c"; type WellKnownMap = HashMap, (String, String)>; type TlsNameMap = HashMap; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries +type SyncHandle = ( + Option, // since + Receiver>>, // rx +); + pub struct Globals { pub actual_destination_cache: Arc>, // actual_destination, host pub tls_name_override: Arc>, @@ -39,15 +44,7 @@ pub struct Globals { pub bad_event_ratelimiter: Arc>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, - pub sync_receivers: RwLock< - BTreeMap< - (UserId, Box), - ( - Option, - tokio::sync::watch::Receiver>>, - ), // since, rx - >, - >, + pub sync_receivers: RwLock), SyncHandle>>, pub rotate: RotationHandler, } @@ -109,6 +106,12 @@ impl RotationHandler { } } +impl Default for RotationHandler { + fn default() -> Self { + Self::new() + } +} + impl Globals { pub fn load( globals: Arc, diff --git a/src/database/media.rs b/src/database/media.rs index a1fe26e..f576ca4 100644 --- a/src/database/media.rs +++ b/src/database/media.rs @@ -54,6 +54,7 @@ impl Media { } /// Uploads or replaces a file thumbnail. + #[allow(clippy::too_many_arguments)] pub async fn upload_thumbnail( &self, mxc: String, diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 7b64c46..4d66f9f 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -533,17 +533,15 @@ impl Rooms { r }, |pduid| { - let r = Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { Error::bad_database("Invalid pduid in eventid_pduid.") - })?)); - r + })?)) }, )? .map(|pdu| { - let r = serde_json::from_slice(&pdu) + serde_json::from_slice(&pdu) .map_err(|_| Error::bad_database("Invalid PDU in db.")) - .map(Arc::new); - r + .map(Arc::new) }) .transpose()? { @@ -1112,7 +1110,7 @@ impl Rooms { } }; - new_state.insert(shortstatekey, shorteventid.into()); + new_state.insert(shortstatekey, shorteventid); let new_state_hash = self.calculate_hash( &new_state diff --git a/src/utils.rs b/src/utils.rs index b8ce303..a4dfe03 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -137,7 +137,7 @@ pub fn deserialize_from_str< where E: serde::de::Error, { - v.parse().map_err(|e| serde::de::Error::custom(e)) + v.parse().map_err(serde::de::Error::custom) } } deserializer.deserialize_str(Visitor(std::marker::PhantomData))