crypto: Change the crypto store so we can save multiple group sessions at once.

master
Damir Jelić 2020-09-09 16:34:18 +02:00
parent 9617d9aac9
commit 127d4c225b
4 changed files with 79 additions and 60 deletions

View File

@ -984,7 +984,7 @@ impl OlmMachine {
&event.content.room_id, &event.content.room_id,
session_key, session_key,
)?; )?;
let _ = self.store.save_inbound_group_session(session).await?; let _ = self.store.save_inbound_group_sessions(&[session]).await?;
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
Ok(Some(event)) Ok(Some(event))
@ -1014,7 +1014,7 @@ impl OlmMachine {
.await .await
.map_err(|_| EventError::UnsupportedAlgorithm)?; .map_err(|_| EventError::UnsupportedAlgorithm)?;
let _ = self.store.save_inbound_group_session(inbound).await?; let _ = self.store.save_inbound_group_sessions(&[inbound]).await?;
let _ = self let _ = self
.outbound_group_sessions .outbound_group_sessions

View File

@ -80,8 +80,12 @@ impl CryptoStore for MemoryStore {
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> { async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
Ok(self.inbound_group_sessions.add(session)) for session in sessions {
self.inbound_group_sessions.add(session.clone());
}
Ok(())
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(
@ -208,7 +212,7 @@ mod test {
let store = MemoryStore::new(); let store = MemoryStore::new();
let _ = store let _ = store
.save_inbound_group_session(inbound.clone()) .save_inbound_group_sessions(&[inbound.clone()])
.await .await
.unwrap(); .unwrap();

View File

@ -157,15 +157,12 @@ pub trait CryptoStore: Debug {
/// * `sender_key` - The sender key that was used to establish the sessions. /// * `sender_key` - The sender key that was used to establish the sessions.
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>; async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
/// Save the given inbound group session in the store. /// Save the given inbound group sessions in the store.
///
/// If the session wasn't already in the store true is returned, false
/// otherwise.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `session` - The session that should be stored. /// * `sessions` - The sessions that should be stored.
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool>; async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>;
/// Get the inbound group session from our store. /// Get the inbound group session from our store.
/// ///

View File

@ -796,6 +796,64 @@ impl SqliteStore {
} }
} }
async fn save_inbound_group_session_helper(
&self,
account_id: i64,
connection: &mut SqliteConnection,
session: &InboundGroupSession,
) -> Result<()> {
let pickle = session.pickle(self.get_pickle_mode()).await;
let session_id = session.session_id();
// FIXME we need to store/restore the forwarding chains.
// FIXME this should be converted so it accepts an array of sessions for
// the key import feature.
query(
"REPLACE INTO inbound_group_sessions (
session_id, account_id, sender_key,
room_id, pickle, imported
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
)
.bind(session_id)
.bind(account_id)
.bind(&pickle.sender_key)
.bind(pickle.room_id.as_str())
.bind(pickle.pickle.as_str())
.bind(pickle.imported)
.execute(&mut *connection)
.await?;
let row: (i64,) = query_as(
"SELECT id FROM inbound_group_sessions
WHERE account_id = ? and session_id = ? and sender_key = ?",
)
.bind(account_id)
.bind(session_id)
.bind(pickle.sender_key)
.fetch_one(&mut *connection)
.await?;
let session_row_id = row.0;
for (key_id, key) in pickle.signing_key {
query(
"INSERT OR IGNORE INTO group_session_claimed_keys (
session_id, algorithm, key
) VALUES (?1, ?2, ?3)
",
)
.bind(session_row_id)
.bind(serde_json::to_string(&key_id)?)
.bind(key)
.execute(&mut *connection)
.await?;
}
Ok(())
}
async fn load_cross_signing_key( async fn load_cross_signing_key(
connection: &mut SqliteConnection, connection: &mut SqliteConnection,
user_id: &UserId, user_id: &UserId,
@ -1165,59 +1223,19 @@ impl CryptoStore for SqliteStore {
Ok(self.get_sessions_for(sender_key).await?) Ok(self.get_sessions_for(sender_key).await?)
} }
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> { async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let session_id = session.session_id();
// FIXME we need to store/restore the forwarding chains. // FIXME use a transaction here once sqlx gets better support for them.
// FIXME this should be converted so it accepts an array of sessions for
// the key import feature.
query( for session in sessions {
"REPLACE INTO inbound_group_sessions ( self.save_inbound_group_session_helper(account_id, &mut connection, session)
session_id, account_id, sender_key,
room_id, pickle, imported
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
)
.bind(session_id)
.bind(account_id)
.bind(&pickle.sender_key)
.bind(pickle.room_id.as_str())
.bind(pickle.pickle.as_str())
.bind(pickle.imported)
.execute(&mut *connection)
.await?;
let row: (i64,) = query_as(
"SELECT id FROM inbound_group_sessions
WHERE account_id = ? and session_id = ? and sender_key = ?",
)
.bind(account_id)
.bind(session_id)
.bind(pickle.sender_key)
.fetch_one(&mut *connection)
.await?;
let session_row_id = row.0;
for (key_id, key) in pickle.signing_key {
query(
"INSERT OR IGNORE INTO group_session_claimed_keys (
session_id, algorithm, key
) VALUES (?1, ?2, ?3)
",
)
.bind(session_row_id)
.bind(serde_json::to_string(&key_id)?)
.bind(key)
.execute(&mut *connection)
.await?; .await?;
self.inbound_group_sessions.add(session.clone());
} }
Ok(self.inbound_group_sessions.add(session)) Ok(())
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(
@ -1581,7 +1599,7 @@ mod test {
.expect("Can't create session"); .expect("Can't create session");
store store
.save_inbound_group_session(session) .save_inbound_group_sessions(&[session])
.await .await
.expect("Can't save group session"); .expect("Can't save group session");
} }
@ -1601,7 +1619,7 @@ mod test {
.expect("Can't create session"); .expect("Can't create session");
store store
.save_inbound_group_session(session.clone()) .save_inbound_group_sessions(&[session.clone()])
.await .await
.expect("Can't save group session"); .expect("Can't save group session");