Modify QuerySharedUsers to handle counts/include/exclude (#1219)
* Modify QuerySharedUsers to handle counts/include/exclude We will need this functionality when working out whether to send device list changes to users who have joined/left a room. * Lintingmain
parent
98f2f09bb4
commit
af5b4d1f6b
|
@ -37,10 +37,12 @@ type CurrentStateInternalAPI interface {
|
||||||
|
|
||||||
type QuerySharedUsersRequest struct {
|
type QuerySharedUsersRequest struct {
|
||||||
UserID string
|
UserID string
|
||||||
|
ExcludeRoomIDs []string
|
||||||
|
IncludeRoomIDs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type QuerySharedUsersResponse struct {
|
type QuerySharedUsersResponse struct {
|
||||||
UserIDs []string
|
UserIDsToCount map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryRoomsForUserRequest struct {
|
type QueryRoomsForUserRequest struct {
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -227,13 +226,31 @@ func TestQuerySharedUsers(t *testing.T) {
|
||||||
req api.QuerySharedUsersRequest
|
req api.QuerySharedUsersRequest
|
||||||
wantRes api.QuerySharedUsersResponse
|
wantRes api.QuerySharedUsersResponse
|
||||||
}{
|
}{
|
||||||
// Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C)
|
// Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A:4,B:2,C:1)
|
||||||
{
|
{
|
||||||
req: api.QuerySharedUsersRequest{
|
req: api.QuerySharedUsersRequest{
|
||||||
UserID: "@alice:localhost",
|
UserID: "@alice:localhost",
|
||||||
},
|
},
|
||||||
wantRes: api.QuerySharedUsersResponse{
|
wantRes: api.QuerySharedUsersResponse{
|
||||||
UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"},
|
UserIDsToCount: map[string]int{
|
||||||
|
"@alice:localhost": 4,
|
||||||
|
"@bob:localhost": 2,
|
||||||
|
"@charlie:localhost": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// Exclude (A,C): sharing (A,B) (A,B) (A) produces (A:3,B:2)
|
||||||
|
{
|
||||||
|
req: api.QuerySharedUsersRequest{
|
||||||
|
UserID: "@alice:localhost",
|
||||||
|
ExcludeRoomIDs: []string{"!foo2:bar"},
|
||||||
|
},
|
||||||
|
wantRes: api.QuerySharedUsersResponse{
|
||||||
|
UserIDsToCount: map[string]int{
|
||||||
|
"@alice:localhost": 3,
|
||||||
|
"@bob:localhost": 2,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -243,7 +260,7 @@ func TestQuerySharedUsers(t *testing.T) {
|
||||||
UserID: "@unknownuser:localhost",
|
UserID: "@unknownuser:localhost",
|
||||||
},
|
},
|
||||||
wantRes: api.QuerySharedUsersResponse{
|
wantRes: api.QuerySharedUsersResponse{
|
||||||
UserIDs: nil,
|
UserIDsToCount: map[string]int{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -253,7 +270,35 @@ func TestQuerySharedUsers(t *testing.T) {
|
||||||
UserID: "@dave:localhost",
|
UserID: "@dave:localhost",
|
||||||
},
|
},
|
||||||
wantRes: api.QuerySharedUsersResponse{
|
wantRes: api.QuerySharedUsersResponse{
|
||||||
UserIDs: nil,
|
UserIDsToCount: map[string]int{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// left real user but with included room returns the included room member
|
||||||
|
{
|
||||||
|
req: api.QuerySharedUsersRequest{
|
||||||
|
UserID: "@dave:localhost",
|
||||||
|
IncludeRoomIDs: []string{"!foo:bar"},
|
||||||
|
},
|
||||||
|
wantRes: api.QuerySharedUsersResponse{
|
||||||
|
UserIDsToCount: map[string]int{
|
||||||
|
"@alice:localhost": 1,
|
||||||
|
"@bob:localhost": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// including a room more than once doesn't double counts
|
||||||
|
{
|
||||||
|
req: api.QuerySharedUsersRequest{
|
||||||
|
UserID: "@dave:localhost",
|
||||||
|
IncludeRoomIDs: []string{"!foo:bar", "!foo:bar", "!foo:bar"},
|
||||||
|
},
|
||||||
|
wantRes: api.QuerySharedUsersResponse{
|
||||||
|
UserIDsToCount: map[string]int{
|
||||||
|
"@alice:localhost": 1,
|
||||||
|
"@bob:localhost": 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -266,10 +311,8 @@ func TestQuerySharedUsers(t *testing.T) {
|
||||||
t.Errorf("QuerySharedUsers returned error: %s", err)
|
t.Errorf("QuerySharedUsers returned error: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
sort.Strings(res.UserIDs)
|
if !reflect.DeepEqual(res.UserIDsToCount, tc.wantRes.UserIDsToCount) {
|
||||||
sort.Strings(tc.wantRes.UserIDs)
|
t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDsToCount, tc.wantRes.UserIDsToCount)
|
||||||
if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) {
|
|
||||||
t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,10 +74,27 @@ func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
roomIDs = append(roomIDs, req.IncludeRoomIDs...)
|
||||||
|
excludeMap := make(map[string]bool)
|
||||||
|
for _, roomID := range req.ExcludeRoomIDs {
|
||||||
|
excludeMap[roomID] = true
|
||||||
|
}
|
||||||
|
// filter out excluded rooms
|
||||||
|
j := 0
|
||||||
|
for i := range roomIDs {
|
||||||
|
// move elements to include to the beginning of the slice
|
||||||
|
// then trim elements on the right
|
||||||
|
if !excludeMap[roomIDs[i]] {
|
||||||
|
roomIDs[j] = roomIDs[i]
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
roomIDs = roomIDs[:j]
|
||||||
|
|
||||||
users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.UserIDs = users
|
res.UserIDsToCount = users
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,6 +37,6 @@ type Database interface {
|
||||||
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||||
// Redact a state event
|
// Redact a state event
|
||||||
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
|
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
|
||||||
// JoinedUsersSetInRooms returns all joined users in the rooms given.
|
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error)
|
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,7 +78,8 @@ const selectBulkStateContentWildSQL = "" +
|
||||||
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)"
|
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)"
|
||||||
|
|
||||||
const selectJoinedUsersSetForRoomsSQL = "" +
|
const selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'"
|
"SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" +
|
||||||
|
" type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
|
@ -124,21 +125,22 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
|
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||||
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
|
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||||
var userIDs []string
|
result := make(map[string]int)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var userID string
|
var userID string
|
||||||
if err := rows.Scan(&userID); err != nil {
|
var count int
|
||||||
|
if err := rows.Scan(&userID, &count); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
userIDs = append(userIDs, userID)
|
result[userID] = count
|
||||||
}
|
}
|
||||||
return userIDs, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||||
|
|
|
@ -86,6 +86,6 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
|
||||||
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
|
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) {
|
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||||
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
|
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ const selectBulkStateContentWildSQL = "" +
|
||||||
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)"
|
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)"
|
||||||
|
|
||||||
const selectJoinedUsersSetForRoomsSQL = "" +
|
const selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'"
|
"SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -106,7 +106,7 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
|
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||||
for i, v := range roomIDs {
|
for i, v := range roomIDs {
|
||||||
iRoomIDs[i] = v
|
iRoomIDs[i] = v
|
||||||
|
@ -117,15 +117,16 @@ func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Co
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||||
var userIDs []string
|
result := make(map[string]int)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var userID string
|
var userID string
|
||||||
if err := rows.Scan(&userID); err != nil {
|
var count int
|
||||||
|
if err := rows.Scan(&userID, &count); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
userIDs = append(userIDs, userID)
|
result[userID] = count
|
||||||
}
|
}
|
||||||
return userIDs, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||||
|
|
|
@ -36,8 +36,9 @@ type CurrentRoomState interface {
|
||||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||||
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
|
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
|
||||||
SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error)
|
SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error)
|
||||||
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms.
|
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
||||||
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error)
|
// counts of how many rooms they are joined.
|
||||||
|
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StrippedEvent represents a stripped event for returning extracted content values.
|
// StrippedEvent represents a stripped event for returning extracted content values.
|
||||||
|
|
Loading…
Reference in New Issue