crypto: Move the account mutex into the account struct.
parent
7c20c79f32
commit
b8d6a4c49a
|
@ -68,7 +68,7 @@ pub struct OlmMachine {
|
||||||
/// The unique device id of the device that holds this account.
|
/// The unique device id of the device that holds this account.
|
||||||
device_id: DeviceId,
|
device_id: DeviceId,
|
||||||
/// Our underlying Olm Account holding our identity keys.
|
/// Our underlying Olm Account holding our identity keys.
|
||||||
account: Arc<Mutex<Account>>,
|
account: Account,
|
||||||
/// The number of signed one-time keys we have uploaded to the server. If
|
/// The number of signed one-time keys we have uploaded to the server. If
|
||||||
/// this is None, no action will be taken. After a sync request the client
|
/// this is None, no action will be taken. After a sync request the client
|
||||||
/// needs to set this for us, depending on the count we will suggest the
|
/// needs to set this for us, depending on the count we will suggest the
|
||||||
|
@ -98,7 +98,7 @@ impl OlmMachine {
|
||||||
Ok(OlmMachine {
|
Ok(OlmMachine {
|
||||||
user_id: user_id.clone(),
|
user_id: user_id.clone(),
|
||||||
device_id: device_id.to_owned(),
|
device_id: device_id.to_owned(),
|
||||||
account: Arc::new(Mutex::new(Account::new())),
|
account: Account::new(),
|
||||||
uploaded_signed_key_count: None,
|
uploaded_signed_key_count: None,
|
||||||
store: Box::new(MemoryStore::new()),
|
store: Box::new(MemoryStore::new()),
|
||||||
users_for_key_query: HashSet::new(),
|
users_for_key_query: HashSet::new(),
|
||||||
|
@ -132,7 +132,7 @@ impl OlmMachine {
|
||||||
Ok(OlmMachine {
|
Ok(OlmMachine {
|
||||||
user_id: user_id.clone(),
|
user_id: user_id.clone(),
|
||||||
device_id: device_id.to_owned(),
|
device_id: device_id.to_owned(),
|
||||||
account: Arc::new(Mutex::new(account)),
|
account,
|
||||||
uploaded_signed_key_count: None,
|
uploaded_signed_key_count: None,
|
||||||
store: Box::new(store),
|
store: Box::new(store),
|
||||||
users_for_key_query: HashSet::new(),
|
users_for_key_query: HashSet::new(),
|
||||||
|
@ -142,7 +142,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
/// Should account or one-time keys be uploaded to the server.
|
/// Should account or one-time keys be uploaded to the server.
|
||||||
pub async fn should_upload_keys(&self) -> bool {
|
pub async fn should_upload_keys(&self) -> bool {
|
||||||
if !self.account.lock().await.shared() {
|
if !self.account.shared() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ impl OlmMachine {
|
||||||
// max_one_time_Keys() / 2, otherwise tell the client to upload more.
|
// max_one_time_Keys() / 2, otherwise tell the client to upload more.
|
||||||
match self.uploaded_signed_key_count {
|
match self.uploaded_signed_key_count {
|
||||||
Some(count) => {
|
Some(count) => {
|
||||||
let max_keys = self.account.lock().await.max_one_time_keys() as u64;
|
let max_keys = self.account.max_one_time_keys().await as u64;
|
||||||
let key_count = (max_keys / 2) - count;
|
let key_count = (max_keys / 2) - count;
|
||||||
key_count > 0
|
key_count > 0
|
||||||
}
|
}
|
||||||
|
@ -169,11 +169,10 @@ impl OlmMachine {
|
||||||
&mut self,
|
&mut self,
|
||||||
response: &keys::upload_keys::Response,
|
response: &keys::upload_keys::Response,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut account = self.account.lock().await;
|
if !self.account.shared() {
|
||||||
if !account.shared {
|
|
||||||
debug!("Marking account as shared");
|
debug!("Marking account as shared");
|
||||||
}
|
}
|
||||||
account.shared = true;
|
self.account.mark_as_shared();
|
||||||
|
|
||||||
let one_time_key_count = response
|
let one_time_key_count = response
|
||||||
.one_time_key_counts
|
.one_time_key_counts
|
||||||
|
@ -187,10 +186,9 @@ impl OlmMachine {
|
||||||
);
|
);
|
||||||
self.uploaded_signed_key_count = Some(count);
|
self.uploaded_signed_key_count = Some(count);
|
||||||
|
|
||||||
account.mark_keys_as_published();
|
self.account.mark_keys_as_published().await;
|
||||||
drop(account);
|
|
||||||
|
|
||||||
self.store.save_account(self.account.clone()).await?;
|
// self.store.save_account(self.account.clone()).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -317,9 +315,8 @@ impl OlmMachine {
|
||||||
|
|
||||||
let session = match self
|
let session = match self
|
||||||
.account
|
.account
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.create_outbound_session(curve_key, &one_time_key)
|
.create_outbound_session(curve_key, &one_time_key)
|
||||||
|
.await
|
||||||
{
|
{
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -441,10 +438,9 @@ impl OlmMachine {
|
||||||
/// Returns the number of newly generated one-time keys. If no keys can be
|
/// Returns the number of newly generated one-time keys. If no keys can be
|
||||||
/// generated returns an empty error.
|
/// generated returns an empty error.
|
||||||
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
|
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
|
||||||
let account = self.account.lock().await;
|
|
||||||
match self.uploaded_signed_key_count {
|
match self.uploaded_signed_key_count {
|
||||||
Some(count) => {
|
Some(count) => {
|
||||||
let max_keys = account.max_one_time_keys() as u64;
|
let max_keys = self.account.max_one_time_keys().await as u64;
|
||||||
let max_on_server = max_keys / 2;
|
let max_on_server = max_keys / 2;
|
||||||
|
|
||||||
if count >= (max_on_server) {
|
if count >= (max_on_server) {
|
||||||
|
@ -453,11 +449,11 @@ impl OlmMachine {
|
||||||
|
|
||||||
let key_count = (max_on_server) - count;
|
let key_count = (max_on_server) - count;
|
||||||
|
|
||||||
let key_count: usize = key_count
|
let max_keys = self.account.max_one_time_keys().await;
|
||||||
.try_into()
|
|
||||||
.unwrap_or_else(|_| account.max_one_time_keys());
|
|
||||||
|
|
||||||
account.generate_one_time_keys(key_count);
|
let key_count: usize = key_count.try_into().unwrap_or(max_keys);
|
||||||
|
|
||||||
|
self.account.generate_one_time_keys(key_count).await;
|
||||||
Ok(key_count as u64)
|
Ok(key_count as u64)
|
||||||
}
|
}
|
||||||
None => Err(()),
|
None => Err(()),
|
||||||
|
@ -466,7 +462,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
/// Sign the device keys and return a JSON Value to upload them.
|
/// Sign the device keys and return a JSON Value to upload them.
|
||||||
async fn device_keys(&self) -> DeviceKeys {
|
async fn device_keys(&self) -> DeviceKeys {
|
||||||
let identity_keys = self.account.lock().await.identity_keys();
|
let identity_keys = self.account.identity_keys();
|
||||||
|
|
||||||
let mut keys = HashMap::new();
|
let mut keys = HashMap::new();
|
||||||
|
|
||||||
|
@ -513,7 +509,7 @@ impl OlmMachine {
|
||||||
/// If no one-time keys need to be uploaded returns an empty error.
|
/// If no one-time keys need to be uploaded returns an empty error.
|
||||||
async fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> {
|
async fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> {
|
||||||
let _ = self.generate_one_time_keys().await?;
|
let _ = self.generate_one_time_keys().await?;
|
||||||
let one_time_keys = self.account.lock().await.one_time_keys();
|
let one_time_keys = self.account.one_time_keys().await;
|
||||||
let mut one_time_key_map = HashMap::new();
|
let mut one_time_key_map = HashMap::new();
|
||||||
|
|
||||||
for (key_id, key) in one_time_keys.curve25519().iter() {
|
for (key_id, key) in one_time_keys.curve25519().iter() {
|
||||||
|
@ -555,10 +551,9 @@ impl OlmMachine {
|
||||||
/// * `json` - The value that should be converted into a canonical JSON
|
/// * `json` - The value that should be converted into a canonical JSON
|
||||||
/// string.
|
/// string.
|
||||||
async fn sign_json(&self, json: &Value) -> String {
|
async fn sign_json(&self, json: &Value) -> String {
|
||||||
let account = self.account.lock().await;
|
|
||||||
let canonical_json = cjson::to_string(json)
|
let canonical_json = cjson::to_string(json)
|
||||||
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json)));
|
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json)));
|
||||||
account.sign(&canonical_json)
|
self.account.sign(&canonical_json).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verify a signed JSON object.
|
/// Verify a signed JSON object.
|
||||||
|
@ -637,7 +632,7 @@ impl OlmMachine {
|
||||||
return Err(());
|
return Err(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let shared = self.account.lock().await.shared();
|
let shared = self.account.shared();
|
||||||
|
|
||||||
let device_keys = if !shared {
|
let device_keys = if !shared {
|
||||||
Some(self.device_keys().await)
|
Some(self.device_keys().await)
|
||||||
|
@ -702,8 +697,9 @@ impl OlmMachine {
|
||||||
let mut session = match &message {
|
let mut session = match &message {
|
||||||
OlmMessage::Message(_) => return Err(OlmError::SessionWedged),
|
OlmMessage::Message(_) => return Err(OlmError::SessionWedged),
|
||||||
OlmMessage::PreKey(m) => {
|
OlmMessage::PreKey(m) => {
|
||||||
let account = self.account.lock().await;
|
self.account
|
||||||
account.create_inbound_session(sender_key, m.clone())?
|
.create_inbound_session(sender_key, m.clone())
|
||||||
|
.await?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -740,7 +736,7 @@ impl OlmMachine {
|
||||||
return Err(OlmError::UnsupportedAlgorithm);
|
return Err(OlmError::UnsupportedAlgorithm);
|
||||||
};
|
};
|
||||||
|
|
||||||
let identity_keys = self.account.lock().await.identity_keys();
|
let identity_keys = self.account.identity_keys();
|
||||||
let own_key = identity_keys.curve25519();
|
let own_key = identity_keys.curve25519();
|
||||||
let own_ciphertext = content.ciphertext.get(own_key);
|
let own_ciphertext = content.ciphertext.get(own_key);
|
||||||
|
|
||||||
|
@ -799,8 +795,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> Result<()> {
|
async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> Result<()> {
|
||||||
let session = OutboundGroupSession::new(room_id);
|
let session = OutboundGroupSession::new(room_id);
|
||||||
let account = self.account.lock().await;
|
let identity_keys = self.account.identity_keys();
|
||||||
let identity_keys = account.identity_keys();
|
|
||||||
|
|
||||||
let sender_key = identity_keys.curve25519();
|
let sender_key = identity_keys.curve25519();
|
||||||
let signing_key = identity_keys.ed25519();
|
let signing_key = identity_keys.ed25519();
|
||||||
|
@ -855,13 +850,7 @@ impl OlmMachine {
|
||||||
Ok(MegolmV1AesSha2Content {
|
Ok(MegolmV1AesSha2Content {
|
||||||
algorithm: Algorithm::MegolmV1AesSha2,
|
algorithm: Algorithm::MegolmV1AesSha2,
|
||||||
ciphertext,
|
ciphertext,
|
||||||
sender_key: self
|
sender_key: self.account.identity_keys().curve25519().to_owned(),
|
||||||
.account
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.identity_keys()
|
|
||||||
.curve25519()
|
|
||||||
.to_owned(),
|
|
||||||
session_id: session.session_id().to_owned(),
|
session_id: session.session_id().to_owned(),
|
||||||
device_id: self.device_id.to_owned(),
|
device_id: self.device_id.to_owned(),
|
||||||
})
|
})
|
||||||
|
@ -874,7 +863,7 @@ impl OlmMachine {
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
content: Value,
|
content: Value,
|
||||||
) -> Result<OlmV1Curve25519AesSha2Content> {
|
) -> Result<OlmV1Curve25519AesSha2Content> {
|
||||||
let identity_keys = self.account.lock().await.identity_keys();
|
let identity_keys = self.account.identity_keys();
|
||||||
|
|
||||||
let recipient_signing_key = recipient_device
|
let recipient_signing_key = recipient_device
|
||||||
.keys(&KeyAlgorithm::Ed25519)
|
.keys(&KeyAlgorithm::Ed25519)
|
||||||
|
@ -1326,7 +1315,7 @@ mod test {
|
||||||
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
||||||
|
|
||||||
let mut device_keys = machine.device_keys().await;
|
let mut device_keys = machine.device_keys().await;
|
||||||
let identity_keys = machine.account.lock().await.identity_keys();
|
let identity_keys = machine.account.identity_keys();
|
||||||
let ed25519_key = identity_keys.ed25519();
|
let ed25519_key = identity_keys.ed25519();
|
||||||
|
|
||||||
let ret = machine.verify_json(
|
let ret = machine.verify_json(
|
||||||
|
@ -1359,7 +1348,7 @@ mod test {
|
||||||
machine.uploaded_signed_key_count = Some(49);
|
machine.uploaded_signed_key_count = Some(49);
|
||||||
|
|
||||||
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
|
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
|
||||||
let identity_keys = machine.account.lock().await.identity_keys();
|
let identity_keys = machine.account.identity_keys();
|
||||||
let ed25519_key = identity_keys.ed25519();
|
let ed25519_key = identity_keys.ed25519();
|
||||||
|
|
||||||
let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap();
|
let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap();
|
||||||
|
@ -1378,7 +1367,7 @@ mod test {
|
||||||
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
||||||
machine.uploaded_signed_key_count = Some(0);
|
machine.uploaded_signed_key_count = Some(0);
|
||||||
|
|
||||||
let identity_keys = machine.account.lock().await.identity_keys();
|
let identity_keys = machine.account.identity_keys();
|
||||||
let ed25519_key = identity_keys.ed25519();
|
let ed25519_key = identity_keys.ed25519();
|
||||||
|
|
||||||
let (device_keys, mut one_time_keys) = machine
|
let (device_keys, mut one_time_keys) = machine
|
||||||
|
|
|
@ -33,9 +33,11 @@ use crate::identifiers::RoomId;
|
||||||
/// The Olm account.
|
/// The Olm account.
|
||||||
/// An account is the central identity for encrypted communication between two
|
/// An account is the central identity for encrypted communication between two
|
||||||
/// devices. It holds the two identity key pairs for a device.
|
/// devices. It holds the two identity key pairs for a device.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct Account {
|
pub struct Account {
|
||||||
inner: OlmAccount,
|
inner: Arc<Mutex<OlmAccount>>,
|
||||||
pub(crate) shared: bool,
|
identity_keys: Arc<IdentityKeys>,
|
||||||
|
pub(crate) shared: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Debug for Account {
|
impl fmt::Debug for Account {
|
||||||
|
@ -44,7 +46,7 @@ impl fmt::Debug for Account {
|
||||||
f,
|
f,
|
||||||
"Olm Account: {:?}, shared: {}",
|
"Olm Account: {:?}, shared: {}",
|
||||||
self.identity_keys(),
|
self.identity_keys(),
|
||||||
self.shared
|
self.shared()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -52,49 +54,61 @@ impl fmt::Debug for Account {
|
||||||
impl Account {
|
impl Account {
|
||||||
/// Create a new account.
|
/// Create a new account.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
let account = OlmAccount::new();
|
||||||
|
let identity_keys = account.parsed_identity_keys();
|
||||||
|
|
||||||
Account {
|
Account {
|
||||||
inner: OlmAccount::new(),
|
inner: Arc::new(Mutex::new(account)),
|
||||||
shared: false,
|
identity_keys: Arc::new(identity_keys),
|
||||||
|
shared: Arc::new(AtomicBool::new(false)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the public parts of the identity keys for the account.
|
/// Get the public parts of the identity keys for the account.
|
||||||
pub fn identity_keys(&self) -> IdentityKeys {
|
pub fn identity_keys(&self) -> &IdentityKeys {
|
||||||
self.inner.parsed_identity_keys()
|
&self.identity_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Has the account been shared with the server.
|
/// Has the account been shared with the server.
|
||||||
pub fn shared(&self) -> bool {
|
pub fn shared(&self) -> bool {
|
||||||
self.shared
|
self.shared.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark the account as shared.
|
||||||
|
///
|
||||||
|
/// Messages shouldn't be encrypted with the session before it has been
|
||||||
|
/// shared.
|
||||||
|
pub fn mark_as_shared(&self) {
|
||||||
|
self.shared.store(true, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the one-time keys of the account.
|
/// Get the one-time keys of the account.
|
||||||
///
|
///
|
||||||
/// This can be empty, keys need to be generated first.
|
/// This can be empty, keys need to be generated first.
|
||||||
pub fn one_time_keys(&self) -> OneTimeKeys {
|
pub async fn one_time_keys(&self) -> OneTimeKeys {
|
||||||
self.inner.parsed_one_time_keys()
|
self.inner.lock().await.parsed_one_time_keys()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate count number of one-time keys.
|
/// Generate count number of one-time keys.
|
||||||
pub fn generate_one_time_keys(&self, count: usize) {
|
pub async fn generate_one_time_keys(&self, count: usize) {
|
||||||
self.inner.generate_one_time_keys(count);
|
self.inner.lock().await.generate_one_time_keys(count);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the maximum number of one-time keys the account can hold.
|
/// Get the maximum number of one-time keys the account can hold.
|
||||||
pub fn max_one_time_keys(&self) -> usize {
|
pub async fn max_one_time_keys(&self) -> usize {
|
||||||
self.inner.max_number_of_one_time_keys()
|
self.inner.lock().await.max_number_of_one_time_keys()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark the current set of one-time keys as being published.
|
/// Mark the current set of one-time keys as being published.
|
||||||
pub fn mark_keys_as_published(&self) {
|
pub async fn mark_keys_as_published(&self) {
|
||||||
self.inner.mark_keys_as_published();
|
self.inner.lock().await.mark_keys_as_published();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sign the given string using the accounts signing key.
|
/// Sign the given string using the accounts signing key.
|
||||||
///
|
///
|
||||||
/// Returns the signature as a base64 encoded string.
|
/// Returns the signature as a base64 encoded string.
|
||||||
pub fn sign(&self, string: &str) -> String {
|
pub async fn sign(&self, string: &str) -> String {
|
||||||
self.inner.sign(string)
|
self.inner.lock().await.sign(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store the account as a base64 encoded string.
|
/// Store the account as a base64 encoded string.
|
||||||
|
@ -103,8 +117,8 @@ impl Account {
|
||||||
///
|
///
|
||||||
/// * `pickle_mode` - The mode that was used to pickle the account, either an
|
/// * `pickle_mode` - The mode that was used to pickle the account, either an
|
||||||
/// unencrypted mode or an encrypted using passphrase.
|
/// unencrypted mode or an encrypted using passphrase.
|
||||||
pub fn pickle(&self, pickle_mode: PicklingMode) -> String {
|
pub async fn pickle(&self, pickle_mode: PicklingMode) -> String {
|
||||||
self.inner.pickle(pickle_mode)
|
self.inner.lock().await.pickle(pickle_mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Restore an account from a previously pickled string.
|
/// Restore an account from a previously pickled string.
|
||||||
|
@ -123,8 +137,14 @@ impl Account {
|
||||||
pickle_mode: PicklingMode,
|
pickle_mode: PicklingMode,
|
||||||
shared: bool,
|
shared: bool,
|
||||||
) -> Result<Self, OlmAccountError> {
|
) -> Result<Self, OlmAccountError> {
|
||||||
let acc = OlmAccount::unpickle(pickle, pickle_mode)?;
|
let account = OlmAccount::unpickle(pickle, pickle_mode)?;
|
||||||
Ok(Account { inner: acc, shared })
|
let identity_keys = account.parsed_identity_keys();
|
||||||
|
|
||||||
|
Ok(Account {
|
||||||
|
inner: Arc::new(Mutex::new(account)),
|
||||||
|
identity_keys: Arc::new(identity_keys),
|
||||||
|
shared: Arc::new(AtomicBool::from(shared)),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new session with another account given a one-time key.
|
/// Create a new session with another account given a one-time key.
|
||||||
|
@ -137,13 +157,15 @@ impl Account {
|
||||||
///
|
///
|
||||||
/// * `their_one_time_key` - A signed one-time key that the other account
|
/// * `their_one_time_key` - A signed one-time key that the other account
|
||||||
/// created and shared with us.
|
/// created and shared with us.
|
||||||
pub fn create_outbound_session(
|
pub async fn create_outbound_session(
|
||||||
&self,
|
&self,
|
||||||
their_identity_key: &str,
|
their_identity_key: &str,
|
||||||
their_one_time_key: &SignedKey,
|
their_one_time_key: &SignedKey,
|
||||||
) -> Result<Session, OlmSessionError> {
|
) -> Result<Session, OlmSessionError> {
|
||||||
let session = self
|
let session = self
|
||||||
.inner
|
.inner
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
.create_outbound_session(their_identity_key, &their_one_time_key.key)?;
|
.create_outbound_session(their_identity_key, &their_one_time_key.key)?;
|
||||||
|
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
|
@ -166,13 +188,15 @@ impl Account {
|
||||||
///
|
///
|
||||||
/// * `message` - A pre-key Olm message that was sent to us by the other
|
/// * `message` - A pre-key Olm message that was sent to us by the other
|
||||||
/// account.
|
/// account.
|
||||||
pub fn create_inbound_session(
|
pub async fn create_inbound_session(
|
||||||
&self,
|
&self,
|
||||||
their_identity_key: &str,
|
their_identity_key: &str,
|
||||||
message: PreKeyMessage,
|
message: PreKeyMessage,
|
||||||
) -> Result<Session, OlmSessionError> {
|
) -> Result<Session, OlmSessionError> {
|
||||||
let session = self
|
let session = self
|
||||||
.inner
|
.inner
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
.create_inbound_session_from(their_identity_key, message)?;
|
.create_inbound_session_from(their_identity_key, message)?;
|
||||||
|
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
|
@ -188,7 +212,7 @@ impl Account {
|
||||||
|
|
||||||
impl PartialEq for Account {
|
impl PartialEq for Account {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.identity_keys() == other.identity_keys() && self.shared == other.shared
|
self.identity_keys() == other.identity_keys() && self.shared() == other.shared()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -566,16 +590,16 @@ mod test {
|
||||||
assert!(!identyty_keys.curve25519().is_empty());
|
assert!(!identyty_keys.curve25519().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn one_time_keys_creation() {
|
async fn one_time_keys_creation() {
|
||||||
let account = Account::new();
|
let account = Account::new();
|
||||||
let one_time_keys = account.one_time_keys();
|
let one_time_keys = account.one_time_keys().await;
|
||||||
|
|
||||||
assert!(one_time_keys.curve25519().is_empty());
|
assert!(one_time_keys.curve25519().is_empty());
|
||||||
assert_ne!(account.max_one_time_keys(), 0);
|
assert_ne!(account.max_one_time_keys().await, 0);
|
||||||
|
|
||||||
account.generate_one_time_keys(10);
|
account.generate_one_time_keys(10).await;
|
||||||
let one_time_keys = account.one_time_keys();
|
let one_time_keys = account.one_time_keys().await;
|
||||||
|
|
||||||
assert!(!one_time_keys.curve25519().is_empty());
|
assert!(!one_time_keys.curve25519().is_empty());
|
||||||
assert_ne!(one_time_keys.values().len(), 0);
|
assert_ne!(one_time_keys.values().len(), 0);
|
||||||
|
@ -588,21 +612,19 @@ mod test {
|
||||||
one_time_keys.get("curve25519").unwrap()
|
one_time_keys.get("curve25519").unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
account.mark_keys_as_published();
|
account.mark_keys_as_published().await;
|
||||||
let one_time_keys = account.one_time_keys();
|
let one_time_keys = account.one_time_keys().await;
|
||||||
assert!(one_time_keys.curve25519().is_empty());
|
assert!(one_time_keys.curve25519().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn session_creation() {
|
async fn session_creation() {
|
||||||
let alice = Account::new();
|
let alice = Account::new();
|
||||||
let bob = Account::new();
|
let bob = Account::new();
|
||||||
let alice_keys = alice.identity_keys();
|
let alice_keys = alice.identity_keys();
|
||||||
let one_time_keys = alice.one_time_keys();
|
alice.generate_one_time_keys(1).await;
|
||||||
|
let one_time_keys = alice.one_time_keys().await;
|
||||||
alice.generate_one_time_keys(1);
|
alice.mark_keys_as_published().await;
|
||||||
let one_time_keys = alice.one_time_keys();
|
|
||||||
alice.mark_keys_as_published();
|
|
||||||
|
|
||||||
let one_time_key = one_time_keys
|
let one_time_key = one_time_keys
|
||||||
.curve25519()
|
.curve25519()
|
||||||
|
@ -619,6 +641,7 @@ mod test {
|
||||||
|
|
||||||
let mut bob_session = bob
|
let mut bob_session = bob
|
||||||
.create_outbound_session(alice_keys.curve25519(), &one_time_key)
|
.create_outbound_session(alice_keys.curve25519(), &one_time_key)
|
||||||
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let plaintext = "Hello world";
|
let plaintext = "Hello world";
|
||||||
|
@ -633,6 +656,7 @@ mod test {
|
||||||
let bob_keys = bob.identity_keys();
|
let bob_keys = bob.identity_keys();
|
||||||
let mut alice_session = alice
|
let mut alice_session = alice
|
||||||
.create_inbound_session(bob_keys.curve25519(), prekey_message)
|
.create_inbound_session(bob_keys.curve25519(), prekey_message)
|
||||||
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(bob_session.session_id(), alice_session.session_id());
|
assert_eq!(bob_session.session_id(), alice_session.session_id());
|
||||||
|
|
|
@ -48,7 +48,7 @@ impl CryptoStore for MemoryStore {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_account(&mut self, _: Arc<Mutex<Account>>) -> Result<()> {
|
async fn save_account(&mut self, _: Account) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait CryptoStore: Debug + Send + Sync {
|
pub trait CryptoStore: Debug + Send + Sync {
|
||||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
async fn save_account(&mut self, account: Account) -> Result<()>;
|
||||||
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
|
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
|
||||||
async fn add_and_save_session(&mut self, session: Session) -> Result<()>;
|
async fn add_and_save_session(&mut self, session: Session) -> Result<()>;
|
||||||
async fn get_sessions(
|
async fn get_sessions(
|
||||||
|
|
|
@ -288,9 +288,8 @@ impl CryptoStore for SqliteStore {
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
|
async fn save_account(&mut self, account: Account) -> Result<()> {
|
||||||
let acc = account.lock().await;
|
let pickle = account.pickle(self.get_pickle_mode()).await;
|
||||||
let pickle = acc.pickle(self.get_pickle_mode());
|
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
query(
|
query(
|
||||||
|
@ -307,7 +306,7 @@ impl CryptoStore for SqliteStore {
|
||||||
.bind(&*self.user_id.to_string())
|
.bind(&*self.user_id.to_string())
|
||||||
.bind(&*self.device_id.to_string())
|
.bind(&*self.device_id.to_string())
|
||||||
.bind(&pickle)
|
.bind(&pickle)
|
||||||
.bind(acc.shared)
|
.bind(account.shared())
|
||||||
.execute(&mut *connection)
|
.execute(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -460,7 +459,7 @@ mod test {
|
||||||
.expect("Can't create store")
|
.expect("Can't create store")
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_loaded_store() -> (Arc<Mutex<Account>>, SqliteStore) {
|
async fn get_loaded_store() -> (Account, SqliteStore) {
|
||||||
let mut store = get_store().await;
|
let mut store = get_store().await;
|
||||||
let account = get_account();
|
let account = get_account();
|
||||||
store
|
store
|
||||||
|
@ -471,19 +470,19 @@ mod test {
|
||||||
(account, store)
|
(account, store)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_account() -> Arc<Mutex<Account>> {
|
fn get_account() -> Account {
|
||||||
let account = Account::new();
|
Account::new()
|
||||||
Arc::new(Mutex::new(account))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_account_and_session() -> (Arc<Mutex<Account>>, Session) {
|
async fn get_account_and_session() -> (Account, Session) {
|
||||||
let alice = Account::new();
|
let alice = Account::new();
|
||||||
|
|
||||||
let bob = Account::new();
|
let bob = Account::new();
|
||||||
|
|
||||||
bob.generate_one_time_keys(1);
|
bob.generate_one_time_keys(1).await;
|
||||||
let one_time_key = bob
|
let one_time_key = bob
|
||||||
.one_time_keys()
|
.one_time_keys()
|
||||||
|
.await
|
||||||
.curve25519()
|
.curve25519()
|
||||||
.iter()
|
.iter()
|
||||||
.nth(0)
|
.nth(0)
|
||||||
|
@ -497,9 +496,10 @@ mod test {
|
||||||
let sender_key = bob.identity_keys().curve25519().to_owned();
|
let sender_key = bob.identity_keys().curve25519().to_owned();
|
||||||
let session = alice
|
let session = alice
|
||||||
.create_outbound_session(&sender_key, &one_time_key)
|
.create_outbound_session(&sender_key, &one_time_key)
|
||||||
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
(Arc::new(Mutex::new(alice)), session)
|
(alice, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -532,11 +532,10 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.expect("Can't save account");
|
.expect("Can't save account");
|
||||||
|
|
||||||
let acc = account.lock().await;
|
|
||||||
let loaded_account = store.load_account().await.expect("Can't load account");
|
let loaded_account = store.load_account().await.expect("Can't load account");
|
||||||
let loaded_account = loaded_account.unwrap();
|
let loaded_account = loaded_account.unwrap();
|
||||||
|
|
||||||
assert_eq!(*acc, loaded_account);
|
assert_eq!(account, loaded_account);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -549,7 +548,7 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.expect("Can't save account");
|
.expect("Can't save account");
|
||||||
|
|
||||||
account.lock().await.shared = true;
|
account.mark_as_shared();
|
||||||
|
|
||||||
store
|
store
|
||||||
.save_account(account.clone())
|
.save_account(account.clone())
|
||||||
|
@ -558,15 +557,14 @@ mod test {
|
||||||
|
|
||||||
let loaded_account = store.load_account().await.expect("Can't load account");
|
let loaded_account = store.load_account().await.expect("Can't load account");
|
||||||
let loaded_account = loaded_account.unwrap();
|
let loaded_account = loaded_account.unwrap();
|
||||||
let acc = account.lock().await;
|
|
||||||
|
|
||||||
assert_eq!(*acc, loaded_account);
|
assert_eq!(account, loaded_account);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn save_session() {
|
async fn save_session() {
|
||||||
let mut store = get_store().await;
|
let mut store = get_store().await;
|
||||||
let (account, session) = get_account_and_session();
|
let (account, session) = get_account_and_session().await;
|
||||||
let session = Arc::new(Mutex::new(session));
|
let session = Arc::new(Mutex::new(session));
|
||||||
|
|
||||||
assert!(store.save_session(session.clone()).await.is_err());
|
assert!(store.save_session(session.clone()).await.is_err());
|
||||||
|
@ -582,7 +580,7 @@ mod test {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn load_sessions() {
|
async fn load_sessions() {
|
||||||
let mut store = get_store().await;
|
let mut store = get_store().await;
|
||||||
let (account, session) = get_account_and_session();
|
let (account, session) = get_account_and_session().await;
|
||||||
let session = Arc::new(Mutex::new(session));
|
let session = Arc::new(Mutex::new(session));
|
||||||
store
|
store
|
||||||
.save_account(account.clone())
|
.save_account(account.clone())
|
||||||
|
@ -604,7 +602,7 @@ mod test {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn add_and_save_session() {
|
async fn add_and_save_session() {
|
||||||
let mut store = get_store().await;
|
let mut store = get_store().await;
|
||||||
let (account, session) = get_account_and_session();
|
let (account, session) = get_account_and_session().await;
|
||||||
let sender_key = session.sender_key.to_owned();
|
let sender_key = session.sender_key.to_owned();
|
||||||
let session_id = session.session_id();
|
let session_id = session.session_id();
|
||||||
|
|
||||||
|
@ -625,8 +623,7 @@ mod test {
|
||||||
async fn save_inbound_group_session() {
|
async fn save_inbound_group_session() {
|
||||||
let (account, mut store) = get_loaded_store().await;
|
let (account, mut store) = get_loaded_store().await;
|
||||||
|
|
||||||
let acc = account.lock().await;
|
let identity_keys = account.identity_keys();
|
||||||
let identity_keys = acc.identity_keys();
|
|
||||||
let outbound_session = OlmOutboundGroupSession::new();
|
let outbound_session = OlmOutboundGroupSession::new();
|
||||||
let session = InboundGroupSession::new(
|
let session = InboundGroupSession::new(
|
||||||
identity_keys.curve25519(),
|
identity_keys.curve25519(),
|
||||||
|
@ -646,8 +643,7 @@ mod test {
|
||||||
async fn load_inbound_group_session() {
|
async fn load_inbound_group_session() {
|
||||||
let (account, mut store) = get_loaded_store().await;
|
let (account, mut store) = get_loaded_store().await;
|
||||||
|
|
||||||
let acc = account.lock().await;
|
let identity_keys = account.identity_keys();
|
||||||
let identity_keys = acc.identity_keys();
|
|
||||||
let outbound_session = OlmOutboundGroupSession::new();
|
let outbound_session = OlmOutboundGroupSession::new();
|
||||||
let session = InboundGroupSession::new(
|
let session = InboundGroupSession::new(
|
||||||
identity_keys.curve25519(),
|
identity_keys.curve25519(),
|
||||||
|
|
Loading…
Reference in New Issue