base: Correctly get the user ids of all room members

master
Damir Jelić 2021-05-18 08:29:10 +02:00
parent cd77441d1b
commit c122549e0d
1 changed files with 28 additions and 13 deletions

View File

@ -83,8 +83,9 @@ impl From<SerializationError> for StoreError {
} }
} }
const ENCODE_SEPARATOR: u8 = 0xff;
trait EncodeKey { trait EncodeKey {
const SEPARATOR: u8 = 0xff;
fn encode(&self) -> Vec<u8>; fn encode(&self) -> Vec<u8>;
} }
@ -102,13 +103,13 @@ impl EncodeKey for &RoomId {
impl EncodeKey for &str { impl EncodeKey for &str {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[self.as_bytes(), &[Self::SEPARATOR]].concat() [self.as_bytes(), &[ENCODE_SEPARATOR]].concat()
} }
} }
impl EncodeKey for (&str, &str) { impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat() [self.0.as_bytes(), &[ENCODE_SEPARATOR], self.1.as_bytes(), &[ENCODE_SEPARATOR]].concat()
} }
} }
@ -116,11 +117,11 @@ impl EncodeKey for (&str, &str, &str) {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[ [
self.0.as_bytes(), self.0.as_bytes(),
&[Self::SEPARATOR], &[ENCODE_SEPARATOR],
self.1.as_bytes(), self.1.as_bytes(),
&[Self::SEPARATOR], &[ENCODE_SEPARATOR],
self.2.as_bytes(), self.2.as_bytes(),
&[Self::SEPARATOR], &[ENCODE_SEPARATOR],
] ]
.concat() .concat()
} }
@ -506,11 +507,22 @@ impl SledStore {
.transpose()?) .transpose()?)
} }
pub async fn get_user_ids(&self, room_id: &RoomId) -> impl Stream<Item = Result<UserId>> { pub async fn get_user_ids_stream(
stream::iter(self.members.scan_prefix(room_id.encode()).map(|u| { &self,
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string()) room_id: &RoomId,
.map_err(StoreError::Identifier) ) -> impl Stream<Item = Result<UserId>> {
})) let decode = |key: &[u8]| -> Result<UserId> {
let mut iter = key.split(|c| c == &ENCODE_SEPARATOR);
// Our key is a the room id separated from the user id by a null
// byte, discard the first value of the split.
iter.next();
let user_id = iter.next().expect("User ids weren't properly encoded");
Ok(UserId::try_from(String::from_utf8_lossy(user_id).to_string())?)
};
stream::iter(self.members.scan_prefix(room_id.encode()).map(move |u| decode(&u?.0)))
} }
pub async fn get_invited_user_ids( pub async fn get_invited_user_ids(
@ -636,7 +648,7 @@ impl StateStore for SledStore {
} }
async fn get_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> { async fn get_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
self.get_user_ids(room_id).await.try_collect().await self.get_user_ids_stream(room_id).await.try_collect().await
} }
async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> { async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
@ -698,7 +710,7 @@ mod test {
use serde_json::json; use serde_json::json;
use super::{SledStore, StateChanges}; use super::{SledStore, StateChanges};
use crate::deserialized_responses::MemberEvent; use crate::{deserialized_responses::MemberEvent, StateStore};
fn user_id() -> UserId { fn user_id() -> UserId {
user_id!("@example:localhost") user_id!("@example:localhost")
@ -748,6 +760,9 @@ mod test {
store.save_changes(&changes).await.unwrap(); store.save_changes(&changes).await.unwrap();
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some()); assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some());
let members = store.get_user_ids(&room_id).await.unwrap();
assert!(!members.is_empty())
} }
#[async_test] #[async_test]