crypto: Don't cache inbound group sessions in the sqlite store.

master
Damir Jelić 2020-10-16 15:54:50 +02:00
parent b5560d3cb6
commit 4262f1d3b0
1 changed files with 124 additions and 56 deletions

View File

@ -34,7 +34,7 @@ use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor,
use zeroize::Zeroizing;
use super::{
caches::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore},
caches::{DeviceStore, ReadOnlyUserDevices, SessionStore},
CryptoStore, CryptoStoreError, Result,
};
use crate::{
@ -56,7 +56,6 @@ pub struct SqliteStore {
path: Arc<PathBuf>,
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
devices: DeviceStore,
tracked_users: Arc<DashSet<UserId>>,
users_for_key_query: Arc<DashSet<UserId>>,
@ -150,7 +149,6 @@ impl SqliteStore {
device_id: Arc::new(device_id.into()),
account_info: Arc::new(SyncMutex::new(None)),
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
devices: DeviceStore::new(),
path: Arc::new(path),
connection: Arc::new(Mutex::new(connection)),
@ -525,7 +523,113 @@ impl SqliteStore {
.collect::<Result<Vec<Session>>>()?)
}
async fn load_inbound_group_sessions(&self) -> Result<()> {
async fn load_inbound_session_data(
&self,
connection: &mut SqliteConnection,
session_row_id: i64,
pickle: String,
sender_key: String,
room_id: RoomId,
imported: bool,
) -> Result<InboundGroupSession> {
let key_rows: Vec<(String, String)> =
query_as("SELECT algorithm, key FROM group_session_claimed_keys WHERE session_id = ?")
.bind(session_row_id)
.fetch_all(&mut *connection)
.await?;
let claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = key_rows
.into_iter()
.filter_map(|row| {
let algorithm = row.0.parse::<DeviceKeyAlgorithm>().ok()?;
let key = row.1;
Some((algorithm, key))
})
.collect();
let mut chain_rows: Vec<(String,)> =
query_as("SELECT key, key FROM group_session_chains WHERE session_id = ?")
.bind(session_row_id)
.fetch_all(&mut *connection)
.await?;
let chains: Vec<String> = chain_rows.drain(..).map(|r| r.0).collect();
let chains = if chains.is_empty() {
None
} else {
Some(chains)
};
let pickle = PickledInboundGroupSession {
pickle: InboundGroupSessionPickle::from(pickle),
sender_key,
signing_key: claimed_keys,
room_id,
forwarding_chains: chains,
imported,
};
Ok(InboundGroupSession::from_pickle(
pickle,
self.get_pickle_mode(),
)?)
}
async fn load_inbound_group_session_helper(
&self,
room_id: &RoomId,
sender_key: &str,
session_id: &str,
) -> Result<Option<InboundGroupSession>> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let row: Option<(i64, String, bool)> = query_as(
"SELECT id, pickle, imported
FROM inbound_group_sessions
WHERE (
account_id = ? and
room_id = ? and
sender_key = ? and
session_id = ?
)",
)
.bind(account_id)
.bind(room_id.as_str())
.bind(sender_key)
.bind(session_id)
.fetch_optional(&mut *connection)
.await?;
let row = if let Some(r) = row {
r
} else {
return Ok(None);
};
let session_row_id = row.0;
let pickle = row.1;
let imported = row.2;
let session = self
.load_inbound_session_data(
&mut connection,
session_row_id,
pickle,
sender_key.to_owned(),
room_id.to_owned(),
imported,
)
.await?;
Ok(Some(session))
}
async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
let mut sessions = Vec::new();
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
@ -541,57 +645,24 @@ impl SqliteStore {
let session_row_id = row.0;
let pickle = row.1;
let sender_key = row.2;
let room_id = row.3;
let room_id = RoomId::try_from(row.3)?;
let imported = row.4;
let key_rows: Vec<(String, String)> = query_as(
"SELECT algorithm, key FROM group_session_claimed_keys WHERE session_id = ?",
)
.bind(session_row_id)
.fetch_all(&mut *connection)
.await?;
let claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = key_rows
.into_iter()
.filter_map(|row| {
let algorithm = row.0.parse::<DeviceKeyAlgorithm>().ok()?;
let key = row.1;
Some((algorithm, key))
})
.collect();
let mut chain_rows: Vec<(String,)> =
query_as("SELECT key, key FROM group_session_chains WHERE session_id = ?")
.bind(session_row_id)
.fetch_all(&mut *connection)
.await?;
let chains: Vec<String> = chain_rows.drain(..).map(|r| r.0).collect();
let chains = if chains.is_empty() {
None
} else {
Some(chains)
};
let pickle = PickledInboundGroupSession {
pickle: InboundGroupSessionPickle::from(pickle),
sender_key,
signing_key: claimed_keys,
room_id: RoomId::try_from(room_id)?,
forwarding_chains: chains,
imported,
};
self.inbound_group_sessions
.add(InboundGroupSession::from_pickle(
let session = self
.load_inbound_session_data(
&mut connection,
session_row_id,
pickle,
self.get_pickle_mode(),
)?);
sender_key,
room_id.to_owned(),
imported,
)
.await?;
sessions.push(session);
}
Ok(())
Ok(sessions)
}
async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> {
@ -1205,7 +1276,6 @@ impl CryptoStore for SqliteStore {
drop(connection);
self.load_inbound_group_sessions().await?;
self.load_devices().await?;
self.load_tracked_users().await?;
@ -1300,7 +1370,6 @@ impl CryptoStore for SqliteStore {
for session in sessions {
self.save_inbound_group_session_helper(account_id, &mut transaction, session)
.await?;
self.inbound_group_sessions.add(session.clone());
}
transaction.commit().await?;
@ -1315,12 +1384,12 @@ impl CryptoStore for SqliteStore {
session_id: &str,
) -> Result<Option<InboundGroupSession>> {
Ok(self
.inbound_group_sessions
.get(room_id, sender_key, session_id))
.load_inbound_group_session_helper(room_id, sender_key, session_id)
.await?)
}
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
Ok(self.inbound_group_sessions.get_all())
Ok(self.load_inbound_group_sessions().await?)
}
fn is_user_tracked(&self, user_id: &UserId) -> bool {
@ -1768,7 +1837,6 @@ mod test {
.expect("Can't create store");
store.load_account().await.unwrap();
store.load_inbound_group_sessions().await.unwrap();
let loaded_session = store
.get_inbound_group_session(&session.room_id, &session.sender_key, session.session_id())