improvement: make better use of sqlite connections
This commit is contained in:
		
							parent
							
								
									2c4f966d60
								
							
						
					
					
						commit
						bd63797213
					
				
					 31 changed files with 422 additions and 568 deletions
				
			
		|  | @ -504,7 +504,7 @@ pub async fn register_route( | |||
| 
 | ||||
|     info!("{} registered on this server", user_id); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(register::Response { | ||||
|         access_token: Some(token), | ||||
|  | @ -580,7 +580,7 @@ pub async fn change_password_route( | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(change_password::Response {}.into()) | ||||
| } | ||||
|  | @ -656,11 +656,17 @@ pub async fn deactivate_route( | |||
|     } | ||||
| 
 | ||||
|     // Leave all joined rooms and reject all invitations
 | ||||
|     for room_id in db.rooms.rooms_joined(&sender_user).chain( | ||||
|     let all_rooms = db | ||||
|         .rooms | ||||
|         .rooms_joined(&sender_user) | ||||
|         .chain( | ||||
|             db.rooms | ||||
|                 .rooms_invited(&sender_user) | ||||
|                 .map(|t| t.map(|(r, _)| r)), | ||||
|     ) { | ||||
|         ) | ||||
|         .collect::<Vec<_>>(); | ||||
| 
 | ||||
|     for room_id in all_rooms { | ||||
|         let room_id = room_id?; | ||||
|         let event = member::MemberEventContent { | ||||
|             membership: member::MembershipState::Leave, | ||||
|  | @ -701,7 +707,7 @@ pub async fn deactivate_route( | |||
| 
 | ||||
|     info!("{} deactivated their account", sender_user); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(deactivate::Response { | ||||
|         id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, | ||||
|  |  | |||
|  | @ -31,7 +31,7 @@ pub async fn create_alias_route( | |||
|     db.rooms | ||||
|         .set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(create_alias::Response::new().into()) | ||||
| } | ||||
|  | @ -47,7 +47,7 @@ pub async fn delete_alias_route( | |||
| ) -> ConduitResult<delete_alias::Response> { | ||||
|     db.rooms.set_alias(&body.room_alias, None, &db.globals)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_alias::Response::new().into()) | ||||
| } | ||||
|  | @ -85,8 +85,7 @@ pub async fn get_alias_helper( | |||
|     match db.rooms.id_from_alias(&room_alias)? { | ||||
|         Some(r) => room_id = Some(r), | ||||
|         None => { | ||||
|             let iter = db.appservice.iter_all()?; | ||||
|             for (_id, registration) in iter.filter_map(|r| r.ok()) { | ||||
|             for (_id, registration) in db.appservice.all()? { | ||||
|                 let aliases = registration | ||||
|                     .get("namespaces") | ||||
|                     .and_then(|ns| ns.get("aliases")) | ||||
|  |  | |||
|  | @ -26,7 +26,7 @@ pub async fn create_backup_route( | |||
|         .key_backups | ||||
|         .create_backup(&sender_user, &body.algorithm, &db.globals)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(create_backup::Response { version }.into()) | ||||
| } | ||||
|  | @ -44,7 +44,7 @@ pub async fn update_backup_route( | |||
|     db.key_backups | ||||
|         .update_backup(&sender_user, &body.version, &body.algorithm, &db.globals)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(update_backup::Response {}.into()) | ||||
| } | ||||
|  | @ -117,7 +117,7 @@ pub async fn delete_backup_route( | |||
| 
 | ||||
|     db.key_backups.delete_backup(&sender_user, &body.version)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_backup::Response {}.into()) | ||||
| } | ||||
|  | @ -147,7 +147,7 @@ pub async fn add_backup_keys_route( | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(add_backup_keys::Response { | ||||
|         count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), | ||||
|  | @ -179,7 +179,7 @@ pub async fn add_backup_key_sessions_route( | |||
|         )? | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(add_backup_key_sessions::Response { | ||||
|         count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), | ||||
|  | @ -209,7 +209,7 @@ pub async fn add_backup_key_session_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(add_backup_key_session::Response { | ||||
|         count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), | ||||
|  | @ -288,7 +288,7 @@ pub async fn delete_backup_keys_route( | |||
|     db.key_backups | ||||
|         .delete_all_keys(&sender_user, &body.version)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_backup_keys::Response { | ||||
|         count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), | ||||
|  | @ -311,7 +311,7 @@ pub async fn delete_backup_key_sessions_route( | |||
|     db.key_backups | ||||
|         .delete_room_keys(&sender_user, &body.version, &body.room_id)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_backup_key_sessions::Response { | ||||
|         count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), | ||||
|  | @ -334,7 +334,7 @@ pub async fn delete_backup_key_session_route( | |||
|     db.key_backups | ||||
|         .delete_room_key(&sender_user, &body.version, &body.room_id, &body.session_id)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_backup_key_session::Response { | ||||
|         count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), | ||||
|  |  | |||
|  | @ -43,7 +43,7 @@ pub async fn set_global_account_data_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_global_account_data::Response {}.into()) | ||||
| } | ||||
|  | @ -78,7 +78,7 @@ pub async fn set_room_account_data_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_room_account_data::Response {}.into()) | ||||
| } | ||||
|  | @ -98,7 +98,7 @@ pub async fn get_global_account_data_route( | |||
|         .account_data | ||||
|         .get::<Box<RawJsonValue>>(None, sender_user, body.event_type.clone().into())? | ||||
|         .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get()) | ||||
|         .map_err(|_| Error::bad_database("Invalid account data event in db."))? | ||||
|  | @ -129,7 +129,7 @@ pub async fn get_room_account_data_route( | |||
|             body.event_type.clone().into(), | ||||
|         )? | ||||
|         .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get()) | ||||
|         .map_err(|_| Error::bad_database("Invalid account data event in db."))? | ||||
|  |  | |||
|  | @ -71,7 +71,7 @@ pub async fn update_device_route( | |||
|     db.users | ||||
|         .update_device_metadata(&sender_user, &body.device_id, &device)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(update_device::Response {}.into()) | ||||
| } | ||||
|  | @ -123,7 +123,7 @@ pub async fn delete_device_route( | |||
| 
 | ||||
|     db.users.remove_device(&sender_user, &body.device_id)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_device::Response {}.into()) | ||||
| } | ||||
|  | @ -177,7 +177,7 @@ pub async fn delete_devices_route( | |||
|         db.users.remove_device(&sender_user, &device_id)? | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_devices::Response {}.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -100,7 +100,7 @@ pub async fn set_room_visibility_route( | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_room_visibility::Response {}.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -64,7 +64,7 @@ pub async fn upload_keys_route( | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(upload_keys::Response { | ||||
|         one_time_key_counts: db.users.count_one_time_keys(sender_user, sender_device)?, | ||||
|  | @ -105,7 +105,7 @@ pub async fn claim_keys_route( | |||
| ) -> ConduitResult<claim_keys::Response> { | ||||
|     let response = claim_keys_helper(&body.one_time_keys, &db).await?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(response.into()) | ||||
| } | ||||
|  | @ -166,7 +166,7 @@ pub async fn upload_signing_keys_route( | |||
|         )?; | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(upload_signing_keys::Response {}.into()) | ||||
| } | ||||
|  | @ -227,7 +227,7 @@ pub async fn upload_signatures_route( | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(upload_signatures::Response {}.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -52,7 +52,7 @@ pub async fn create_content_route( | |||
|         ) | ||||
|         .await?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(create_content::Response { | ||||
|         content_uri: mxc.try_into().expect("Invalid mxc:// URI"), | ||||
|  |  | |||
|  | @ -74,7 +74,7 @@ pub async fn join_room_by_id_route( | |||
|     ) | ||||
|     .await; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     ret | ||||
| } | ||||
|  | @ -125,7 +125,7 @@ pub async fn join_room_by_id_or_alias_route( | |||
|     ) | ||||
|     .await?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(join_room_by_id_or_alias::Response { | ||||
|         room_id: join_room_response.0.room_id, | ||||
|  | @ -146,7 +146,7 @@ pub async fn leave_room_route( | |||
| 
 | ||||
|     db.rooms.leave_room(sender_user, &body.room_id, &db).await?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(leave_room::Response::new().into()) | ||||
| } | ||||
|  | @ -164,7 +164,7 @@ pub async fn invite_user_route( | |||
| 
 | ||||
|     if let invite_user::IncomingInvitationRecipient::UserId { user_id } = &body.recipient { | ||||
|         invite_helper(sender_user, user_id, &body.room_id, &db, false).await?; | ||||
|         db.flush().await?; | ||||
|         db.flush()?; | ||||
|         Ok(invite_user::Response {}.into()) | ||||
|     } else { | ||||
|         Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) | ||||
|  | @ -229,7 +229,7 @@ pub async fn kick_user_route( | |||
| 
 | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(kick_user::Response::new().into()) | ||||
| } | ||||
|  | @ -301,7 +301,7 @@ pub async fn ban_user_route( | |||
| 
 | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(ban_user::Response::new().into()) | ||||
| } | ||||
|  | @ -363,7 +363,7 @@ pub async fn unban_user_route( | |||
| 
 | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(unban_user::Response::new().into()) | ||||
| } | ||||
|  | @ -381,7 +381,7 @@ pub async fn forget_room_route( | |||
| 
 | ||||
|     db.rooms.forget(&body.room_id, &sender_user)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(forget_room::Response::new().into()) | ||||
| } | ||||
|  | @ -712,7 +712,7 @@ async fn join_room_by_id_helper( | |||
| 
 | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(join_room_by_id::Response::new(room_id.clone()).into()) | ||||
| } | ||||
|  | @ -788,6 +788,8 @@ pub async fn invite_helper<'a>( | |||
|     db: &Database, | ||||
|     is_direct: bool, | ||||
| ) -> Result<()> { | ||||
|     if user_id.server_name() != db.globals.server_name() { | ||||
|         let (room_version_id, pdu_json, invite_room_state) = { | ||||
|             let mutex = Arc::clone( | ||||
|                 db.globals | ||||
|                     .roomid_mutex | ||||
|  | @ -798,7 +800,6 @@ pub async fn invite_helper<'a>( | |||
|             ); | ||||
|             let mutex_lock = mutex.lock().await; | ||||
| 
 | ||||
|     if user_id.server_name() != db.globals.server_name() { | ||||
|             let prev_events = db | ||||
|                 .rooms | ||||
|                 .get_pdu_leaves(room_id)? | ||||
|  | @ -833,7 +834,8 @@ pub async fn invite_helper<'a>( | |||
|                 .map_or(RoomVersionId::Version6, |create_event| { | ||||
|                     create_event.room_version | ||||
|                 }); | ||||
|         let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); | ||||
|             let room_version = | ||||
|                 RoomVersion::new(&room_version_id).expect("room version is supported"); | ||||
| 
 | ||||
|             let content = serde_json::to_value(MemberEventContent { | ||||
|                 avatar_url: None, | ||||
|  | @ -848,9 +850,13 @@ pub async fn invite_helper<'a>( | |||
|             let state_key = user_id.to_string(); | ||||
|             let kind = EventType::RoomMember; | ||||
| 
 | ||||
|         let auth_events = | ||||
|             db.rooms | ||||
|                 .get_auth_events(room_id, &kind, &sender_user, Some(&state_key), &content)?; | ||||
|             let auth_events = db.rooms.get_auth_events( | ||||
|                 room_id, | ||||
|                 &kind, | ||||
|                 &sender_user, | ||||
|                 Some(&state_key), | ||||
|                 &content, | ||||
|             )?; | ||||
| 
 | ||||
|             // Our depth is the maximum depth of prev_events + 1
 | ||||
|             let depth = prev_events | ||||
|  | @ -934,9 +940,13 @@ pub async fn invite_helper<'a>( | |||
|             ) | ||||
|             .expect("event is valid, we just created it"); | ||||
| 
 | ||||
|             let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; | ||||
| 
 | ||||
|             drop(mutex_lock); | ||||
| 
 | ||||
|         let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; | ||||
|             (room_version_id, pdu_json, invite_room_state) | ||||
|         }; | ||||
| 
 | ||||
|         let response = db | ||||
|             .sending | ||||
|             .send_federation_request( | ||||
|  | @ -1008,6 +1018,17 @@ pub async fn invite_helper<'a>( | |||
|         return Ok(()); | ||||
|     } | ||||
| 
 | ||||
|     let mutex = Arc::clone( | ||||
|         db.globals | ||||
|             .roomid_mutex | ||||
|             .write() | ||||
|             .unwrap() | ||||
|             .entry(room_id.clone()) | ||||
|             .or_default(), | ||||
|     ); | ||||
| 
 | ||||
|     let mutex_lock = mutex.lock().await; | ||||
| 
 | ||||
|     db.rooms.build_and_append_pdu( | ||||
|         PduBuilder { | ||||
|             event_type: EventType::RoomMember, | ||||
|  | @ -1030,5 +1051,7 @@ pub async fn invite_helper<'a>( | |||
|         &mutex_lock, | ||||
|     )?; | ||||
| 
 | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     Ok(()) | ||||
| } | ||||
|  |  | |||
|  | @ -87,7 +87,7 @@ pub async fn send_message_event_route( | |||
| 
 | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(send_message_event::Response::new(event_id).into()) | ||||
| } | ||||
|  |  | |||
|  | @ -41,7 +41,7 @@ pub async fn set_presence_route( | |||
|         )?; | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_presence::Response {}.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -32,9 +32,10 @@ pub async fn set_displayname_route( | |||
|         .set_displayname(&sender_user, body.displayname.clone())?; | ||||
| 
 | ||||
|     // Send a new membership event and presence update into all joined rooms
 | ||||
|     for (pdu_builder, room_id) in db | ||||
|         .rooms | ||||
|         .rooms_joined(&sender_user) | ||||
|     let all_rooms_joined = db.rooms.rooms_joined(&sender_user).collect::<Vec<_>>(); | ||||
| 
 | ||||
|     for (pdu_builder, room_id) in all_rooms_joined | ||||
|         .into_iter() | ||||
|         .filter_map(|r| r.ok()) | ||||
|         .map(|room_id| { | ||||
|             Ok::<_, Error>(( | ||||
|  | @ -109,7 +110,7 @@ pub async fn set_displayname_route( | |||
|         )?; | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_display_name::Response {}.into()) | ||||
| } | ||||
|  | @ -165,9 +166,10 @@ pub async fn set_avatar_url_route( | |||
|     db.users.set_blurhash(&sender_user, body.blurhash.clone())?; | ||||
| 
 | ||||
|     // Send a new membership event and presence update into all joined rooms
 | ||||
|     for (pdu_builder, room_id) in db | ||||
|         .rooms | ||||
|         .rooms_joined(&sender_user) | ||||
|     let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::<Vec<_>>(); | ||||
| 
 | ||||
|     for (pdu_builder, room_id) in all_joined_rooms | ||||
|         .into_iter() | ||||
|         .filter_map(|r| r.ok()) | ||||
|         .map(|room_id| { | ||||
|             Ok::<_, Error>(( | ||||
|  | @ -242,7 +244,7 @@ pub async fn set_avatar_url_route( | |||
|         )?; | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_avatar_url::Response {}.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -192,7 +192,7 @@ pub async fn set_pushrule_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_pushrule::Response {}.into()) | ||||
| } | ||||
|  | @ -248,7 +248,7 @@ pub async fn get_pushrule_actions_route( | |||
|         _ => None, | ||||
|     }; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(get_pushrule_actions::Response { | ||||
|         actions: actions.unwrap_or_default(), | ||||
|  | @ -325,7 +325,7 @@ pub async fn set_pushrule_actions_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_pushrule_actions::Response {}.into()) | ||||
| } | ||||
|  | @ -386,7 +386,7 @@ pub async fn get_pushrule_enabled_route( | |||
|         _ => false, | ||||
|     }; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(get_pushrule_enabled::Response { enabled }.into()) | ||||
| } | ||||
|  | @ -465,7 +465,7 @@ pub async fn set_pushrule_enabled_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_pushrule_enabled::Response {}.into()) | ||||
| } | ||||
|  | @ -534,7 +534,7 @@ pub async fn delete_pushrule_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_pushrule::Response {}.into()) | ||||
| } | ||||
|  | @ -570,7 +570,7 @@ pub async fn set_pushers_route( | |||
| 
 | ||||
|     db.pusher.set_pusher(sender_user, pusher)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_pusher::Response::default().into()) | ||||
| } | ||||
|  |  | |||
|  | @ -75,7 +75,7 @@ pub async fn set_read_marker_route( | |||
|         )?; | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(set_read_marker::Response {}.into()) | ||||
| } | ||||
|  | @ -128,7 +128,7 @@ pub async fn create_receipt_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(create_receipt::Response {}.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -49,7 +49,7 @@ pub async fn redact_event_route( | |||
| 
 | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(redact_event::Response { event_id }.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -301,7 +301,7 @@ pub async fn create_room_route( | |||
| 
 | ||||
|     info!("{} created a room", sender_user); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(create_room::Response::new(room_id).into()) | ||||
| } | ||||
|  | @ -561,7 +561,7 @@ pub async fn upgrade_room_route( | |||
| 
 | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     // Return the replacement room id
 | ||||
|     Ok(upgrade_room::Response { replacement_room }.into()) | ||||
|  |  | |||
|  | @ -143,7 +143,7 @@ pub async fn login_route( | |||
| 
 | ||||
|     info!("{} logged in", user_id); | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(login::Response { | ||||
|         user_id, | ||||
|  | @ -175,7 +175,7 @@ pub async fn logout_route( | |||
| 
 | ||||
|     db.users.remove_device(&sender_user, sender_device)?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(logout::Response::new().into()) | ||||
| } | ||||
|  | @ -204,7 +204,7 @@ pub async fn logout_all_route( | |||
|         db.users.remove_device(&sender_user, &device_id)?; | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(logout_all::Response::new().into()) | ||||
| } | ||||
|  |  | |||
|  | @ -43,7 +43,7 @@ pub async fn send_state_event_for_key_route( | |||
|     ) | ||||
|     .await?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(send_state_event::Response { event_id }.into()) | ||||
| } | ||||
|  | @ -69,7 +69,7 @@ pub async fn send_state_event_for_empty_key_route( | |||
|     ) | ||||
|     .await?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(send_state_event::Response { event_id }.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -186,7 +186,8 @@ async fn sync_helper( | |||
|             .filter_map(|r| r.ok()), | ||||
|     ); | ||||
| 
 | ||||
|     for room_id in db.rooms.rooms_joined(&sender_user) { | ||||
|     let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::<Vec<_>>(); | ||||
|     for room_id in all_joined_rooms { | ||||
|         let room_id = room_id?; | ||||
| 
 | ||||
|         // Get and drop the lock to wait for remaining operations to finish
 | ||||
|  | @ -198,6 +199,7 @@ async fn sync_helper( | |||
|                 .entry(room_id.clone()) | ||||
|                 .or_default(), | ||||
|         ); | ||||
| 
 | ||||
|         let mutex_lock = mutex.lock().await; | ||||
|         drop(mutex_lock); | ||||
| 
 | ||||
|  | @ -658,7 +660,8 @@ async fn sync_helper( | |||
|     } | ||||
| 
 | ||||
|     let mut left_rooms = BTreeMap::new(); | ||||
|     for result in db.rooms.rooms_left(&sender_user) { | ||||
|     let all_left_rooms = db.rooms.rooms_left(&sender_user).collect::<Vec<_>>(); | ||||
|     for result in all_left_rooms { | ||||
|         let (room_id, left_state_events) = result?; | ||||
| 
 | ||||
|         // Get and drop the lock to wait for remaining operations to finish
 | ||||
|  | @ -697,7 +700,8 @@ async fn sync_helper( | |||
|     } | ||||
| 
 | ||||
|     let mut invited_rooms = BTreeMap::new(); | ||||
|     for result in db.rooms.rooms_invited(&sender_user) { | ||||
|     let all_invited_rooms = db.rooms.rooms_invited(&sender_user).collect::<Vec<_>>(); | ||||
|     for result in all_invited_rooms { | ||||
|         let (room_id, invite_state_events) = result?; | ||||
| 
 | ||||
|         // Get and drop the lock to wait for remaining operations to finish
 | ||||
|  |  | |||
|  | @ -40,7 +40,7 @@ pub async fn update_tag_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(create_tag::Response {}.into()) | ||||
| } | ||||
|  | @ -74,7 +74,7 @@ pub async fn delete_tag_route( | |||
|         &db.globals, | ||||
|     )?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(delete_tag::Response {}.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -95,7 +95,7 @@ pub async fn send_event_to_device_route( | |||
|     db.transaction_ids | ||||
|         .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(send_event_to_device::Response {}.into()) | ||||
| } | ||||
|  |  | |||
|  | @ -45,14 +45,8 @@ pub struct Config { | |||
|     database_path: String, | ||||
|     #[serde(default = "default_db_cache_capacity_mb")] | ||||
|     db_cache_capacity_mb: f64, | ||||
|     #[serde(default = "default_sqlite_read_pool_size")] | ||||
|     sqlite_read_pool_size: usize, | ||||
|     #[serde(default = "default_sqlite_wal_clean_second_interval")] | ||||
|     sqlite_wal_clean_second_interval: u32, | ||||
|     #[serde(default = "default_sqlite_spillover_reap_fraction")] | ||||
|     sqlite_spillover_reap_fraction: f64, | ||||
|     #[serde(default = "default_sqlite_spillover_reap_interval_secs")] | ||||
|     sqlite_spillover_reap_interval_secs: u32, | ||||
|     #[serde(default = "default_max_request_size")] | ||||
|     max_request_size: u32, | ||||
|     #[serde(default = "default_max_concurrent_requests")] | ||||
|  | @ -111,22 +105,10 @@ fn default_db_cache_capacity_mb() -> f64 { | |||
|     200.0 | ||||
| } | ||||
| 
 | ||||
| fn default_sqlite_read_pool_size() -> usize { | ||||
|     num_cpus::get().max(1) | ||||
| } | ||||
| 
 | ||||
| fn default_sqlite_wal_clean_second_interval() -> u32 { | ||||
|     15 * 60 // every 15 minutes
 | ||||
| } | ||||
| 
 | ||||
| fn default_sqlite_spillover_reap_fraction() -> f64 { | ||||
|     0.5 | ||||
| } | ||||
| 
 | ||||
| fn default_sqlite_spillover_reap_interval_secs() -> u32 { | ||||
|     60 | ||||
| } | ||||
| 
 | ||||
| fn default_max_request_size() -> u32 { | ||||
|     20 * 1024 * 1024 // Default to 20 MB
 | ||||
| } | ||||
|  | @ -458,7 +440,6 @@ impl Database { | |||
|         #[cfg(feature = "sqlite")] | ||||
|         { | ||||
|             Self::start_wal_clean_task(Arc::clone(&db), &config).await; | ||||
|             Self::start_spillover_reap_task(builder, &config).await; | ||||
|         } | ||||
| 
 | ||||
|         Ok(db) | ||||
|  | @ -568,7 +549,7 @@ impl Database { | |||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     pub async fn flush(&self) -> Result<()> { | ||||
|     pub fn flush(&self) -> Result<()> { | ||||
|         let start = std::time::Instant::now(); | ||||
| 
 | ||||
|         let res = self._db.flush(); | ||||
|  | @ -584,33 +565,6 @@ impl Database { | |||
|         self._db.flush_wal() | ||||
|     } | ||||
| 
 | ||||
|     #[cfg(feature = "sqlite")] | ||||
|     #[tracing::instrument(skip(engine, config))] | ||||
|     pub async fn start_spillover_reap_task(engine: Arc<Engine>, config: &Config) { | ||||
|         let fraction = config.sqlite_spillover_reap_fraction.clamp(0.01, 1.0); | ||||
|         let interval_secs = config.sqlite_spillover_reap_interval_secs as u64; | ||||
| 
 | ||||
|         let weak = Arc::downgrade(&engine); | ||||
| 
 | ||||
|         tokio::spawn(async move { | ||||
|             use tokio::time::interval; | ||||
| 
 | ||||
|             use std::{sync::Weak, time::Duration}; | ||||
| 
 | ||||
|             let mut i = interval(Duration::from_secs(interval_secs)); | ||||
| 
 | ||||
|             loop { | ||||
|                 i.tick().await; | ||||
| 
 | ||||
|                 if let Some(arc) = Weak::upgrade(&weak) { | ||||
|                     arc.reap_spillover_by_fraction(fraction); | ||||
|                 } else { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     #[cfg(feature = "sqlite")] | ||||
|     #[tracing::instrument(skip(db, config))] | ||||
|     pub async fn start_wal_clean_task(db: Arc<TokioRwLock<Self>>, config: &Config) { | ||||
|  |  | |||
|  | @ -28,20 +28,20 @@ pub trait Tree: Send + Sync { | |||
| 
 | ||||
|     fn remove(&self, key: &[u8]) -> Result<()>; | ||||
| 
 | ||||
|     fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + Send + 'a>; | ||||
|     fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>; | ||||
| 
 | ||||
|     fn iter_from<'a>( | ||||
|         &'a self, | ||||
|         from: &[u8], | ||||
|         backwards: bool, | ||||
|     ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + Send + 'a>; | ||||
|     ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>; | ||||
| 
 | ||||
|     fn increment(&self, key: &[u8]) -> Result<Vec<u8>>; | ||||
| 
 | ||||
|     fn scan_prefix<'a>( | ||||
|         &'a self, | ||||
|         prefix: Vec<u8>, | ||||
|     ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + Send + 'a>; | ||||
|     ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>; | ||||
| 
 | ||||
|     fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>; | ||||
| 
 | ||||
|  |  | |||
|  | @ -81,7 +81,7 @@ impl EngineTree { | |||
|         let (s, r) = bounded::<TupleOfBytes>(100); | ||||
|         let engine = Arc::clone(&self.engine); | ||||
| 
 | ||||
|         let lock = self.engine.iter_pool.lock().unwrap(); | ||||
|         let lock = self.engine.iter_pool.lock().await; | ||||
|         if lock.active_count() < lock.max_count() { | ||||
|             lock.execute(move || { | ||||
|                 iter_from_thread_work(tree, &engine.env.read_txn().unwrap(), from, backwards, &s); | ||||
|  |  | |||
|  | @ -1,133 +1,61 @@ | |||
| use super::{DatabaseEngine, Tree}; | ||||
| use crate::{database::Config, Result}; | ||||
| use crossbeam::channel::{ | ||||
|     bounded, unbounded, Receiver as ChannelReceiver, Sender as ChannelSender, TryRecvError, | ||||
| }; | ||||
| use parking_lot::{Mutex, MutexGuard, RwLock}; | ||||
| use rusqlite::{Connection, DatabaseName::Main, OptionalExtension, Params}; | ||||
| use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; | ||||
| use std::{ | ||||
|     cell::RefCell, | ||||
|     collections::HashMap, | ||||
|     future::Future, | ||||
|     ops::Deref, | ||||
|     path::{Path, PathBuf}, | ||||
|     pin::Pin, | ||||
|     sync::Arc, | ||||
|     time::{Duration, Instant}, | ||||
| }; | ||||
| use threadpool::ThreadPool; | ||||
| use tokio::sync::oneshot::Sender; | ||||
| use tracing::{debug, warn}; | ||||
| 
 | ||||
| struct Pool { | ||||
|     writer: Mutex<Connection>, | ||||
|     readers: Vec<Mutex<Connection>>, | ||||
|     spills: ConnectionRecycler, | ||||
|     spill_tracker: Arc<()>, | ||||
|     path: PathBuf, | ||||
| } | ||||
| 
 | ||||
| pub const MILLI: Duration = Duration::from_millis(1); | ||||
| 
 | ||||
| enum HoldingConn<'a> { | ||||
|     FromGuard(MutexGuard<'a, Connection>), | ||||
|     FromRecycled(RecycledConn, Arc<()>), | ||||
| thread_local! { | ||||
|     static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None); | ||||
| } | ||||
| 
 | ||||
| impl<'a> Deref for HoldingConn<'a> { | ||||
|     type Target = Connection; | ||||
| 
 | ||||
|     fn deref(&self) -> &Self::Target { | ||||
|         match self { | ||||
|             HoldingConn::FromGuard(guard) => guard.deref(), | ||||
|             HoldingConn::FromRecycled(conn, _) => conn.deref(), | ||||
|         } | ||||
|     } | ||||
| struct PreparedStatementIterator<'a> { | ||||
|     pub iterator: Box<dyn Iterator<Item = TupleOfBytes> + 'a>, | ||||
|     pub statement_ref: NonAliasingBox<rusqlite::Statement<'a>>, | ||||
| } | ||||
| 
 | ||||
| struct ConnectionRecycler(ChannelSender<Connection>, ChannelReceiver<Connection>); | ||||
| impl Iterator for PreparedStatementIterator<'_> { | ||||
|     type Item = TupleOfBytes; | ||||
| 
 | ||||
| impl ConnectionRecycler { | ||||
|     fn new() -> Self { | ||||
|         let (s, r) = unbounded(); | ||||
|         Self(s, r) | ||||
|     } | ||||
| 
 | ||||
|     fn recycle(&self, conn: Connection) -> RecycledConn { | ||||
|         let sender = self.0.clone(); | ||||
| 
 | ||||
|         RecycledConn(Some(conn), sender) | ||||
|     } | ||||
| 
 | ||||
|     fn try_take(&self) -> Option<Connection> { | ||||
|         match self.1.try_recv() { | ||||
|             Ok(conn) => Some(conn), | ||||
|             Err(TryRecvError::Empty) => None, | ||||
|             // as this is pretty impossible, a panic is warranted if it ever occurs
 | ||||
|             Err(TryRecvError::Disconnected) => panic!("Receiving channel was disconnected. A a sender is owned by the current struct, this should never happen(!!!)") | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| struct RecycledConn( | ||||
|     Option<Connection>, // To allow moving out of the struct when `Drop` is called.
 | ||||
|     ChannelSender<Connection>, | ||||
| ); | ||||
| 
 | ||||
| impl Deref for RecycledConn { | ||||
|     type Target = Connection; | ||||
| 
 | ||||
|     fn deref(&self) -> &Self::Target { | ||||
|         self.0 | ||||
|             .as_ref() | ||||
|             .expect("RecycledConn does not have a connection in Option<>") | ||||
|     fn next(&mut self) -> Option<Self::Item> { | ||||
|         self.iterator.next() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Drop for RecycledConn { | ||||
| struct NonAliasingBox<T>(*mut T); | ||||
| impl<T> Drop for NonAliasingBox<T> { | ||||
|     fn drop(&mut self) { | ||||
|         if let Some(conn) = self.0.take() { | ||||
|             debug!("Recycled connection"); | ||||
|             if let Err(e) = self.1.send(conn) { | ||||
|                 warn!("Recycling a connection led to the following error: {:?}", e) | ||||
|             } | ||||
|         } | ||||
|         unsafe { Box::from_raw(self.0) }; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Pool { | ||||
|     fn new<P: AsRef<Path>>(path: P, num_readers: usize, total_cache_size_mb: f64) -> Result<Self> { | ||||
|         // calculates cache-size per permanent connection
 | ||||
|         // 1. convert MB to KiB
 | ||||
|         // 2. divide by permanent connections
 | ||||
|         // 3. round down to nearest integer
 | ||||
|         let cache_size: u32 = ((total_cache_size_mb * 1024.0) / (num_readers + 1) as f64) as u32; | ||||
| pub struct Engine { | ||||
|     writer: Mutex<Connection>, | ||||
| 
 | ||||
|         let writer = Mutex::new(Self::prepare_conn(&path, Some(cache_size))?); | ||||
| 
 | ||||
|         let mut readers = Vec::new(); | ||||
| 
 | ||||
|         for _ in 0..num_readers { | ||||
|             readers.push(Mutex::new(Self::prepare_conn(&path, Some(cache_size))?)) | ||||
|     path: PathBuf, | ||||
|     cache_size_per_thread: u32, | ||||
| } | ||||
| 
 | ||||
|         Ok(Self { | ||||
|             writer, | ||||
|             readers, | ||||
|             spills: ConnectionRecycler::new(), | ||||
|             spill_tracker: Arc::new(()), | ||||
|             path: path.as_ref().to_path_buf(), | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     fn prepare_conn<P: AsRef<Path>>(path: P, cache_size: Option<u32>) -> Result<Connection> { | ||||
|         let conn = Connection::open(path)?; | ||||
| impl Engine { | ||||
|     fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> { | ||||
|         let conn = Connection::open(&path)?; | ||||
| 
 | ||||
|         conn.pragma_update(Some(Main), "page_size", &32768)?; | ||||
|         conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; | ||||
|         conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; | ||||
| 
 | ||||
|         if let Some(cache_kib) = cache_size { | ||||
|             conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_kib)))?; | ||||
|         } | ||||
|         conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; | ||||
|         conn.pragma_update(Some(Main), "wal_autocheckpoint", &0)?; | ||||
| 
 | ||||
|         Ok(conn) | ||||
|     } | ||||
|  | @ -136,68 +64,52 @@ impl Pool { | |||
|         self.writer.lock() | ||||
|     } | ||||
| 
 | ||||
|     fn read_lock(&self) -> HoldingConn<'_> { | ||||
|         // First try to get a connection from the permanent pool
 | ||||
|         for r in &self.readers { | ||||
|             if let Some(reader) = r.try_lock() { | ||||
|                 return HoldingConn::FromGuard(reader); | ||||
|             } | ||||
|     fn read_lock(&self) -> &'static Connection { | ||||
|         READ_CONNECTION.with(|cell| { | ||||
|             let connection = &mut cell.borrow_mut(); | ||||
| 
 | ||||
|             if (*connection).is_none() { | ||||
|                 let c = Box::leak(Box::new( | ||||
|                     Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(), | ||||
|                 )); | ||||
|                 **connection = Some(c); | ||||
|             } | ||||
| 
 | ||||
|         debug!("read_lock: All permanent readers locked, obtaining spillover reader..."); | ||||
| 
 | ||||
|         // We didn't get a connection from the permanent pool, so we'll dumpster-dive for recycled connections.
 | ||||
|         // Either we have a connection or we dont, if we don't, we make a new one.
 | ||||
|         let conn = match self.spills.try_take() { | ||||
|             Some(conn) => conn, | ||||
|             None => { | ||||
|                 debug!("read_lock: No recycled connections left, creating new one..."); | ||||
|                 Self::prepare_conn(&self.path, None).unwrap() | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         // Clone the spill Arc to mark how many spilled connections actually exist.
 | ||||
|         let spill_arc = Arc::clone(&self.spill_tracker); | ||||
| 
 | ||||
|         // Get a sense of how many connections exist now.
 | ||||
|         let now_count = Arc::strong_count(&spill_arc) - 1 /* because one is held by the pool */; | ||||
| 
 | ||||
|         // If the spillover readers are more than the number of total readers, there might be a problem.
 | ||||
|         if now_count > self.readers.len() { | ||||
|             warn!( | ||||
|                 "Database is under high load. Consider increasing sqlite_read_pool_size ({} spillover readers exist)", | ||||
|                 now_count | ||||
|             ); | ||||
|             connection.unwrap() | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|         // Return the recyclable connection.
 | ||||
|         HoldingConn::FromRecycled(self.spills.recycle(conn), spill_arc) | ||||
|     pub fn flush_wal(self: &Arc<Self>) -> Result<()> { | ||||
|         self.write_lock() | ||||
|             .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub struct Engine { | ||||
|     pool: Pool, | ||||
|     iter_pool: Mutex<ThreadPool>, | ||||
| } | ||||
| 
 | ||||
| impl DatabaseEngine for Engine { | ||||
|     fn open(config: &Config) -> Result<Arc<Self>> { | ||||
|         let pool = Pool::new( | ||||
|             Path::new(&config.database_path).join("conduit.db"), | ||||
|             config.sqlite_read_pool_size, | ||||
|             config.db_cache_capacity_mb, | ||||
|         )?; | ||||
|         let path = Path::new(&config.database_path).join("conduit.db"); | ||||
| 
 | ||||
|         // calculates cache-size per permanent connection
 | ||||
|         // 1. convert MB to KiB
 | ||||
|         // 2. divide by permanent connections
 | ||||
|         // 3. round down to nearest integer
 | ||||
|         let cache_size_per_thread: u32 = | ||||
|             ((config.db_cache_capacity_mb * 1024.0) / (num_cpus::get().max(1) + 1) as f64) as u32; | ||||
| 
 | ||||
|         let writer = Mutex::new(Self::prepare_conn(&path, cache_size_per_thread)?); | ||||
| 
 | ||||
|         let arc = Arc::new(Engine { | ||||
|             pool, | ||||
|             iter_pool: Mutex::new(ThreadPool::new(10)), | ||||
|             writer, | ||||
|             path, | ||||
|             cache_size_per_thread, | ||||
|         }); | ||||
| 
 | ||||
|         Ok(arc) | ||||
|     } | ||||
| 
 | ||||
|     fn open_tree(self: &Arc<Self>, name: &str) -> Result<Arc<dyn Tree>> { | ||||
|         self.pool.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; | ||||
|         self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; | ||||
| 
 | ||||
|         Ok(Arc::new(SqliteTable { | ||||
|             engine: Arc::clone(self), | ||||
|  | @ -212,31 +124,6 @@ impl DatabaseEngine for Engine { | |||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Engine { | ||||
|     pub fn flush_wal(self: &Arc<Self>) -> Result<()> { | ||||
|         self.pool.write_lock().pragma_update(Some(Main), "wal_checkpoint", &"RESTART")?; | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     // Reaps (at most) (.len() * `fraction`) (rounded down, min 1) connections.
 | ||||
|     pub fn reap_spillover_by_fraction(&self, fraction: f64) { | ||||
|         let mut reaped = 0; | ||||
| 
 | ||||
|         let spill_amount = self.pool.spills.1.len() as f64; | ||||
|         let fraction = fraction.clamp(0.01, 1.0); | ||||
| 
 | ||||
|         let amount = (spill_amount * fraction).max(1.0) as u32; | ||||
| 
 | ||||
|         for _ in 0..amount { | ||||
|             if self.pool.spills.try_take().is_some() { | ||||
|                 reaped += 1; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         debug!("Reaped {} connections", reaped); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub struct SqliteTable { | ||||
|     engine: Arc<Engine>, | ||||
|     name: String, | ||||
|  | @ -258,7 +145,7 @@ impl SqliteTable { | |||
|     fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { | ||||
|         guard.execute( | ||||
|             format!( | ||||
|                 "INSERT INTO {} (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", | ||||
|                 "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", | ||||
|                 self.name | ||||
|             ) | ||||
|             .as_str(), | ||||
|  | @ -266,70 +153,17 @@ impl SqliteTable { | |||
|         )?; | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self, sql, param))] | ||||
|     fn iter_from_thread( | ||||
|         &self, | ||||
|         sql: String, | ||||
|         param: Option<Vec<u8>>, | ||||
|     ) -> Box<dyn Iterator<Item = TupleOfBytes> + Send + Sync> { | ||||
|         let (s, r) = bounded::<TupleOfBytes>(5); | ||||
| 
 | ||||
|         let engine = Arc::clone(&self.engine); | ||||
| 
 | ||||
|         let lock = self.engine.iter_pool.lock(); | ||||
|         if lock.active_count() < lock.max_count() { | ||||
|             lock.execute(move || { | ||||
|                 if let Some(param) = param { | ||||
|                     iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, [param]); | ||||
|                 } else { | ||||
|                     iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, []); | ||||
|                 } | ||||
|             }); | ||||
|         } else { | ||||
|             std::thread::spawn(move || { | ||||
|                 if let Some(param) = param { | ||||
|                     iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, [param]); | ||||
|                 } else { | ||||
|                     iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, []); | ||||
|                 } | ||||
|             }); | ||||
|         } | ||||
| 
 | ||||
|         Box::new(r.into_iter()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn iter_from_thread_work<P>( | ||||
|     guard: &HoldingConn<'_>, | ||||
|     s: &ChannelSender<(Vec<u8>, Vec<u8>)>, | ||||
|     sql: &str, | ||||
|     params: P, | ||||
| ) where | ||||
|     P: Params, | ||||
| { | ||||
|     for bob in guard | ||||
|         .prepare(sql) | ||||
|         .unwrap() | ||||
|         .query_map(params, |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) | ||||
|         .unwrap() | ||||
|         .map(|r| r.unwrap()) | ||||
|     { | ||||
|         if s.send(bob).is_err() { | ||||
|             return; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Tree for SqliteTable { | ||||
|     #[tracing::instrument(skip(self, key))] | ||||
|     fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { | ||||
|         self.get_with_guard(&self.engine.pool.read_lock(), key) | ||||
|         self.get_with_guard(&self.engine.read_lock(), key) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self, key, value))] | ||||
|     fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { | ||||
|         let guard = self.engine.pool.write_lock(); | ||||
|         let guard = self.engine.write_lock(); | ||||
| 
 | ||||
|         let start = Instant::now(); | ||||
| 
 | ||||
|  | @ -337,7 +171,7 @@ impl Tree for SqliteTable { | |||
| 
 | ||||
|         let elapsed = start.elapsed(); | ||||
|         if elapsed > MILLI { | ||||
|             debug!("insert:    took {:012?} : {}", elapsed, &self.name); | ||||
|             warn!("insert took {:?} : {}", elapsed, &self.name); | ||||
|         } | ||||
| 
 | ||||
|         drop(guard); | ||||
|  | @ -369,7 +203,7 @@ impl Tree for SqliteTable { | |||
| 
 | ||||
|     #[tracing::instrument(skip(self, key))] | ||||
|     fn remove(&self, key: &[u8]) -> Result<()> { | ||||
|         let guard = self.engine.pool.write_lock(); | ||||
|         let guard = self.engine.write_lock(); | ||||
| 
 | ||||
|         let start = Instant::now(); | ||||
| 
 | ||||
|  | @ -389,9 +223,28 @@ impl Tree for SqliteTable { | |||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + Send + 'a> { | ||||
|         let name = self.name.clone(); | ||||
|         self.iter_from_thread(format!("SELECT key, value FROM {}", name), None) | ||||
|     fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { | ||||
|         let guard = self.engine.read_lock(); | ||||
| 
 | ||||
|         let statement = Box::leak(Box::new( | ||||
|             guard | ||||
|                 .prepare(&format!("SELECT key, value FROM {}", &self.name)) | ||||
|                 .unwrap(), | ||||
|         )); | ||||
| 
 | ||||
|         let statement_ref = NonAliasingBox(statement); | ||||
| 
 | ||||
|         let iterator = Box::new( | ||||
|             statement | ||||
|                 .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) | ||||
|                 .unwrap() | ||||
|                 .map(|r| r.unwrap()), | ||||
|         ); | ||||
| 
 | ||||
|         Box::new(PreparedStatementIterator { | ||||
|             iterator, | ||||
|             statement_ref, | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self, from, backwards))] | ||||
|  | @ -399,31 +252,61 @@ impl Tree for SqliteTable { | |||
|         &'a self, | ||||
|         from: &[u8], | ||||
|         backwards: bool, | ||||
|     ) -> Box<dyn Iterator<Item = TupleOfBytes> + Send + 'a> { | ||||
|         let name = self.name.clone(); | ||||
|     ) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { | ||||
|         let guard = self.engine.read_lock(); | ||||
|         let from = from.to_vec(); // TODO change interface?
 | ||||
| 
 | ||||
|         if backwards { | ||||
|             self.iter_from_thread( | ||||
|                 format!( | ||||
|             let statement = Box::leak(Box::new( | ||||
|                 guard | ||||
|                     .prepare(&format!( | ||||
|                         "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", | ||||
|                     name | ||||
|                 ), | ||||
|                 Some(from), | ||||
|             ) | ||||
|                         &self.name | ||||
|                     )) | ||||
|                     .unwrap(), | ||||
|             )); | ||||
| 
 | ||||
|             let statement_ref = NonAliasingBox(statement); | ||||
| 
 | ||||
|             let iterator = Box::new( | ||||
|                 statement | ||||
|                     .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) | ||||
|                     .unwrap() | ||||
|                     .map(|r| r.unwrap()), | ||||
|             ); | ||||
|             Box::new(PreparedStatementIterator { | ||||
|                 iterator, | ||||
|                 statement_ref, | ||||
|             }) | ||||
|         } else { | ||||
|             self.iter_from_thread( | ||||
|                 format!( | ||||
|             let statement = Box::leak(Box::new( | ||||
|                 guard | ||||
|                     .prepare(&format!( | ||||
|                         "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", | ||||
|                     name | ||||
|                 ), | ||||
|                 Some(from), | ||||
|             ) | ||||
|                         &self.name | ||||
|                     )) | ||||
|                     .unwrap(), | ||||
|             )); | ||||
| 
 | ||||
|             let statement_ref = NonAliasingBox(statement); | ||||
| 
 | ||||
|             let iterator = Box::new( | ||||
|                 statement | ||||
|                     .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) | ||||
|                     .unwrap() | ||||
|                     .map(|r| r.unwrap()), | ||||
|             ); | ||||
| 
 | ||||
|             Box::new(PreparedStatementIterator { | ||||
|                 iterator, | ||||
|                 statement_ref, | ||||
|             }) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self, key))] | ||||
|     fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { | ||||
|         let guard = self.engine.pool.write_lock(); | ||||
|         let guard = self.engine.write_lock(); | ||||
| 
 | ||||
|         let start = Instant::now(); | ||||
| 
 | ||||
|  | @ -445,10 +328,7 @@ impl Tree for SqliteTable { | |||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self, prefix))] | ||||
|     fn scan_prefix<'a>( | ||||
|         &'a self, | ||||
|         prefix: Vec<u8>, | ||||
|     ) -> Box<dyn Iterator<Item = TupleOfBytes> + Send + 'a> { | ||||
|     fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { | ||||
|         // let name = self.name.clone();
 | ||||
|         // self.iter_from_thread(
 | ||||
|         //     format!(
 | ||||
|  | @ -483,25 +363,9 @@ impl Tree for SqliteTable { | |||
|     fn clear(&self) -> Result<()> { | ||||
|         debug!("clear: running"); | ||||
|         self.engine | ||||
|             .pool | ||||
|             .write_lock() | ||||
|             .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; | ||||
|         debug!("clear: ran"); | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| // TODO
 | ||||
| // struct Pool<const NUM_READERS: usize> {
 | ||||
| //     writer: Mutex<Connection>,
 | ||||
| //     readers: [Mutex<Connection>; NUM_READERS],
 | ||||
| // }
 | ||||
| 
 | ||||
| // // then, to pick a reader:
 | ||||
| // for r in &pool.readers {
 | ||||
| //     if let Ok(reader) = r.try_lock() {
 | ||||
| //         // use reader
 | ||||
| //     }
 | ||||
| // }
 | ||||
| // // none unlocked, pick the next reader
 | ||||
| // pool.readers[pool.counter.fetch_add(1, Relaxed) % NUM_READERS].lock()
 | ||||
|  |  | |||
|  | @ -49,22 +49,23 @@ impl Appservice { | |||
|             ) | ||||
|     } | ||||
| 
 | ||||
|     pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + Send + '_> { | ||||
|     pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { | ||||
|         Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { | ||||
|             utils::string_from_bytes(&id) | ||||
|                 .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) | ||||
|         })) | ||||
|     } | ||||
| 
 | ||||
|     pub fn iter_all( | ||||
|         &self, | ||||
|     ) -> Result<impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ + Send> { | ||||
|         Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| { | ||||
|     pub fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>> { | ||||
|         self.iter_ids()? | ||||
|             .filter_map(|id| id.ok()) | ||||
|             .map(move |id| { | ||||
|                 Ok(( | ||||
|                     id.clone(), | ||||
|                     self.get_registration(&id)? | ||||
|                         .expect("iter_ids only returns appservices that exist"), | ||||
|                 )) | ||||
|         })) | ||||
|             }) | ||||
|             .collect() | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -15,7 +15,7 @@ use std::{ | |||
|     sync::{Arc, RwLock}, | ||||
|     time::{Duration, Instant}, | ||||
| }; | ||||
| use tokio::sync::{broadcast, watch::Receiver, Mutex, Semaphore}; | ||||
| use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; | ||||
| use tracing::{error, info}; | ||||
| use trust_dns_resolver::TokioAsyncResolver; | ||||
| 
 | ||||
|  | @ -45,8 +45,8 @@ pub struct Globals { | |||
|     pub bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>, | ||||
|     pub servername_ratelimiter: Arc<RwLock<HashMap<Box<ServerName>, Arc<Semaphore>>>>, | ||||
|     pub sync_receivers: RwLock<HashMap<(UserId, Box<DeviceId>), SyncHandle>>, | ||||
|     pub roomid_mutex: RwLock<HashMap<RoomId, Arc<Mutex<()>>>>, | ||||
|     pub roomid_mutex_federation: RwLock<HashMap<RoomId, Arc<Mutex<()>>>>, // this lock will be held longer
 | ||||
|     pub roomid_mutex: RwLock<HashMap<RoomId, Arc<TokioMutex<()>>>>, | ||||
|     pub roomid_mutex_federation: RwLock<HashMap<RoomId, Arc<TokioMutex<()>>>>, // this lock will be held longer
 | ||||
|     pub rotate: RotationHandler, | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -101,8 +101,8 @@ impl Media { | |||
|         prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
 | ||||
|         prefix.push(0xff); | ||||
| 
 | ||||
|         let mut iter = self.mediaid_file.scan_prefix(prefix); | ||||
|         if let Some((key, _)) = iter.next() { | ||||
|         let first = self.mediaid_file.scan_prefix(prefix).next(); | ||||
|         if let Some((key, _)) = first { | ||||
|             let path = globals.get_media_file(&key); | ||||
|             let mut file = Vec::new(); | ||||
|             File::open(path).await?.read_to_end(&mut file).await?; | ||||
|  | @ -190,7 +190,9 @@ impl Media { | |||
|         original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
 | ||||
|         original_prefix.push(0xff); | ||||
| 
 | ||||
|         if let Some((key, _)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() { | ||||
|         let first_thumbnailprefix = self.mediaid_file.scan_prefix(thumbnail_prefix).next(); | ||||
|         let first_originalprefix = self.mediaid_file.scan_prefix(original_prefix).next(); | ||||
|         if let Some((key, _)) = first_thumbnailprefix { | ||||
|             // Using saved thumbnail
 | ||||
|             let path = globals.get_media_file(&key); | ||||
|             let mut file = Vec::new(); | ||||
|  | @ -225,7 +227,7 @@ impl Media { | |||
|                 content_type, | ||||
|                 file: file.to_vec(), | ||||
|             })) | ||||
|         } else if let Some((key, _)) = self.mediaid_file.scan_prefix(original_prefix).next() { | ||||
|         } else if let Some((key, _)) = first_originalprefix { | ||||
|             // Generate a thumbnail
 | ||||
|             let path = globals.get_media_file(&key); | ||||
|             let mut file = Vec::new(); | ||||
|  |  | |||
|  | @ -2,7 +2,6 @@ mod edus; | |||
| 
 | ||||
| pub use edus::RoomEdus; | ||||
| use member::MembershipState; | ||||
| use tokio::sync::MutexGuard; | ||||
| 
 | ||||
| use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; | ||||
| use lru_cache::LruCache; | ||||
|  | @ -28,6 +27,7 @@ use std::{ | |||
|     mem, | ||||
|     sync::{Arc, Mutex}, | ||||
| }; | ||||
| use tokio::sync::MutexGuard; | ||||
| use tracing::{debug, error, warn}; | ||||
| 
 | ||||
| use super::{abstraction::Tree, admin::AdminCommand, pusher}; | ||||
|  | @ -1496,7 +1496,7 @@ impl Rooms { | |||
|             db.sending.send_pdu(&server, &pdu_id)?; | ||||
|         } | ||||
| 
 | ||||
|         for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { | ||||
|         for appservice in db.appservice.all()? { | ||||
|             if let Some(namespaces) = appservice.1.get("namespaces") { | ||||
|                 let users = namespaces | ||||
|                     .get("users") | ||||
|  |  | |||
|  | @ -75,9 +75,9 @@ where | |||
|             registration, | ||||
|         )) = db | ||||
|             .appservice | ||||
|             .iter_all() | ||||
|             .all() | ||||
|             .unwrap() | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .iter() | ||||
|             .find(|(_id, registration)| { | ||||
|                 registration | ||||
|                     .get("as_token") | ||||
|  |  | |||
|  | @ -806,7 +806,7 @@ pub async fn send_transaction_message_route( | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(send_transaction_message::v1::Response { pdus: resolved_map }.into()) | ||||
| } | ||||
|  | @ -1343,7 +1343,6 @@ pub fn handle_incoming_pdu<'a>( | |||
|                     &state_at_incoming_event, | ||||
|                     &mutex_lock, | ||||
|                 ) | ||||
|                 .await | ||||
|                 .map_err(|_| "Failed to add pdu to db.".to_owned())?, | ||||
|             ); | ||||
|             debug!("Appended incoming pdu."); | ||||
|  | @ -1643,7 +1642,7 @@ pub(crate) async fn fetch_signing_keys( | |||
| /// Append the incoming event setting the state snapshot to the state from the
 | ||||
| /// server that sent the event.
 | ||||
| #[tracing::instrument(skip(db, pdu, pdu_json, new_room_leaves, state, _mutex_lock))] | ||||
| async fn append_incoming_pdu( | ||||
| fn append_incoming_pdu( | ||||
|     db: &Database, | ||||
|     pdu: &PduEvent, | ||||
|     pdu_json: CanonicalJsonObject, | ||||
|  | @ -1663,7 +1662,7 @@ async fn append_incoming_pdu( | |||
|         &db, | ||||
|     )?; | ||||
| 
 | ||||
|     for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { | ||||
|     for appservice in db.appservice.all()? { | ||||
|         if let Some(namespaces) = appservice.1.get("namespaces") { | ||||
|             let users = namespaces | ||||
|                 .get("users") | ||||
|  | @ -2208,7 +2207,7 @@ pub async fn create_join_event_route( | |||
|         db.sending.send_pdu(&server, &pdu_id)?; | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(create_join_event::v2::Response { | ||||
|         room_state: RoomState { | ||||
|  | @ -2327,7 +2326,7 @@ pub async fn create_invite_route( | |||
|         )?; | ||||
|     } | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(create_invite::v2::Response { | ||||
|         event: PduEvent::convert_to_outgoing_federation_event(signed_event), | ||||
|  | @ -2464,7 +2463,7 @@ pub async fn get_keys_route( | |||
|     ) | ||||
|     .await?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(get_keys::v1::Response { | ||||
|         device_keys: result.device_keys, | ||||
|  | @ -2489,7 +2488,7 @@ pub async fn claim_keys_route( | |||
| 
 | ||||
|     let result = claim_keys_helper(&body.one_time_keys, &db).await?; | ||||
| 
 | ||||
|     db.flush().await?; | ||||
|     db.flush()?; | ||||
| 
 | ||||
|     Ok(claim_keys::v1::Response { | ||||
|         one_time_keys: result.one_time_keys, | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue