crypto: Make the session stores thread safe.
This commit is contained in:
parent
50167e7988
commit
abe13d7a2d
5 changed files with 50 additions and 30 deletions
|
@ -398,7 +398,7 @@ impl OlmMachine {
|
|||
sender_key: &str,
|
||||
message: &OlmMessage,
|
||||
) -> Result<Option<String>> {
|
||||
let mut s = self.store.sessions_mut(sender_key).await?;
|
||||
let mut s = self.sessions.get_mut(sender_key).await;
|
||||
|
||||
let sessions = if let Some(s) = s {
|
||||
s
|
||||
|
@ -406,7 +406,7 @@ impl OlmMachine {
|
|||
return Ok(None);
|
||||
};
|
||||
|
||||
for session in sessions.iter_mut() {
|
||||
for session in sessions.lock().await.iter_mut() {
|
||||
let mut matches = false;
|
||||
|
||||
if let OlmMessage::PreKey(m) = &message {
|
||||
|
@ -448,8 +448,10 @@ impl OlmMachine {
|
|||
}
|
||||
};
|
||||
|
||||
session.decrypt(message)?
|
||||
let plaintext = session.decrypt(message)?;
|
||||
self.sessions.add(session).await;
|
||||
// TODO save the session
|
||||
plaintext
|
||||
};
|
||||
|
||||
// TODO convert the plaintext to a ruma event.
|
||||
|
@ -644,7 +646,7 @@ impl OlmMachine {
|
|||
// TODO check if the olm session is wedged and re-request the key.
|
||||
let session = session.ok_or(OlmError::MissingSession)?;
|
||||
|
||||
let (plaintext, _) = session.decrypt(content.ciphertext.clone())?;
|
||||
let (plaintext, _) = session.lock().await.decrypt(content.ciphertext.clone())?;
|
||||
// TODO check the message index.
|
||||
// TODO check if this is from a verified device.
|
||||
|
||||
|
|
|
@ -12,12 +12,46 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use super::olm::InboundGroupSession;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::olm::{InboundGroupSession, Session};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SessionStore {
|
||||
entries: Mutex<HashMap<String, Arc<Mutex<Vec<Session>>>>>,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
pub fn new() -> Self {
|
||||
SessionStore {
|
||||
entries: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add(&mut self, session: Session) {
|
||||
let mut entries = self.entries.lock().await;
|
||||
|
||||
if !entries.contains_key(&session.sender_key) {
|
||||
entries.insert(
|
||||
session.sender_key.to_owned(),
|
||||
Arc::new(Mutex::new(Vec::new())),
|
||||
);
|
||||
}
|
||||
let mut sessions = entries.get_mut(&session.sender_key).unwrap();
|
||||
sessions.lock().await.push(session);
|
||||
}
|
||||
|
||||
pub async fn get_mut(&mut self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
|
||||
self.entries.lock().await.get_mut(sender_key).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GroupSessionStore {
|
||||
entries: HashMap<String, HashMap<String, HashMap<String, InboundGroupSession>>>,
|
||||
entries: HashMap<String, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>,
|
||||
}
|
||||
|
||||
impl GroupSessionStore {
|
||||
|
@ -40,7 +74,7 @@ impl GroupSessionStore {
|
|||
}
|
||||
|
||||
let mut sender_map = room_map.get_mut(&session.sender_key).unwrap();
|
||||
let ret = sender_map.insert(session.session_id(), session);
|
||||
let ret = sender_map.insert(session.session_id(), Arc::new(Mutex::new(session)));
|
||||
|
||||
ret.is_some()
|
||||
}
|
||||
|
@ -50,9 +84,9 @@ impl GroupSessionStore {
|
|||
room_id: &str,
|
||||
sender_key: &str,
|
||||
session_id: &str,
|
||||
) -> Option<&InboundGroupSession> {
|
||||
) -> Option<Arc<Mutex<InboundGroupSession>>> {
|
||||
self.entries
|
||||
.get(room_id)
|
||||
.and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id)))
|
||||
.and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id).cloned()))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -116,6 +116,7 @@ impl Account {
|
|||
|
||||
Ok(Session {
|
||||
inner: session,
|
||||
sender_key: their_identity_key.to_owned(),
|
||||
creation_time: now.clone(),
|
||||
last_use_time: now,
|
||||
})
|
||||
|
@ -128,8 +129,10 @@ impl PartialEq for Account {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Session {
|
||||
inner: OlmSession,
|
||||
pub(crate) sender_key: String,
|
||||
creation_time: Instant,
|
||||
last_use_time: Instant,
|
||||
}
|
||||
|
|
|
@ -53,24 +53,19 @@ pub enum CryptoStoreError {
|
|||
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait CryptoStore: Debug {
|
||||
pub trait CryptoStore: Debug + Sync + Sync {
|
||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
||||
async fn sessions_mut(&mut self, sender_key: &str) -> Result<Option<&mut Vec<Session>>>;
|
||||
}
|
||||
|
||||
pub struct MemoryStore {
|
||||
pub(crate) account_info: Option<(String, bool)>,
|
||||
sessions: HashMap<String, Vec<Session>>,
|
||||
}
|
||||
|
||||
impl MemoryStore {
|
||||
/// Create a new empty memory store.
|
||||
pub fn new() -> Self {
|
||||
MemoryStore {
|
||||
account_info: None,
|
||||
sessions: HashMap::new(),
|
||||
}
|
||||
MemoryStore { account_info: None }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -94,13 +89,6 @@ impl CryptoStore for MemoryStore {
|
|||
self.account_info = Some((pickle, acc.shared));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sessions_mut<'a>(
|
||||
&'a mut self,
|
||||
sender_key: &str,
|
||||
) -> Result<Option<&'a mut Vec<Session>>> {
|
||||
Ok(self.sessions.get_mut(sender_key))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MemoryStore {
|
||||
|
|
|
@ -168,13 +168,6 @@ impl CryptoStore for SqliteStore {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sessions_mut<'a>(
|
||||
&'a mut self,
|
||||
sender_key: &str,
|
||||
) -> Result<Option<&'a mut Vec<Session>>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SqliteStore {
|
||||
|
|
Loading…
Reference in a new issue