From 7b862384a779f067f07ffeb2151856f89d372732 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 23 Jul 2020 12:26:31 +0100 Subject: [PATCH] currentstate: Add QuerySharedUsers (#1217) This will be used to determine who to send device list updates to. It can also be used to determine who to send presence info to. --- currentstateserver/api/api.go | 10 ++ currentstateserver/currentstateserver_test.go | 111 ++++++++++++++++++ currentstateserver/internal/api.go | 13 ++ currentstateserver/inthttp/client.go | 11 ++ currentstateserver/inthttp/server.go | 13 ++ currentstateserver/storage/interface.go | 2 + .../postgres/current_room_state_table.go | 38 ++++-- currentstateserver/storage/shared/storage.go | 4 + .../sqlite3/current_room_state_table.go | 41 ++++++- .../storage/tables/interface.go | 2 + 10 files changed, 232 insertions(+), 13 deletions(-) diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go index 729a66ba..520ce8d6 100644 --- a/currentstateserver/api/api.go +++ b/currentstateserver/api/api.go @@ -31,6 +31,16 @@ type CurrentStateInternalAPI interface { QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error // QueryBulkStateContent does a bulk query for state event content in the given rooms. QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error + // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. + QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error +} + +type QuerySharedUsersRequest struct { + UserID string +} + +type QuerySharedUsersResponse struct { + UserIDs []string } type QueryRoomsForUserRequest struct { diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go index a0627fea..4dac742f 100644 --- a/currentstateserver/currentstateserver_test.go +++ b/currentstateserver/currentstateserver_test.go @@ -16,9 +16,11 @@ package currentstateserver import ( "context" + "crypto/ed25519" "encoding/json" "net/http" "reflect" + "sort" "testing" "time" @@ -178,3 +180,112 @@ func TestQueryCurrentState(t *testing.T) { runCases(currStateAPI) }) } + +func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *roomserverAPI.OutputNewRoomEvent { + eb := gomatrixserverlib.EventBuilder{ + RoomID: roomID, + Sender: userID, + StateKey: &userID, + Type: "m.room.member", + Content: []byte(`{"membership":"` + membership + `"}`), + } + _, pkey, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("failed to make ed25519 key: %s", err) + } + roomVer := gomatrixserverlib.RoomVersionV5 + ev, err := eb.Build( + time.Now(), gomatrixserverlib.ServerName("localhost"), gomatrixserverlib.KeyID("ed25519:test"), + pkey, roomVer, + ) + if err != nil { + t.Fatalf("mustMakeMembershipEvent failed: %s", err) + } + + return &roomserverAPI.OutputNewRoomEvent{ + Event: ev.Headered(roomVer), + AddsStateEventIDs: []string{ev.EventID()}, + } +} + +// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets. +func TestQuerySharedUsers(t *testing.T) { + currStateAPI, producer := MustMakeInternalAPI(t) + MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join")) + MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join")) + + MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@alice:localhost", "join")) + MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@charlie:localhost", "join")) + + MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@alice:localhost", "join")) + MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join")) + MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave")) + + MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join")) + + testCases := []struct { + req api.QuerySharedUsersRequest + wantRes api.QuerySharedUsersResponse + }{ + // Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C) + { + req: api.QuerySharedUsersRequest{ + UserID: "@alice:localhost", + }, + wantRes: api.QuerySharedUsersResponse{ + UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"}, + }, + }, + + // Unknown user has no shared users + { + req: api.QuerySharedUsersRequest{ + UserID: "@unknownuser:localhost", + }, + wantRes: api.QuerySharedUsersResponse{ + UserIDs: nil, + }, + }, + + // left real user produces no shared users + { + req: api.QuerySharedUsersRequest{ + UserID: "@dave:localhost", + }, + wantRes: api.QuerySharedUsersResponse{ + UserIDs: nil, + }, + }, + } + + runCases := func(testAPI api.CurrentStateInternalAPI) { + for _, tc := range testCases { + var res api.QuerySharedUsersResponse + err := testAPI.QuerySharedUsers(context.Background(), &tc.req, &res) + if err != nil { + t.Errorf("QuerySharedUsers returned error: %s", err) + continue + } + sort.Strings(res.UserIDs) + sort.Strings(tc.wantRes.UserIDs) + if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) { + t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs) + } + } + } + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + AddInternalRoutes(router, currStateAPI) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + httpAPI, err := inthttp.NewCurrentStateAPIClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client") + } + runCases(httpAPI) + }) + t.Run("Monolith", func(t *testing.T) { + runCases(currStateAPI) + }) +} diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go index c2876047..e945d0c1 100644 --- a/currentstateserver/internal/api.go +++ b/currentstateserver/internal/api.go @@ -68,3 +68,16 @@ func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req } return nil } + +func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { + roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, "join") + if err != nil { + return err + } + users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs) + if err != nil { + return err + } + res.UserIDs = users + return nil +} diff --git a/currentstateserver/inthttp/client.go b/currentstateserver/inthttp/client.go index b8c6a119..cce881ff 100644 --- a/currentstateserver/inthttp/client.go +++ b/currentstateserver/inthttp/client.go @@ -29,6 +29,7 @@ const ( QueryCurrentStatePath = "/currentstateserver/queryCurrentState" QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" + QuerySharedUsersPath = "/currentstateserver/querySharedUsers" ) // NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. @@ -86,3 +87,13 @@ func (h *httpCurrentStateInternalAPI) QueryBulkStateContent( apiURL := h.apiURL + QueryBulkStateContentPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +func (h *httpCurrentStateInternalAPI) QuerySharedUsers( + ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers") + defer span.Finish() + + apiURL := h.apiURL + QuerySharedUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/currentstateserver/inthttp/server.go b/currentstateserver/inthttp/server.go index dafb9f64..f4e93dcd 100644 --- a/currentstateserver/inthttp/server.go +++ b/currentstateserver/inthttp/server.go @@ -64,4 +64,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QuerySharedUsersPath, + httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse { + request := api.QuerySharedUsersRequest{} + response := api.QuerySharedUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.QuerySharedUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go index 0e95cde8..1c4635be 100644 --- a/currentstateserver/storage/interface.go +++ b/currentstateserver/storage/interface.go @@ -37,4 +37,6 @@ type Database interface { GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) // Redact a state event RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error + // JoinedUsersSetInRooms returns all joined users in the rooms given. + JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) } diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go index 79c9f967..9e0070f1 100644 --- a/currentstateserver/storage/postgres/current_room_state_table.go +++ b/currentstateserver/storage/postgres/current_room_state_table.go @@ -77,14 +77,18 @@ const selectBulkStateContentSQL = "" + const selectBulkStateContentWildSQL = "" + "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)" +const selectJoinedUsersSetForRoomsSQL = "" + + "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'" + type currentRoomStateStatements struct { - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectEventsWithEventIDsStmt *sql.Stmt - selectStateEventStmt *sql.Stmt - selectBulkStateContentStmt *sql.Stmt - selectBulkStateContentWildStmt *sql.Stmt + upsertRoomStateStmt *sql.Stmt + deleteRoomStateByEventIDStmt *sql.Stmt + selectRoomIDsWithMembershipStmt *sql.Stmt + selectEventsWithEventIDsStmt *sql.Stmt + selectStateEventStmt *sql.Stmt + selectBulkStateContentStmt *sql.Stmt + selectBulkStateContentWildStmt *sql.Stmt + selectJoinedUsersSetForRoomsStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -114,9 +118,29 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectBulkStateContentWildStmt, err = db.Prepare(selectBulkStateContentWildSQL); err != nil { return nil, err } + if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { + return nil, err + } return s, nil } +func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) { + rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + var userIDs []string + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + userIDs = append(userIDs, userID) + } + return userIDs, rows.Err() +} + // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go index 66b979d8..aafb5fdd 100644 --- a/currentstateserver/storage/shared/storage.go +++ b/currentstateserver/storage/shared/storage.go @@ -85,3 +85,7 @@ func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatr func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) } + +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) { + return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs) +} diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index b95fb435..4d3803b6 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -66,13 +66,17 @@ const selectBulkStateContentSQL = "" + const selectBulkStateContentWildSQL = "" + "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)" +const selectJoinedUsersSetForRoomsSQL = "" + + "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'" + type currentRoomStateStatements struct { - db *sql.DB - writer *sqlutil.TransactionWriter - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectStateEventStmt *sql.Stmt + db *sql.DB + writer *sqlutil.TransactionWriter + upsertRoomStateStmt *sql.Stmt + deleteRoomStateByEventIDStmt *sql.Stmt + selectRoomIDsWithMembershipStmt *sql.Stmt + selectStateEventStmt *sql.Stmt + selectJoinedUsersSetForRoomsStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -96,9 +100,34 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { return nil, err } + if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { + return nil, err + } return s, nil } +func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) { + iRoomIDs := make([]interface{}, len(roomIDs)) + for i, v := range roomIDs { + iRoomIDs[i] = v + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + rows, err := s.db.QueryContext(ctx, query, iRoomIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + var userIDs []string + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + userIDs = append(userIDs, userID) + } + return userIDs, rows.Err() +} + // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go index 12884b68..88e7a31b 100644 --- a/currentstateserver/storage/tables/interface.go +++ b/currentstateserver/storage/tables/interface.go @@ -36,6 +36,8 @@ type CurrentRoomState interface { // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error) + // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms. + SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) } // StrippedEvent represents a stripped event for returning extracted content values.