crypto: Change the crypto store so we can save multiple group sessions at once.
parent
9617d9aac9
commit
127d4c225b
|
@ -984,7 +984,7 @@ impl OlmMachine {
|
|||
&event.content.room_id,
|
||||
session_key,
|
||||
)?;
|
||||
let _ = self.store.save_inbound_group_session(session).await?;
|
||||
let _ = self.store.save_inbound_group_sessions(&[session]).await?;
|
||||
|
||||
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
|
||||
Ok(Some(event))
|
||||
|
@ -1014,7 +1014,7 @@ impl OlmMachine {
|
|||
.await
|
||||
.map_err(|_| EventError::UnsupportedAlgorithm)?;
|
||||
|
||||
let _ = self.store.save_inbound_group_session(inbound).await?;
|
||||
let _ = self.store.save_inbound_group_sessions(&[inbound]).await?;
|
||||
|
||||
let _ = self
|
||||
.outbound_group_sessions
|
||||
|
|
|
@ -80,8 +80,12 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(self.sessions.get(sender_key))
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
|
||||
Ok(self.inbound_group_sessions.add(session))
|
||||
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
|
||||
for session in sessions {
|
||||
self.inbound_group_sessions.add(session.clone());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
|
@ -208,7 +212,7 @@ mod test {
|
|||
|
||||
let store = MemoryStore::new();
|
||||
let _ = store
|
||||
.save_inbound_group_session(inbound.clone())
|
||||
.save_inbound_group_sessions(&[inbound.clone()])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
@ -157,15 +157,12 @@ pub trait CryptoStore: Debug {
|
|||
/// * `sender_key` - The sender key that was used to establish the sessions.
|
||||
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
|
||||
|
||||
/// Save the given inbound group session in the store.
|
||||
///
|
||||
/// If the session wasn't already in the store true is returned, false
|
||||
/// otherwise.
|
||||
/// Save the given inbound group sessions in the store.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session` - The session that should be stored.
|
||||
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool>;
|
||||
/// * `sessions` - The sessions that should be stored.
|
||||
async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>;
|
||||
|
||||
/// Get the inbound group session from our store.
|
||||
///
|
||||
|
|
|
@ -796,6 +796,64 @@ impl SqliteStore {
|
|||
}
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session_helper(
|
||||
&self,
|
||||
account_id: i64,
|
||||
connection: &mut SqliteConnection,
|
||||
session: &InboundGroupSession,
|
||||
) -> Result<()> {
|
||||
let pickle = session.pickle(self.get_pickle_mode()).await;
|
||||
let session_id = session.session_id();
|
||||
|
||||
// FIXME we need to store/restore the forwarding chains.
|
||||
// FIXME this should be converted so it accepts an array of sessions for
|
||||
// the key import feature.
|
||||
|
||||
query(
|
||||
"REPLACE INTO inbound_group_sessions (
|
||||
session_id, account_id, sender_key,
|
||||
room_id, pickle, imported
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
",
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(account_id)
|
||||
.bind(&pickle.sender_key)
|
||||
.bind(pickle.room_id.as_str())
|
||||
.bind(pickle.pickle.as_str())
|
||||
.bind(pickle.imported)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
|
||||
let row: (i64,) = query_as(
|
||||
"SELECT id FROM inbound_group_sessions
|
||||
WHERE account_id = ? and session_id = ? and sender_key = ?",
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(session_id)
|
||||
.bind(pickle.sender_key)
|
||||
.fetch_one(&mut *connection)
|
||||
.await?;
|
||||
|
||||
let session_row_id = row.0;
|
||||
|
||||
for (key_id, key) in pickle.signing_key {
|
||||
query(
|
||||
"INSERT OR IGNORE INTO group_session_claimed_keys (
|
||||
session_id, algorithm, key
|
||||
) VALUES (?1, ?2, ?3)
|
||||
",
|
||||
)
|
||||
.bind(session_row_id)
|
||||
.bind(serde_json::to_string(&key_id)?)
|
||||
.bind(key)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_cross_signing_key(
|
||||
connection: &mut SqliteConnection,
|
||||
user_id: &UserId,
|
||||
|
@ -1165,59 +1223,19 @@ impl CryptoStore for SqliteStore {
|
|||
Ok(self.get_sessions_for(sender_key).await?)
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
|
||||
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
|
||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let pickle = session.pickle(self.get_pickle_mode()).await;
|
||||
let mut connection = self.connection.lock().await;
|
||||
let session_id = session.session_id();
|
||||
|
||||
// FIXME we need to store/restore the forwarding chains.
|
||||
// FIXME this should be converted so it accepts an array of sessions for
|
||||
// the key import feature.
|
||||
// FIXME use a transaction here once sqlx gets better support for them.
|
||||
|
||||
query(
|
||||
"REPLACE INTO inbound_group_sessions (
|
||||
session_id, account_id, sender_key,
|
||||
room_id, pickle, imported
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
",
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(account_id)
|
||||
.bind(&pickle.sender_key)
|
||||
.bind(pickle.room_id.as_str())
|
||||
.bind(pickle.pickle.as_str())
|
||||
.bind(pickle.imported)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
|
||||
let row: (i64,) = query_as(
|
||||
"SELECT id FROM inbound_group_sessions
|
||||
WHERE account_id = ? and session_id = ? and sender_key = ?",
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(session_id)
|
||||
.bind(pickle.sender_key)
|
||||
.fetch_one(&mut *connection)
|
||||
.await?;
|
||||
|
||||
let session_row_id = row.0;
|
||||
|
||||
for (key_id, key) in pickle.signing_key {
|
||||
query(
|
||||
"INSERT OR IGNORE INTO group_session_claimed_keys (
|
||||
session_id, algorithm, key
|
||||
) VALUES (?1, ?2, ?3)
|
||||
",
|
||||
)
|
||||
.bind(session_row_id)
|
||||
.bind(serde_json::to_string(&key_id)?)
|
||||
.bind(key)
|
||||
.execute(&mut *connection)
|
||||
for session in sessions {
|
||||
self.save_inbound_group_session_helper(account_id, &mut connection, session)
|
||||
.await?;
|
||||
self.inbound_group_sessions.add(session.clone());
|
||||
}
|
||||
|
||||
Ok(self.inbound_group_sessions.add(session))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
|
@ -1581,7 +1599,7 @@ mod test {
|
|||
.expect("Can't create session");
|
||||
|
||||
store
|
||||
.save_inbound_group_session(session)
|
||||
.save_inbound_group_sessions(&[session])
|
||||
.await
|
||||
.expect("Can't save group session");
|
||||
}
|
||||
|
@ -1601,7 +1619,7 @@ mod test {
|
|||
.expect("Can't create session");
|
||||
|
||||
store
|
||||
.save_inbound_group_session(session.clone())
|
||||
.save_inbound_group_sessions(&[session.clone()])
|
||||
.await
|
||||
.expect("Can't save group session");
|
||||
|
||||
|
|
Loading…
Reference in New Issue