currentstate: Add QuerySharedUsers (#1217)
This will be used to determine who to send device list updates to. It can also be used to determine who to send presence info to.
This commit is contained in:
parent
cfeb1b2f42
commit
7b862384a7
10 changed files with 232 additions and 13 deletions
|
@ -31,6 +31,16 @@ type CurrentStateInternalAPI interface {
|
|||
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
||||
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
|
||||
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
|
||||
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
|
||||
}
|
||||
|
||||
type QuerySharedUsersRequest struct {
|
||||
UserID string
|
||||
}
|
||||
|
||||
type QuerySharedUsersResponse struct {
|
||||
UserIDs []string
|
||||
}
|
||||
|
||||
type QueryRoomsForUserRequest struct {
|
||||
|
|
|
@ -16,9 +16,11 @@ package currentstateserver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -178,3 +180,112 @@ func TestQueryCurrentState(t *testing.T) {
|
|||
runCases(currStateAPI)
|
||||
})
|
||||
}
|
||||
|
||||
func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *roomserverAPI.OutputNewRoomEvent {
|
||||
eb := gomatrixserverlib.EventBuilder{
|
||||
RoomID: roomID,
|
||||
Sender: userID,
|
||||
StateKey: &userID,
|
||||
Type: "m.room.member",
|
||||
Content: []byte(`{"membership":"` + membership + `"}`),
|
||||
}
|
||||
_, pkey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make ed25519 key: %s", err)
|
||||
}
|
||||
roomVer := gomatrixserverlib.RoomVersionV5
|
||||
ev, err := eb.Build(
|
||||
time.Now(), gomatrixserverlib.ServerName("localhost"), gomatrixserverlib.KeyID("ed25519:test"),
|
||||
pkey, roomVer,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("mustMakeMembershipEvent failed: %s", err)
|
||||
}
|
||||
|
||||
return &roomserverAPI.OutputNewRoomEvent{
|
||||
Event: ev.Headered(roomVer),
|
||||
AddsStateEventIDs: []string{ev.EventID()},
|
||||
}
|
||||
}
|
||||
|
||||
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
|
||||
func TestQuerySharedUsers(t *testing.T) {
|
||||
currStateAPI, producer := MustMakeInternalAPI(t)
|
||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
|
||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
|
||||
|
||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@alice:localhost", "join"))
|
||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@charlie:localhost", "join"))
|
||||
|
||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@alice:localhost", "join"))
|
||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join"))
|
||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave"))
|
||||
|
||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
|
||||
|
||||
testCases := []struct {
|
||||
req api.QuerySharedUsersRequest
|
||||
wantRes api.QuerySharedUsersResponse
|
||||
}{
|
||||
// Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C)
|
||||
{
|
||||
req: api.QuerySharedUsersRequest{
|
||||
UserID: "@alice:localhost",
|
||||
},
|
||||
wantRes: api.QuerySharedUsersResponse{
|
||||
UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"},
|
||||
},
|
||||
},
|
||||
|
||||
// Unknown user has no shared users
|
||||
{
|
||||
req: api.QuerySharedUsersRequest{
|
||||
UserID: "@unknownuser:localhost",
|
||||
},
|
||||
wantRes: api.QuerySharedUsersResponse{
|
||||
UserIDs: nil,
|
||||
},
|
||||
},
|
||||
|
||||
// left real user produces no shared users
|
||||
{
|
||||
req: api.QuerySharedUsersRequest{
|
||||
UserID: "@dave:localhost",
|
||||
},
|
||||
wantRes: api.QuerySharedUsersResponse{
|
||||
UserIDs: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
runCases := func(testAPI api.CurrentStateInternalAPI) {
|
||||
for _, tc := range testCases {
|
||||
var res api.QuerySharedUsersResponse
|
||||
err := testAPI.QuerySharedUsers(context.Background(), &tc.req, &res)
|
||||
if err != nil {
|
||||
t.Errorf("QuerySharedUsers returned error: %s", err)
|
||||
continue
|
||||
}
|
||||
sort.Strings(res.UserIDs)
|
||||
sort.Strings(tc.wantRes.UserIDs)
|
||||
if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) {
|
||||
t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("HTTP API", func(t *testing.T) {
|
||||
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||
AddInternalRoutes(router, currStateAPI)
|
||||
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||
defer cancel()
|
||||
httpAPI, err := inthttp.NewCurrentStateAPIClient(apiURL, &http.Client{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create HTTP client")
|
||||
}
|
||||
runCases(httpAPI)
|
||||
})
|
||||
t.Run("Monolith", func(t *testing.T) {
|
||||
runCases(currStateAPI)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -68,3 +68,16 @@ func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||
roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, "join")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.UserIDs = users
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ const (
|
|||
QueryCurrentStatePath = "/currentstateserver/queryCurrentState"
|
||||
QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser"
|
||||
QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent"
|
||||
QuerySharedUsersPath = "/currentstateserver/querySharedUsers"
|
||||
)
|
||||
|
||||
// NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API.
|
||||
|
@ -86,3 +87,13 @@ func (h *httpCurrentStateInternalAPI) QueryBulkStateContent(
|
|||
apiURL := h.apiURL + QueryBulkStateContentPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpCurrentStateInternalAPI) QuerySharedUsers(
|
||||
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QuerySharedUsersPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
|
|
@ -64,4 +64,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QuerySharedUsersPath,
|
||||
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QuerySharedUsersRequest{}
|
||||
response := api.QuerySharedUsersResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := intAPI.QuerySharedUsers(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -37,4 +37,6 @@ type Database interface {
|
|||
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||
// Redact a state event
|
||||
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
|
||||
// JoinedUsersSetInRooms returns all joined users in the rooms given.
|
||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error)
|
||||
}
|
||||
|
|
|
@ -77,14 +77,18 @@ const selectBulkStateContentSQL = "" +
|
|||
const selectBulkStateContentWildSQL = "" +
|
||||
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)"
|
||||
|
||||
const selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'"
|
||||
|
||||
type currentRoomStateStatements struct {
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||
selectEventsWithEventIDsStmt *sql.Stmt
|
||||
selectStateEventStmt *sql.Stmt
|
||||
selectBulkStateContentStmt *sql.Stmt
|
||||
selectBulkStateContentWildStmt *sql.Stmt
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||
selectEventsWithEventIDsStmt *sql.Stmt
|
||||
selectStateEventStmt *sql.Stmt
|
||||
selectBulkStateContentStmt *sql.Stmt
|
||||
selectBulkStateContentWildStmt *sql.Stmt
|
||||
selectJoinedUsersSetForRoomsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||
|
@ -114,9 +118,29 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
|
|||
if s.selectBulkStateContentWildStmt, err = db.Prepare(selectBulkStateContentWildSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
|
||||
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||
var userIDs []string
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
|
||||
ctx context.Context,
|
||||
|
|
|
@ -85,3 +85,7 @@ func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatr
|
|||
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
|
||||
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
|
||||
}
|
||||
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) {
|
||||
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
|
||||
}
|
||||
|
|
|
@ -66,13 +66,17 @@ const selectBulkStateContentSQL = "" +
|
|||
const selectBulkStateContentWildSQL = "" +
|
||||
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)"
|
||||
|
||||
const selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'"
|
||||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||
selectStateEventStmt *sql.Stmt
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||
selectStateEventStmt *sql.Stmt
|
||||
selectJoinedUsersSetForRoomsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||
|
@ -96,9 +100,34 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
|
|||
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
|
||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||
for i, v := range roomIDs {
|
||||
iRoomIDs[i] = v
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, query, iRoomIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||
var userIDs []string
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
|
||||
ctx context.Context,
|
||||
|
|
|
@ -36,6 +36,8 @@ type CurrentRoomState interface {
|
|||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
|
||||
SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error)
|
||||
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms.
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error)
|
||||
}
|
||||
|
||||
// StrippedEvent represents a stripped event for returning extracted content values.
|
||||
|
|
Loading…
Reference in a new issue