crypto: Make the session stores thread safe.
parent
50167e7988
commit
abe13d7a2d
|
@ -398,7 +398,7 @@ impl OlmMachine {
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
message: &OlmMessage,
|
message: &OlmMessage,
|
||||||
) -> Result<Option<String>> {
|
) -> Result<Option<String>> {
|
||||||
let mut s = self.store.sessions_mut(sender_key).await?;
|
let mut s = self.sessions.get_mut(sender_key).await;
|
||||||
|
|
||||||
let sessions = if let Some(s) = s {
|
let sessions = if let Some(s) = s {
|
||||||
s
|
s
|
||||||
|
@ -406,7 +406,7 @@ impl OlmMachine {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
for session in sessions.iter_mut() {
|
for session in sessions.lock().await.iter_mut() {
|
||||||
let mut matches = false;
|
let mut matches = false;
|
||||||
|
|
||||||
if let OlmMessage::PreKey(m) = &message {
|
if let OlmMessage::PreKey(m) = &message {
|
||||||
|
@ -448,8 +448,10 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
session.decrypt(message)?
|
let plaintext = session.decrypt(message)?;
|
||||||
|
self.sessions.add(session).await;
|
||||||
// TODO save the session
|
// TODO save the session
|
||||||
|
plaintext
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO convert the plaintext to a ruma event.
|
// TODO convert the plaintext to a ruma event.
|
||||||
|
@ -644,7 +646,7 @@ impl OlmMachine {
|
||||||
// TODO check if the olm session is wedged and re-request the key.
|
// TODO check if the olm session is wedged and re-request the key.
|
||||||
let session = session.ok_or(OlmError::MissingSession)?;
|
let session = session.ok_or(OlmError::MissingSession)?;
|
||||||
|
|
||||||
let (plaintext, _) = session.decrypt(content.ciphertext.clone())?;
|
let (plaintext, _) = session.lock().await.decrypt(content.ciphertext.clone())?;
|
||||||
// TODO check the message index.
|
// TODO check the message index.
|
||||||
// TODO check if this is from a verified device.
|
// TODO check if this is from a verified device.
|
||||||
|
|
||||||
|
|
|
@ -12,12 +12,46 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use super::olm::InboundGroupSession;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use super::olm::{InboundGroupSession, Session};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SessionStore {
|
||||||
|
entries: Mutex<HashMap<String, Arc<Mutex<Vec<Session>>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
SessionStore {
|
||||||
|
entries: Mutex::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add(&mut self, session: Session) {
|
||||||
|
let mut entries = self.entries.lock().await;
|
||||||
|
|
||||||
|
if !entries.contains_key(&session.sender_key) {
|
||||||
|
entries.insert(
|
||||||
|
session.sender_key.to_owned(),
|
||||||
|
Arc::new(Mutex::new(Vec::new())),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let mut sessions = entries.get_mut(&session.sender_key).unwrap();
|
||||||
|
sessions.lock().await.push(session);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_mut(&mut self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
|
||||||
|
self.entries.lock().await.get_mut(sender_key).cloned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GroupSessionStore {
|
pub struct GroupSessionStore {
|
||||||
entries: HashMap<String, HashMap<String, HashMap<String, InboundGroupSession>>>,
|
entries: HashMap<String, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GroupSessionStore {
|
impl GroupSessionStore {
|
||||||
|
@ -40,7 +74,7 @@ impl GroupSessionStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut sender_map = room_map.get_mut(&session.sender_key).unwrap();
|
let mut sender_map = room_map.get_mut(&session.sender_key).unwrap();
|
||||||
let ret = sender_map.insert(session.session_id(), session);
|
let ret = sender_map.insert(session.session_id(), Arc::new(Mutex::new(session)));
|
||||||
|
|
||||||
ret.is_some()
|
ret.is_some()
|
||||||
}
|
}
|
||||||
|
@ -50,9 +84,9 @@ impl GroupSessionStore {
|
||||||
room_id: &str,
|
room_id: &str,
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Option<&InboundGroupSession> {
|
) -> Option<Arc<Mutex<InboundGroupSession>>> {
|
||||||
self.entries
|
self.entries
|
||||||
.get(room_id)
|
.get(room_id)
|
||||||
.and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id)))
|
.and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id).cloned()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,6 +116,7 @@ impl Account {
|
||||||
|
|
||||||
Ok(Session {
|
Ok(Session {
|
||||||
inner: session,
|
inner: session,
|
||||||
|
sender_key: their_identity_key.to_owned(),
|
||||||
creation_time: now.clone(),
|
creation_time: now.clone(),
|
||||||
last_use_time: now,
|
last_use_time: now,
|
||||||
})
|
})
|
||||||
|
@ -128,8 +129,10 @@ impl PartialEq for Account {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
inner: OlmSession,
|
inner: OlmSession,
|
||||||
|
pub(crate) sender_key: String,
|
||||||
creation_time: Instant,
|
creation_time: Instant,
|
||||||
last_use_time: Instant,
|
last_use_time: Instant,
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,24 +53,19 @@ pub enum CryptoStoreError {
|
||||||
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait CryptoStore: Debug {
|
pub trait CryptoStore: Debug + Sync + Sync {
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
||||||
async fn sessions_mut(&mut self, sender_key: &str) -> Result<Option<&mut Vec<Session>>>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MemoryStore {
|
pub struct MemoryStore {
|
||||||
pub(crate) account_info: Option<(String, bool)>,
|
pub(crate) account_info: Option<(String, bool)>,
|
||||||
sessions: HashMap<String, Vec<Session>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MemoryStore {
|
impl MemoryStore {
|
||||||
/// Create a new empty memory store.
|
/// Create a new empty memory store.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
MemoryStore {
|
MemoryStore { account_info: None }
|
||||||
account_info: None,
|
|
||||||
sessions: HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,13 +89,6 @@ impl CryptoStore for MemoryStore {
|
||||||
self.account_info = Some((pickle, acc.shared));
|
self.account_info = Some((pickle, acc.shared));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn sessions_mut<'a>(
|
|
||||||
&'a mut self,
|
|
||||||
sender_key: &str,
|
|
||||||
) -> Result<Option<&'a mut Vec<Session>>> {
|
|
||||||
Ok(self.sessions.get_mut(sender_key))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MemoryStore {
|
impl std::fmt::Debug for MemoryStore {
|
||||||
|
|
|
@ -168,13 +168,6 @@ impl CryptoStore for SqliteStore {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn sessions_mut<'a>(
|
|
||||||
&'a mut self,
|
|
||||||
sender_key: &str,
|
|
||||||
) -> Result<Option<&'a mut Vec<Session>>> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for SqliteStore {
|
impl std::fmt::Debug for SqliteStore {
|
||||||
|
|
Loading…
Reference in New Issue