diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index a0935c27..55ee480e 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -984,7 +984,7 @@ impl OlmMachine { &event.content.room_id, session_key, )?; - let _ = self.store.save_inbound_group_session(session).await?; + let _ = self.store.save_inbound_group_sessions(&[session]).await?; let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); Ok(Some(event)) @@ -1014,7 +1014,7 @@ impl OlmMachine { .await .map_err(|_| EventError::UnsupportedAlgorithm)?; - let _ = self.store.save_inbound_group_session(inbound).await?; + let _ = self.store.save_inbound_group_sessions(&[inbound]).await?; let _ = self .outbound_group_sessions diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index f7d3d753..ec4a5246 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -80,8 +80,12 @@ impl CryptoStore for MemoryStore { Ok(self.sessions.get(sender_key)) } - async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result { - Ok(self.inbound_group_sessions.add(session)) + async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { + for session in sessions { + self.inbound_group_sessions.add(session.clone()); + } + + Ok(()) } async fn get_inbound_group_session( @@ -208,7 +212,7 @@ mod test { let store = MemoryStore::new(); let _ = store - .save_inbound_group_session(inbound.clone()) + .save_inbound_group_sessions(&[inbound.clone()]) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index a7cce545..2efd7d29 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -157,15 +157,12 @@ pub trait CryptoStore: Debug { /// * `sender_key` - The sender key that was used to establish the sessions. async fn get_sessions(&self, sender_key: &str) -> Result>>>>; - /// Save the given inbound group session in the store. - /// - /// If the session wasn't already in the store true is returned, false - /// otherwise. + /// Save the given inbound group sessions in the store. /// /// # Arguments /// - /// * `session` - The session that should be stored. - async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result; + /// * `sessions` - The sessions that should be stored. + async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>; /// Get the inbound group session from our store. /// diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 219ffb06..00dfccd5 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -796,6 +796,64 @@ impl SqliteStore { } } + async fn save_inbound_group_session_helper( + &self, + account_id: i64, + connection: &mut SqliteConnection, + session: &InboundGroupSession, + ) -> Result<()> { + let pickle = session.pickle(self.get_pickle_mode()).await; + let session_id = session.session_id(); + + // FIXME we need to store/restore the forwarding chains. + // FIXME this should be converted so it accepts an array of sessions for + // the key import feature. + + query( + "REPLACE INTO inbound_group_sessions ( + session_id, account_id, sender_key, + room_id, pickle, imported + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) + ", + ) + .bind(session_id) + .bind(account_id) + .bind(&pickle.sender_key) + .bind(pickle.room_id.as_str()) + .bind(pickle.pickle.as_str()) + .bind(pickle.imported) + .execute(&mut *connection) + .await?; + + let row: (i64,) = query_as( + "SELECT id FROM inbound_group_sessions + WHERE account_id = ? and session_id = ? and sender_key = ?", + ) + .bind(account_id) + .bind(session_id) + .bind(pickle.sender_key) + .fetch_one(&mut *connection) + .await?; + + let session_row_id = row.0; + + for (key_id, key) in pickle.signing_key { + query( + "INSERT OR IGNORE INTO group_session_claimed_keys ( + session_id, algorithm, key + ) VALUES (?1, ?2, ?3) + ", + ) + .bind(session_row_id) + .bind(serde_json::to_string(&key_id)?) + .bind(key) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + async fn load_cross_signing_key( connection: &mut SqliteConnection, user_id: &UserId, @@ -1165,59 +1223,19 @@ impl CryptoStore for SqliteStore { Ok(self.get_sessions_for(sender_key).await?) } - async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result { + async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let pickle = session.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; - let session_id = session.session_id(); - // FIXME we need to store/restore the forwarding chains. - // FIXME this should be converted so it accepts an array of sessions for - // the key import feature. + // FIXME use a transaction here once sqlx gets better support for them. - query( - "REPLACE INTO inbound_group_sessions ( - session_id, account_id, sender_key, - room_id, pickle, imported - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) - ", - ) - .bind(session_id) - .bind(account_id) - .bind(&pickle.sender_key) - .bind(pickle.room_id.as_str()) - .bind(pickle.pickle.as_str()) - .bind(pickle.imported) - .execute(&mut *connection) - .await?; - - let row: (i64,) = query_as( - "SELECT id FROM inbound_group_sessions - WHERE account_id = ? and session_id = ? and sender_key = ?", - ) - .bind(account_id) - .bind(session_id) - .bind(pickle.sender_key) - .fetch_one(&mut *connection) - .await?; - - let session_row_id = row.0; - - for (key_id, key) in pickle.signing_key { - query( - "INSERT OR IGNORE INTO group_session_claimed_keys ( - session_id, algorithm, key - ) VALUES (?1, ?2, ?3) - ", - ) - .bind(session_row_id) - .bind(serde_json::to_string(&key_id)?) - .bind(key) - .execute(&mut *connection) - .await?; + for session in sessions { + self.save_inbound_group_session_helper(account_id, &mut connection, session) + .await?; + self.inbound_group_sessions.add(session.clone()); } - Ok(self.inbound_group_sessions.add(session)) + Ok(()) } async fn get_inbound_group_session( @@ -1581,7 +1599,7 @@ mod test { .expect("Can't create session"); store - .save_inbound_group_session(session) + .save_inbound_group_sessions(&[session]) .await .expect("Can't save group session"); } @@ -1601,7 +1619,7 @@ mod test { .expect("Can't create session"); store - .save_inbound_group_session(session.clone()) + .save_inbound_group_sessions(&[session.clone()]) .await .expect("Can't save group session");