From bd63797213cec5dbf137c047505c6606d4cd0c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Mon, 2 Aug 2021 10:13:34 +0200 Subject: [PATCH] improvement: make better use of sqlite connections --- src/client_server/account.rs | 22 +- src/client_server/alias.rs | 7 +- src/client_server/backup.rs | 18 +- src/client_server/config.rs | 8 +- src/client_server/device.rs | 6 +- src/client_server/directory.rs | 2 +- src/client_server/keys.rs | 8 +- src/client_server/media.rs | 2 +- src/client_server/membership.rs | 327 ++++++++++++----------- src/client_server/message.rs | 2 +- src/client_server/presence.rs | 2 +- src/client_server/profile.rs | 18 +- src/client_server/push.rs | 14 +- src/client_server/read_marker.rs | 4 +- src/client_server/redact.rs | 2 +- src/client_server/room.rs | 4 +- src/client_server/session.rs | 6 +- src/client_server/state.rs | 4 +- src/client_server/sync.rs | 10 +- src/client_server/tag.rs | 4 +- src/client_server/to_device.rs | 2 +- src/database.rs | 48 +--- src/database/abstraction.rs | 6 +- src/database/abstraction/heed.rs | 2 +- src/database/abstraction/sqlite.rs | 400 ++++++++++------------------- src/database/appservice.rs | 23 +- src/database/globals.rs | 6 +- src/database/media.rs | 10 +- src/database/rooms.rs | 4 +- src/ruma_wrapper.rs | 4 +- src/server_server.rs | 15 +- 31 files changed, 422 insertions(+), 568 deletions(-) diff --git a/src/client_server/account.rs b/src/client_server/account.rs index c00cc87..ca8b7b1 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -504,7 +504,7 @@ pub async fn register_route( info!("{} registered on this server", user_id); - db.flush().await?; + db.flush()?; Ok(register::Response { access_token: Some(token), @@ -580,7 +580,7 @@ pub async fn change_password_route( } } - db.flush().await?; + db.flush()?; Ok(change_password::Response {}.into()) } @@ -656,11 +656,17 @@ pub async fn deactivate_route( } // Leave all joined rooms and reject all invitations - for room_id in db.rooms.rooms_joined(&sender_user).chain( - db.rooms - .rooms_invited(&sender_user) - .map(|t| t.map(|(r, _)| r)), - ) { + let all_rooms = db + .rooms + .rooms_joined(&sender_user) + .chain( + db.rooms + .rooms_invited(&sender_user) + .map(|t| t.map(|(r, _)| r)), + ) + .collect::>(); + + for room_id in all_rooms { let room_id = room_id?; let event = member::MemberEventContent { membership: member::MembershipState::Leave, @@ -701,7 +707,7 @@ pub async fn deactivate_route( info!("{} deactivated their account", sender_user); - db.flush().await?; + db.flush()?; Ok(deactivate::Response { id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index f5d9f64..143e607 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -31,7 +31,7 @@ pub async fn create_alias_route( db.rooms .set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?; - db.flush().await?; + db.flush()?; Ok(create_alias::Response::new().into()) } @@ -47,7 +47,7 @@ pub async fn delete_alias_route( ) -> ConduitResult { db.rooms.set_alias(&body.room_alias, None, &db.globals)?; - db.flush().await?; + db.flush()?; Ok(delete_alias::Response::new().into()) } @@ -85,8 +85,7 @@ pub async fn get_alias_helper( match db.rooms.id_from_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - let iter = db.appservice.iter_all()?; - for (_id, registration) in iter.filter_map(|r| r.ok()) { + for (_id, registration) in db.appservice.all()? { let aliases = registration .get("namespaces") .and_then(|ns| ns.get("aliases")) diff --git a/src/client_server/backup.rs b/src/client_server/backup.rs index 6d540cb..06f9818 100644 --- a/src/client_server/backup.rs +++ b/src/client_server/backup.rs @@ -26,7 +26,7 @@ pub async fn create_backup_route( .key_backups .create_backup(&sender_user, &body.algorithm, &db.globals)?; - db.flush().await?; + db.flush()?; Ok(create_backup::Response { version }.into()) } @@ -44,7 +44,7 @@ pub async fn update_backup_route( db.key_backups .update_backup(&sender_user, &body.version, &body.algorithm, &db.globals)?; - db.flush().await?; + db.flush()?; Ok(update_backup::Response {}.into()) } @@ -117,7 +117,7 @@ pub async fn delete_backup_route( db.key_backups.delete_backup(&sender_user, &body.version)?; - db.flush().await?; + db.flush()?; Ok(delete_backup::Response {}.into()) } @@ -147,7 +147,7 @@ pub async fn add_backup_keys_route( } } - db.flush().await?; + db.flush()?; Ok(add_backup_keys::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -179,7 +179,7 @@ pub async fn add_backup_key_sessions_route( )? } - db.flush().await?; + db.flush()?; Ok(add_backup_key_sessions::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -209,7 +209,7 @@ pub async fn add_backup_key_session_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(add_backup_key_session::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -288,7 +288,7 @@ pub async fn delete_backup_keys_route( db.key_backups .delete_all_keys(&sender_user, &body.version)?; - db.flush().await?; + db.flush()?; Ok(delete_backup_keys::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -311,7 +311,7 @@ pub async fn delete_backup_key_sessions_route( db.key_backups .delete_room_keys(&sender_user, &body.version, &body.room_id)?; - db.flush().await?; + db.flush()?; Ok(delete_backup_key_sessions::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -334,7 +334,7 @@ pub async fn delete_backup_key_session_route( db.key_backups .delete_room_key(&sender_user, &body.version, &body.room_id, &body.session_id)?; - db.flush().await?; + db.flush()?; Ok(delete_backup_key_session::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), diff --git a/src/client_server/config.rs b/src/client_server/config.rs index b9826bf..b692749 100644 --- a/src/client_server/config.rs +++ b/src/client_server/config.rs @@ -43,7 +43,7 @@ pub async fn set_global_account_data_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_global_account_data::Response {}.into()) } @@ -78,7 +78,7 @@ pub async fn set_room_account_data_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_room_account_data::Response {}.into()) } @@ -98,7 +98,7 @@ pub async fn get_global_account_data_route( .account_data .get::>(None, sender_user, body.event_type.clone().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - db.flush().await?; + db.flush()?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -129,7 +129,7 @@ pub async fn get_room_account_data_route( body.event_type.clone().into(), )? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - db.flush().await?; + db.flush()?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 085d034..5210467 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -71,7 +71,7 @@ pub async fn update_device_route( db.users .update_device_metadata(&sender_user, &body.device_id, &device)?; - db.flush().await?; + db.flush()?; Ok(update_device::Response {}.into()) } @@ -123,7 +123,7 @@ pub async fn delete_device_route( db.users.remove_device(&sender_user, &body.device_id)?; - db.flush().await?; + db.flush()?; Ok(delete_device::Response {}.into()) } @@ -177,7 +177,7 @@ pub async fn delete_devices_route( db.users.remove_device(&sender_user, &device_id)? } - db.flush().await?; + db.flush()?; Ok(delete_devices::Response {}.into()) } diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index f1ec4b8..7cab1a7 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -100,7 +100,7 @@ pub async fn set_room_visibility_route( } } - db.flush().await?; + db.flush()?; Ok(set_room_visibility::Response {}.into()) } diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 418e41a..8db7688 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -64,7 +64,7 @@ pub async fn upload_keys_route( } } - db.flush().await?; + db.flush()?; Ok(upload_keys::Response { one_time_key_counts: db.users.count_one_time_keys(sender_user, sender_device)?, @@ -105,7 +105,7 @@ pub async fn claim_keys_route( ) -> ConduitResult { let response = claim_keys_helper(&body.one_time_keys, &db).await?; - db.flush().await?; + db.flush()?; Ok(response.into()) } @@ -166,7 +166,7 @@ pub async fn upload_signing_keys_route( )?; } - db.flush().await?; + db.flush()?; Ok(upload_signing_keys::Response {}.into()) } @@ -227,7 +227,7 @@ pub async fn upload_signatures_route( } } - db.flush().await?; + db.flush()?; Ok(upload_signatures::Response {}.into()) } diff --git a/src/client_server/media.rs b/src/client_server/media.rs index eaaf939..2bd189a 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -52,7 +52,7 @@ pub async fn create_content_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(create_content::Response { content_uri: mxc.try_into().expect("Invalid mxc:// URI"), diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index ea7fdab..895ad27 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -74,7 +74,7 @@ pub async fn join_room_by_id_route( ) .await; - db.flush().await?; + db.flush()?; ret } @@ -125,7 +125,7 @@ pub async fn join_room_by_id_or_alias_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(join_room_by_id_or_alias::Response { room_id: join_room_response.0.room_id, @@ -146,7 +146,7 @@ pub async fn leave_room_route( db.rooms.leave_room(sender_user, &body.room_id, &db).await?; - db.flush().await?; + db.flush()?; Ok(leave_room::Response::new().into()) } @@ -164,7 +164,7 @@ pub async fn invite_user_route( if let invite_user::IncomingInvitationRecipient::UserId { user_id } = &body.recipient { invite_helper(sender_user, user_id, &body.room_id, &db, false).await?; - db.flush().await?; + db.flush()?; Ok(invite_user::Response {}.into()) } else { Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) @@ -229,7 +229,7 @@ pub async fn kick_user_route( drop(mutex_lock); - db.flush().await?; + db.flush()?; Ok(kick_user::Response::new().into()) } @@ -301,7 +301,7 @@ pub async fn ban_user_route( drop(mutex_lock); - db.flush().await?; + db.flush()?; Ok(ban_user::Response::new().into()) } @@ -363,7 +363,7 @@ pub async fn unban_user_route( drop(mutex_lock); - db.flush().await?; + db.flush()?; Ok(unban_user::Response::new().into()) } @@ -381,7 +381,7 @@ pub async fn forget_room_route( db.rooms.forget(&body.room_id, &sender_user)?; - db.flush().await?; + db.flush()?; Ok(forget_room::Response::new().into()) } @@ -712,7 +712,7 @@ async fn join_room_by_id_helper( drop(mutex_lock); - db.flush().await?; + db.flush()?; Ok(join_room_by_id::Response::new(room_id.clone()).into()) } @@ -788,155 +788,165 @@ pub async fn invite_helper<'a>( db: &Database, is_direct: bool, ) -> Result<()> { - let mutex = Arc::clone( - db.globals - .roomid_mutex - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let mutex_lock = mutex.lock().await; - if user_id.server_name() != db.globals.server_name() { - let prev_events = db - .rooms - .get_pdu_leaves(room_id)? - .into_iter() - .take(20) - .collect::>(); - - let create_event = db - .rooms - .room_state_get(room_id, &EventType::RoomCreate, "")?; - - let create_event_content = create_event - .as_ref() - .map(|create_event| { - serde_json::from_value::>(create_event.content.clone()) - .expect("Raw::from_value always works.") - .deserialize() - .map_err(|_| Error::bad_database("Invalid PowerLevels event in db.")) - }) - .transpose()?; - - let create_prev_event = if prev_events.len() == 1 - && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) - { - create_event - } else { - None - }; - - // If there was no create event yet, assume we are creating a version 6 room right now - let room_version_id = create_event_content - .map_or(RoomVersionId::Version6, |create_event| { - create_event.room_version - }); - let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - - let content = serde_json::to_value(MemberEventContent { - avatar_url: None, - displayname: None, - is_direct: Some(is_direct), - membership: MembershipState::Invite, - third_party_invite: None, - blurhash: None, - }) - .expect("member event is valid value"); - - let state_key = user_id.to_string(); - let kind = EventType::RoomMember; - - let auth_events = - db.rooms - .get_auth_events(room_id, &kind, &sender_user, Some(&state_key), &content)?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) - + uint!(1); - - let mut unsigned = BTreeMap::new(); - - if let Some(prev_pdu) = db.rooms.room_state_get(room_id, &kind, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); - unsigned.insert( - "prev_sender".to_owned(), - serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + let (room_version_id, pdu_json, invite_room_state) = { + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), ); - } + let mutex_lock = mutex.lock().await; - let pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater"), - room_id: room_id.clone(), - sender: sender_user.clone(), - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"), - kind, - content, - state_key: Some(state_key), - prev_events, - depth, - auth_events: auth_events + let prev_events = db + .rooms + .get_pdu_leaves(room_id)? + .into_iter() + .take(20) + .collect::>(); + + let create_event = db + .rooms + .room_state_get(room_id, &EventType::RoomCreate, "")?; + + let create_event_content = create_event + .as_ref() + .map(|create_event| { + serde_json::from_value::>(create_event.content.clone()) + .expect("Raw::from_value always works.") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PowerLevels event in db.")) + }) + .transpose()?; + + let create_prev_event = if prev_events.len() == 1 + && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) + { + create_event + } else { + None + }; + + // If there was no create event yet, assume we are creating a version 6 room right now + let room_version_id = create_event_content + .map_or(RoomVersionId::Version6, |create_event| { + create_event.room_version + }); + let room_version = + RoomVersion::new(&room_version_id).expect("room version is supported"); + + let content = serde_json::to_value(MemberEventContent { + avatar_url: None, + displayname: None, + is_direct: Some(is_direct), + membership: MembershipState::Invite, + third_party_invite: None, + blurhash: None, + }) + .expect("member event is valid value"); + + let state_key = user_id.to_string(); + let kind = EventType::RoomMember; + + let auth_events = db.rooms.get_auth_events( + room_id, + &kind, + &sender_user, + Some(&state_key), + &content, + )?; + + // Our depth is the maximum depth of prev_events + 1 + let depth = prev_events .iter() - .map(|(_, pdu)| pdu.event_id.clone()) - .collect(), - redacts: None, - unsigned, - hashes: ruma::events::pdu::EventHash { - sha256: "aaa".to_owned(), - }, - signatures: BTreeMap::new(), + .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) + .max() + .unwrap_or_else(|| uint!(0)) + + uint!(1); + + let mut unsigned = BTreeMap::new(); + + if let Some(prev_pdu) = db.rooms.room_state_get(room_id, &kind, &state_key)? { + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); + unsigned.insert( + "prev_sender".to_owned(), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + ); + } + + let pdu = PduEvent { + event_id: ruma::event_id!("$thiswillbefilledinlater"), + room_id: room_id.clone(), + sender: sender_user.clone(), + origin_server_ts: utils::millis_since_unix_epoch() + .try_into() + .expect("time is valid"), + kind, + content, + state_key: Some(state_key), + prev_events, + depth, + auth_events: auth_events + .iter() + .map(|(_, pdu)| pdu.event_id.clone()) + .collect(), + redacts: None, + unsigned, + hashes: ruma::events::pdu::EventHash { + sha256: "aaa".to_owned(), + }, + signatures: BTreeMap::new(), + }; + + let auth_check = state_res::auth_check( + &room_version, + &Arc::new(pdu.clone()), + create_prev_event, + &auth_events, + None, // TODO: third_party_invite + ) + .map_err(|e| { + error!("{:?}", e); + Error::bad_database("Auth check failed.") + })?; + + if !auth_check { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Event is not authorized.", + )); + } + + // Hash and sign + let mut pdu_json = + utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); + + pdu_json.remove("event_id"); + + // Add origin because synapse likes that (and it's required in the spec) + pdu_json.insert( + "origin".to_owned(), + to_canonical_value(db.globals.server_name()) + .expect("server name is a valid CanonicalJsonValue"), + ); + + ruma::signatures::hash_and_sign_event( + db.globals.server_name().as_str(), + db.globals.keypair(), + &mut pdu_json, + &room_version_id, + ) + .expect("event is valid, we just created it"); + + let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; + + drop(mutex_lock); + + (room_version_id, pdu_json, invite_room_state) }; - let auth_check = state_res::auth_check( - &room_version, - &Arc::new(pdu.clone()), - create_prev_event, - &auth_events, - None, // TODO: third_party_invite - ) - .map_err(|e| { - error!("{:?}", e); - Error::bad_database("Auth check failed.") - })?; - - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Event is not authorized.", - )); - } - - // Hash and sign - let mut pdu_json = - utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); - - pdu_json.remove("event_id"); - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - to_canonical_value(db.globals.server_name()) - .expect("server name is a valid CanonicalJsonValue"), - ); - - ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut pdu_json, - &room_version_id, - ) - .expect("event is valid, we just created it"); - - drop(mutex_lock); - - let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; let response = db .sending .send_federation_request( @@ -1008,6 +1018,17 @@ pub async fn invite_helper<'a>( return Ok(()); } + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + + let mutex_lock = mutex.lock().await; + db.rooms.build_and_append_pdu( PduBuilder { event_type: EventType::RoomMember, @@ -1030,5 +1051,7 @@ pub async fn invite_helper<'a>( &mutex_lock, )?; + drop(mutex_lock); + Ok(()) } diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 3d8218c..f77ca89 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -87,7 +87,7 @@ pub async fn send_message_event_route( drop(mutex_lock); - db.flush().await?; + db.flush()?; Ok(send_message_event::Response::new(event_id).into()) } diff --git a/src/client_server/presence.rs b/src/client_server/presence.rs index ca78a88..7312cb3 100644 --- a/src/client_server/presence.rs +++ b/src/client_server/presence.rs @@ -41,7 +41,7 @@ pub async fn set_presence_route( )?; } - db.flush().await?; + db.flush()?; Ok(set_presence::Response {}.into()) } diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 693254f..648afea 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -32,9 +32,10 @@ pub async fn set_displayname_route( .set_displayname(&sender_user, body.displayname.clone())?; // Send a new membership event and presence update into all joined rooms - for (pdu_builder, room_id) in db - .rooms - .rooms_joined(&sender_user) + let all_rooms_joined = db.rooms.rooms_joined(&sender_user).collect::>(); + + for (pdu_builder, room_id) in all_rooms_joined + .into_iter() .filter_map(|r| r.ok()) .map(|room_id| { Ok::<_, Error>(( @@ -109,7 +110,7 @@ pub async fn set_displayname_route( )?; } - db.flush().await?; + db.flush()?; Ok(set_display_name::Response {}.into()) } @@ -165,9 +166,10 @@ pub async fn set_avatar_url_route( db.users.set_blurhash(&sender_user, body.blurhash.clone())?; // Send a new membership event and presence update into all joined rooms - for (pdu_builder, room_id) in db - .rooms - .rooms_joined(&sender_user) + let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::>(); + + for (pdu_builder, room_id) in all_joined_rooms + .into_iter() .filter_map(|r| r.ok()) .map(|room_id| { Ok::<_, Error>(( @@ -242,7 +244,7 @@ pub async fn set_avatar_url_route( )?; } - db.flush().await?; + db.flush()?; Ok(set_avatar_url::Response {}.into()) } diff --git a/src/client_server/push.rs b/src/client_server/push.rs index 867b452..9489f07 100644 --- a/src/client_server/push.rs +++ b/src/client_server/push.rs @@ -192,7 +192,7 @@ pub async fn set_pushrule_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_pushrule::Response {}.into()) } @@ -248,7 +248,7 @@ pub async fn get_pushrule_actions_route( _ => None, }; - db.flush().await?; + db.flush()?; Ok(get_pushrule_actions::Response { actions: actions.unwrap_or_default(), @@ -325,7 +325,7 @@ pub async fn set_pushrule_actions_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_pushrule_actions::Response {}.into()) } @@ -386,7 +386,7 @@ pub async fn get_pushrule_enabled_route( _ => false, }; - db.flush().await?; + db.flush()?; Ok(get_pushrule_enabled::Response { enabled }.into()) } @@ -465,7 +465,7 @@ pub async fn set_pushrule_enabled_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_pushrule_enabled::Response {}.into()) } @@ -534,7 +534,7 @@ pub async fn delete_pushrule_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(delete_pushrule::Response {}.into()) } @@ -570,7 +570,7 @@ pub async fn set_pushers_route( db.pusher.set_pusher(sender_user, pusher)?; - db.flush().await?; + db.flush()?; Ok(set_pusher::Response::default().into()) } diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index f5e2924..85b0bf6 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -75,7 +75,7 @@ pub async fn set_read_marker_route( )?; } - db.flush().await?; + db.flush()?; Ok(set_read_marker::Response {}.into()) } @@ -128,7 +128,7 @@ pub async fn create_receipt_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(create_receipt::Response {}.into()) } diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index 2e4c651..63d3d4a 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -49,7 +49,7 @@ pub async fn redact_event_route( drop(mutex_lock); - db.flush().await?; + db.flush()?; Ok(redact_event::Response { event_id }.into()) } diff --git a/src/client_server/room.rs b/src/client_server/room.rs index d5188e8..1b14a93 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -301,7 +301,7 @@ pub async fn create_room_route( info!("{} created a room", sender_user); - db.flush().await?; + db.flush()?; Ok(create_room::Response::new(room_id).into()) } @@ -561,7 +561,7 @@ pub async fn upgrade_room_route( drop(mutex_lock); - db.flush().await?; + db.flush()?; // Return the replacement room id Ok(upgrade_room::Response { replacement_room }.into()) diff --git a/src/client_server/session.rs b/src/client_server/session.rs index f8452e0..d4d3c03 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -143,7 +143,7 @@ pub async fn login_route( info!("{} logged in", user_id); - db.flush().await?; + db.flush()?; Ok(login::Response { user_id, @@ -175,7 +175,7 @@ pub async fn logout_route( db.users.remove_device(&sender_user, sender_device)?; - db.flush().await?; + db.flush()?; Ok(logout::Response::new().into()) } @@ -204,7 +204,7 @@ pub async fn logout_all_route( db.users.remove_device(&sender_user, &device_id)?; } - db.flush().await?; + db.flush()?; Ok(logout_all::Response::new().into()) } diff --git a/src/client_server/state.rs b/src/client_server/state.rs index e0e5d29..5afac03 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -43,7 +43,7 @@ pub async fn send_state_event_for_key_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(send_state_event::Response { event_id }.into()) } @@ -69,7 +69,7 @@ pub async fn send_state_event_for_empty_key_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(send_state_event::Response { event_id }.into()) } diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 541045e..b09a212 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -186,7 +186,8 @@ async fn sync_helper( .filter_map(|r| r.ok()), ); - for room_id in db.rooms.rooms_joined(&sender_user) { + let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::>(); + for room_id in all_joined_rooms { let room_id = room_id?; // Get and drop the lock to wait for remaining operations to finish @@ -198,6 +199,7 @@ async fn sync_helper( .entry(room_id.clone()) .or_default(), ); + let mutex_lock = mutex.lock().await; drop(mutex_lock); @@ -658,7 +660,8 @@ async fn sync_helper( } let mut left_rooms = BTreeMap::new(); - for result in db.rooms.rooms_left(&sender_user) { + let all_left_rooms = db.rooms.rooms_left(&sender_user).collect::>(); + for result in all_left_rooms { let (room_id, left_state_events) = result?; // Get and drop the lock to wait for remaining operations to finish @@ -697,7 +700,8 @@ async fn sync_helper( } let mut invited_rooms = BTreeMap::new(); - for result in db.rooms.rooms_invited(&sender_user) { + let all_invited_rooms = db.rooms.rooms_invited(&sender_user).collect::>(); + for result in all_invited_rooms { let (room_id, invite_state_events) = result?; // Get and drop the lock to wait for remaining operations to finish diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs index 223d122..5582bcd 100644 --- a/src/client_server/tag.rs +++ b/src/client_server/tag.rs @@ -40,7 +40,7 @@ pub async fn update_tag_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(create_tag::Response {}.into()) } @@ -74,7 +74,7 @@ pub async fn delete_tag_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(delete_tag::Response {}.into()) } diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index d3f7d25..69147c9 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -95,7 +95,7 @@ pub async fn send_event_to_device_route( db.transaction_ids .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; - db.flush().await?; + db.flush()?; Ok(send_event_to_device::Response {}.into()) } diff --git a/src/database.rs b/src/database.rs index baf66b5..5b47302 100644 --- a/src/database.rs +++ b/src/database.rs @@ -45,14 +45,8 @@ pub struct Config { database_path: String, #[serde(default = "default_db_cache_capacity_mb")] db_cache_capacity_mb: f64, - #[serde(default = "default_sqlite_read_pool_size")] - sqlite_read_pool_size: usize, #[serde(default = "default_sqlite_wal_clean_second_interval")] sqlite_wal_clean_second_interval: u32, - #[serde(default = "default_sqlite_spillover_reap_fraction")] - sqlite_spillover_reap_fraction: f64, - #[serde(default = "default_sqlite_spillover_reap_interval_secs")] - sqlite_spillover_reap_interval_secs: u32, #[serde(default = "default_max_request_size")] max_request_size: u32, #[serde(default = "default_max_concurrent_requests")] @@ -111,22 +105,10 @@ fn default_db_cache_capacity_mb() -> f64 { 200.0 } -fn default_sqlite_read_pool_size() -> usize { - num_cpus::get().max(1) -} - fn default_sqlite_wal_clean_second_interval() -> u32 { 15 * 60 // every 15 minutes } -fn default_sqlite_spillover_reap_fraction() -> f64 { - 0.5 -} - -fn default_sqlite_spillover_reap_interval_secs() -> u32 { - 60 -} - fn default_max_request_size() -> u32 { 20 * 1024 * 1024 // Default to 20 MB } @@ -458,7 +440,6 @@ impl Database { #[cfg(feature = "sqlite")] { Self::start_wal_clean_task(Arc::clone(&db), &config).await; - Self::start_spillover_reap_task(builder, &config).await; } Ok(db) @@ -568,7 +549,7 @@ impl Database { } #[tracing::instrument(skip(self))] - pub async fn flush(&self) -> Result<()> { + pub fn flush(&self) -> Result<()> { let start = std::time::Instant::now(); let res = self._db.flush(); @@ -584,33 +565,6 @@ impl Database { self._db.flush_wal() } - #[cfg(feature = "sqlite")] - #[tracing::instrument(skip(engine, config))] - pub async fn start_spillover_reap_task(engine: Arc, config: &Config) { - let fraction = config.sqlite_spillover_reap_fraction.clamp(0.01, 1.0); - let interval_secs = config.sqlite_spillover_reap_interval_secs as u64; - - let weak = Arc::downgrade(&engine); - - tokio::spawn(async move { - use tokio::time::interval; - - use std::{sync::Weak, time::Duration}; - - let mut i = interval(Duration::from_secs(interval_secs)); - - loop { - i.tick().await; - - if let Some(arc) = Weak::upgrade(&weak) { - arc.reap_spillover_by_fraction(fraction); - } else { - break; - } - } - }); - } - #[cfg(feature = "sqlite")] #[tracing::instrument(skip(db, config))] pub async fn start_wal_clean_task(db: Arc>, config: &Config) { diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 8ccac78..d0fa780 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -28,20 +28,20 @@ pub trait Tree: Send + Sync { fn remove(&self, key: &[u8]) -> Result<()>; - fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a>; + fn iter<'a>(&'a self) -> Box, Vec)> + 'a>; fn iter_from<'a>( &'a self, from: &[u8], backwards: bool, - ) -> Box, Vec)> + Send + 'a>; + ) -> Box, Vec)> + 'a>; fn increment(&self, key: &[u8]) -> Result>; fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + Send + 'a>; + ) -> Box, Vec)> + 'a>; fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; diff --git a/src/database/abstraction/heed.rs b/src/database/abstraction/heed.rs index 0421b14..e767e22 100644 --- a/src/database/abstraction/heed.rs +++ b/src/database/abstraction/heed.rs @@ -81,7 +81,7 @@ impl EngineTree { let (s, r) = bounded::(100); let engine = Arc::clone(&self.engine); - let lock = self.engine.iter_pool.lock().unwrap(); + let lock = self.engine.iter_pool.lock().await; if lock.active_count() < lock.max_count() { lock.execute(move || { iter_from_thread_work(tree, &engine.env.read_txn().unwrap(), from, backwards, &s); diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index bbf7508..d2ecb3a 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,133 +1,61 @@ use super::{DatabaseEngine, Tree}; use crate::{database::Config, Result}; -use crossbeam::channel::{ - bounded, unbounded, Receiver as ChannelReceiver, Sender as ChannelSender, TryRecvError, -}; use parking_lot::{Mutex, MutexGuard, RwLock}; -use rusqlite::{Connection, DatabaseName::Main, OptionalExtension, Params}; +use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use std::{ + cell::RefCell, collections::HashMap, future::Future, - ops::Deref, path::{Path, PathBuf}, pin::Pin, sync::Arc, time::{Duration, Instant}, }; -use threadpool::ThreadPool; use tokio::sync::oneshot::Sender; use tracing::{debug, warn}; -struct Pool { - writer: Mutex, - readers: Vec>, - spills: ConnectionRecycler, - spill_tracker: Arc<()>, - path: PathBuf, -} - pub const MILLI: Duration = Duration::from_millis(1); -enum HoldingConn<'a> { - FromGuard(MutexGuard<'a, Connection>), - FromRecycled(RecycledConn, Arc<()>), +thread_local! { + static READ_CONNECTION: RefCell> = RefCell::new(None); } -impl<'a> Deref for HoldingConn<'a> { - type Target = Connection; +struct PreparedStatementIterator<'a> { + pub iterator: Box + 'a>, + pub statement_ref: NonAliasingBox>, +} - fn deref(&self) -> &Self::Target { - match self { - HoldingConn::FromGuard(guard) => guard.deref(), - HoldingConn::FromRecycled(conn, _) => conn.deref(), - } +impl Iterator for PreparedStatementIterator<'_> { + type Item = TupleOfBytes; + + fn next(&mut self) -> Option { + self.iterator.next() } } -struct ConnectionRecycler(ChannelSender, ChannelReceiver); - -impl ConnectionRecycler { - fn new() -> Self { - let (s, r) = unbounded(); - Self(s, r) - } - - fn recycle(&self, conn: Connection) -> RecycledConn { - let sender = self.0.clone(); - - RecycledConn(Some(conn), sender) - } - - fn try_take(&self) -> Option { - match self.1.try_recv() { - Ok(conn) => Some(conn), - Err(TryRecvError::Empty) => None, - // as this is pretty impossible, a panic is warranted if it ever occurs - Err(TryRecvError::Disconnected) => panic!("Receiving channel was disconnected. A a sender is owned by the current struct, this should never happen(!!!)") - } - } -} - -struct RecycledConn( - Option, // To allow moving out of the struct when `Drop` is called. - ChannelSender, -); - -impl Deref for RecycledConn { - type Target = Connection; - - fn deref(&self) -> &Self::Target { - self.0 - .as_ref() - .expect("RecycledConn does not have a connection in Option<>") - } -} - -impl Drop for RecycledConn { +struct NonAliasingBox(*mut T); +impl Drop for NonAliasingBox { fn drop(&mut self) { - if let Some(conn) = self.0.take() { - debug!("Recycled connection"); - if let Err(e) = self.1.send(conn) { - warn!("Recycling a connection led to the following error: {:?}", e) - } - } + unsafe { Box::from_raw(self.0) }; } } -impl Pool { - fn new>(path: P, num_readers: usize, total_cache_size_mb: f64) -> Result { - // calculates cache-size per permanent connection - // 1. convert MB to KiB - // 2. divide by permanent connections - // 3. round down to nearest integer - let cache_size: u32 = ((total_cache_size_mb * 1024.0) / (num_readers + 1) as f64) as u32; +pub struct Engine { + writer: Mutex, - let writer = Mutex::new(Self::prepare_conn(&path, Some(cache_size))?); + path: PathBuf, + cache_size_per_thread: u32, +} - let mut readers = Vec::new(); - - for _ in 0..num_readers { - readers.push(Mutex::new(Self::prepare_conn(&path, Some(cache_size))?)) - } - - Ok(Self { - writer, - readers, - spills: ConnectionRecycler::new(), - spill_tracker: Arc::new(()), - path: path.as_ref().to_path_buf(), - }) - } - - fn prepare_conn>(path: P, cache_size: Option) -> Result { - let conn = Connection::open(path)?; +impl Engine { + fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { + let conn = Connection::open(&path)?; + conn.pragma_update(Some(Main), "page_size", &32768)?; conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; - - if let Some(cache_kib) = cache_size { - conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_kib)))?; - } + conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; + conn.pragma_update(Some(Main), "wal_autocheckpoint", &0)?; Ok(conn) } @@ -136,68 +64,52 @@ impl Pool { self.writer.lock() } - fn read_lock(&self) -> HoldingConn<'_> { - // First try to get a connection from the permanent pool - for r in &self.readers { - if let Some(reader) = r.try_lock() { - return HoldingConn::FromGuard(reader); + fn read_lock(&self) -> &'static Connection { + READ_CONNECTION.with(|cell| { + let connection = &mut cell.borrow_mut(); + + if (*connection).is_none() { + let c = Box::leak(Box::new( + Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(), + )); + **connection = Some(c); } - } - debug!("read_lock: All permanent readers locked, obtaining spillover reader..."); - - // We didn't get a connection from the permanent pool, so we'll dumpster-dive for recycled connections. - // Either we have a connection or we dont, if we don't, we make a new one. - let conn = match self.spills.try_take() { - Some(conn) => conn, - None => { - debug!("read_lock: No recycled connections left, creating new one..."); - Self::prepare_conn(&self.path, None).unwrap() - } - }; - - // Clone the spill Arc to mark how many spilled connections actually exist. - let spill_arc = Arc::clone(&self.spill_tracker); - - // Get a sense of how many connections exist now. - let now_count = Arc::strong_count(&spill_arc) - 1 /* because one is held by the pool */; - - // If the spillover readers are more than the number of total readers, there might be a problem. - if now_count > self.readers.len() { - warn!( - "Database is under high load. Consider increasing sqlite_read_pool_size ({} spillover readers exist)", - now_count - ); - } - - // Return the recyclable connection. - HoldingConn::FromRecycled(self.spills.recycle(conn), spill_arc) + connection.unwrap() + }) } -} -pub struct Engine { - pool: Pool, - iter_pool: Mutex, + pub fn flush_wal(self: &Arc) -> Result<()> { + self.write_lock() + .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; + Ok(()) + } } impl DatabaseEngine for Engine { fn open(config: &Config) -> Result> { - let pool = Pool::new( - Path::new(&config.database_path).join("conduit.db"), - config.sqlite_read_pool_size, - config.db_cache_capacity_mb, - )?; + let path = Path::new(&config.database_path).join("conduit.db"); + + // calculates cache-size per permanent connection + // 1. convert MB to KiB + // 2. divide by permanent connections + // 3. round down to nearest integer + let cache_size_per_thread: u32 = + ((config.db_cache_capacity_mb * 1024.0) / (num_cpus::get().max(1) + 1) as f64) as u32; + + let writer = Mutex::new(Self::prepare_conn(&path, cache_size_per_thread)?); let arc = Arc::new(Engine { - pool, - iter_pool: Mutex::new(ThreadPool::new(10)), + writer, + path, + cache_size_per_thread, }); Ok(arc) } fn open_tree(self: &Arc, name: &str) -> Result> { - self.pool.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; + self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; Ok(Arc::new(SqliteTable { engine: Arc::clone(self), @@ -212,31 +124,6 @@ impl DatabaseEngine for Engine { } } -impl Engine { - pub fn flush_wal(self: &Arc) -> Result<()> { - self.pool.write_lock().pragma_update(Some(Main), "wal_checkpoint", &"RESTART")?; - Ok(()) - } - - // Reaps (at most) (.len() * `fraction`) (rounded down, min 1) connections. - pub fn reap_spillover_by_fraction(&self, fraction: f64) { - let mut reaped = 0; - - let spill_amount = self.pool.spills.1.len() as f64; - let fraction = fraction.clamp(0.01, 1.0); - - let amount = (spill_amount * fraction).max(1.0) as u32; - - for _ in 0..amount { - if self.pool.spills.try_take().is_some() { - reaped += 1; - } - } - - debug!("Reaped {} connections", reaped); - } -} - pub struct SqliteTable { engine: Arc, name: String, @@ -258,7 +145,7 @@ impl SqliteTable { fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { guard.execute( format!( - "INSERT INTO {} (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", + "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name ) .as_str(), @@ -266,70 +153,17 @@ impl SqliteTable { )?; Ok(()) } - - #[tracing::instrument(skip(self, sql, param))] - fn iter_from_thread( - &self, - sql: String, - param: Option>, - ) -> Box + Send + Sync> { - let (s, r) = bounded::(5); - - let engine = Arc::clone(&self.engine); - - let lock = self.engine.iter_pool.lock(); - if lock.active_count() < lock.max_count() { - lock.execute(move || { - if let Some(param) = param { - iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, [param]); - } else { - iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, []); - } - }); - } else { - std::thread::spawn(move || { - if let Some(param) = param { - iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, [param]); - } else { - iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, []); - } - }); - } - - Box::new(r.into_iter()) - } -} - -fn iter_from_thread_work

( - guard: &HoldingConn<'_>, - s: &ChannelSender<(Vec, Vec)>, - sql: &str, - params: P, -) where - P: Params, -{ - for bob in guard - .prepare(sql) - .unwrap() - .query_map(params, |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(|r| r.unwrap()) - { - if s.send(bob).is_err() { - return; - } - } } impl Tree for SqliteTable { #[tracing::instrument(skip(self, key))] fn get(&self, key: &[u8]) -> Result>> { - self.get_with_guard(&self.engine.pool.read_lock(), key) + self.get_with_guard(&self.engine.read_lock(), key) } #[tracing::instrument(skip(self, key, value))] fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - let guard = self.engine.pool.write_lock(); + let guard = self.engine.write_lock(); let start = Instant::now(); @@ -337,7 +171,7 @@ impl Tree for SqliteTable { let elapsed = start.elapsed(); if elapsed > MILLI { - debug!("insert: took {:012?} : {}", elapsed, &self.name); + warn!("insert took {:?} : {}", elapsed, &self.name); } drop(guard); @@ -369,7 +203,7 @@ impl Tree for SqliteTable { #[tracing::instrument(skip(self, key))] fn remove(&self, key: &[u8]) -> Result<()> { - let guard = self.engine.pool.write_lock(); + let guard = self.engine.write_lock(); let start = Instant::now(); @@ -389,9 +223,28 @@ impl Tree for SqliteTable { } #[tracing::instrument(skip(self))] - fn iter<'a>(&'a self) -> Box + Send + 'a> { - let name = self.name.clone(); - self.iter_from_thread(format!("SELECT key, value FROM {}", name), None) + fn iter<'a>(&'a self) -> Box + 'a> { + let guard = self.engine.read_lock(); + + let statement = Box::leak(Box::new( + guard + .prepare(&format!("SELECT key, value FROM {}", &self.name)) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) } #[tracing::instrument(skip(self, from, backwards))] @@ -399,31 +252,61 @@ impl Tree for SqliteTable { &'a self, from: &[u8], backwards: bool, - ) -> Box + Send + 'a> { - let name = self.name.clone(); + ) -> Box + 'a> { + let guard = self.engine.read_lock(); let from = from.to_vec(); // TODO change interface? + if backwards { - self.iter_from_thread( - format!( - "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", - name - ), - Some(from), - ) + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", + &self.name + )) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) } else { - self.iter_from_thread( - format!( - "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", - name - ), - Some(from), - ) + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", + &self.name + )) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) } } #[tracing::instrument(skip(self, key))] fn increment(&self, key: &[u8]) -> Result> { - let guard = self.engine.pool.write_lock(); + let guard = self.engine.write_lock(); let start = Instant::now(); @@ -445,10 +328,7 @@ impl Tree for SqliteTable { } #[tracing::instrument(skip(self, prefix))] - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box + Send + 'a> { + fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { // let name = self.name.clone(); // self.iter_from_thread( // format!( @@ -483,25 +363,9 @@ impl Tree for SqliteTable { fn clear(&self) -> Result<()> { debug!("clear: running"); self.engine - .pool .write_lock() .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; debug!("clear: ran"); Ok(()) } } - -// TODO -// struct Pool { -// writer: Mutex, -// readers: [Mutex; NUM_READERS], -// } - -// // then, to pick a reader: -// for r in &pool.readers { -// if let Ok(reader) = r.try_lock() { -// // use reader -// } -// } -// // none unlocked, pick the next reader -// pool.readers[pool.counter.fetch_add(1, Relaxed) % NUM_READERS].lock() diff --git a/src/database/appservice.rs b/src/database/appservice.rs index f39520c..7cc9137 100644 --- a/src/database/appservice.rs +++ b/src/database/appservice.rs @@ -49,22 +49,23 @@ impl Appservice { ) } - pub fn iter_ids(&self) -> Result> + Send + '_> { + pub fn iter_ids(&self) -> Result> + '_> { Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { utils::string_from_bytes(&id) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) })) } - pub fn iter_all( - &self, - ) -> Result> + '_ + Send> { - Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| { - Ok(( - id.clone(), - self.get_registration(&id)? - .expect("iter_ids only returns appservices that exist"), - )) - })) + pub fn all(&self) -> Result> { + self.iter_ids()? + .filter_map(|id| id.ok()) + .map(move |id| { + Ok(( + id.clone(), + self.get_registration(&id)? + .expect("iter_ids only returns appservices that exist"), + )) + }) + .collect() } } diff --git a/src/database/globals.rs b/src/database/globals.rs index 0edb9ca..2ca8de9 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -15,7 +15,7 @@ use std::{ sync::{Arc, RwLock}, time::{Duration, Instant}, }; -use tokio::sync::{broadcast, watch::Receiver, Mutex, Semaphore}; +use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; use tracing::{error, info}; use trust_dns_resolver::TokioAsyncResolver; @@ -45,8 +45,8 @@ pub struct Globals { pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, pub sync_receivers: RwLock), SyncHandle>>, - pub roomid_mutex: RwLock>>>, - pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer + pub roomid_mutex: RwLock>>>, + pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer pub rotate: RotationHandler, } diff --git a/src/database/media.rs b/src/database/media.rs index f576ca4..a9bb42b 100644 --- a/src/database/media.rs +++ b/src/database/media.rs @@ -101,8 +101,8 @@ impl Media { prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail prefix.push(0xff); - let mut iter = self.mediaid_file.scan_prefix(prefix); - if let Some((key, _)) = iter.next() { + let first = self.mediaid_file.scan_prefix(prefix).next(); + if let Some((key, _)) = first { let path = globals.get_media_file(&key); let mut file = Vec::new(); File::open(path).await?.read_to_end(&mut file).await?; @@ -190,7 +190,9 @@ impl Media { original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail original_prefix.push(0xff); - if let Some((key, _)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() { + let first_thumbnailprefix = self.mediaid_file.scan_prefix(thumbnail_prefix).next(); + let first_originalprefix = self.mediaid_file.scan_prefix(original_prefix).next(); + if let Some((key, _)) = first_thumbnailprefix { // Using saved thumbnail let path = globals.get_media_file(&key); let mut file = Vec::new(); @@ -225,7 +227,7 @@ impl Media { content_type, file: file.to_vec(), })) - } else if let Some((key, _)) = self.mediaid_file.scan_prefix(original_prefix).next() { + } else if let Some((key, _)) = first_originalprefix { // Generate a thumbnail let path = globals.get_media_file(&key); let mut file = Vec::new(); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 79bb059..c3148c2 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -2,7 +2,6 @@ mod edus; pub use edus::RoomEdus; use member::MembershipState; -use tokio::sync::MutexGuard; use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; use lru_cache::LruCache; @@ -28,6 +27,7 @@ use std::{ mem, sync::{Arc, Mutex}, }; +use tokio::sync::MutexGuard; use tracing::{debug, error, warn}; use super::{abstraction::Tree, admin::AdminCommand, pusher}; @@ -1496,7 +1496,7 @@ impl Rooms { db.sending.send_pdu(&server, &pdu_id)?; } - for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { + for appservice in db.appservice.all()? { if let Some(namespaces) = appservice.1.get("namespaces") { let users = namespaces .get("users") diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 2121439..5681194 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -75,9 +75,9 @@ where registration, )) = db .appservice - .iter_all() + .all() .unwrap() - .filter_map(|r| r.ok()) + .iter() .find(|(_id, registration)| { registration .get("as_token") diff --git a/src/server_server.rs b/src/server_server.rs index 232c5d4..09b6bfc 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -806,7 +806,7 @@ pub async fn send_transaction_message_route( } } - db.flush().await?; + db.flush()?; Ok(send_transaction_message::v1::Response { pdus: resolved_map }.into()) } @@ -1343,7 +1343,6 @@ pub fn handle_incoming_pdu<'a>( &state_at_incoming_event, &mutex_lock, ) - .await .map_err(|_| "Failed to add pdu to db.".to_owned())?, ); debug!("Appended incoming pdu."); @@ -1643,7 +1642,7 @@ pub(crate) async fn fetch_signing_keys( /// Append the incoming event setting the state snapshot to the state from the /// server that sent the event. #[tracing::instrument(skip(db, pdu, pdu_json, new_room_leaves, state, _mutex_lock))] -async fn append_incoming_pdu( +fn append_incoming_pdu( db: &Database, pdu: &PduEvent, pdu_json: CanonicalJsonObject, @@ -1663,7 +1662,7 @@ async fn append_incoming_pdu( &db, )?; - for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { + for appservice in db.appservice.all()? { if let Some(namespaces) = appservice.1.get("namespaces") { let users = namespaces .get("users") @@ -2208,7 +2207,7 @@ pub async fn create_join_event_route( db.sending.send_pdu(&server, &pdu_id)?; } - db.flush().await?; + db.flush()?; Ok(create_join_event::v2::Response { room_state: RoomState { @@ -2327,7 +2326,7 @@ pub async fn create_invite_route( )?; } - db.flush().await?; + db.flush()?; Ok(create_invite::v2::Response { event: PduEvent::convert_to_outgoing_federation_event(signed_event), @@ -2464,7 +2463,7 @@ pub async fn get_keys_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(get_keys::v1::Response { device_keys: result.device_keys, @@ -2489,7 +2488,7 @@ pub async fn claim_keys_route( let result = claim_keys_helper(&body.one_time_keys, &db).await?; - db.flush().await?; + db.flush()?; Ok(claim_keys::v1::Response { one_time_keys: result.one_time_keys,