From acc8e80a51515c953c6710cb24f36fd9d1f7aeb1 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 28 Jul 2020 10:53:17 +0100 Subject: [PATCH] User directory (#1225) * User directory * Fix syncapi unit test * Make user directory only show remote users you know about from your joined rooms * Update sytest-whitelist * Review comments --- clientapi/auth/authtypes/profile.go | 13 +- clientapi/routing/routing.go | 21 ++++ clientapi/routing/userdirectory.go | 115 ++++++++++++++++++ currentstateserver/api/api.go | 13 ++ currentstateserver/internal/api.go | 14 +++ currentstateserver/inthttp/client.go | 11 ++ currentstateserver/inthttp/server.go | 13 ++ currentstateserver/storage/interface.go | 2 + .../postgres/current_room_state_table.go | 30 +++++ currentstateserver/storage/shared/storage.go | 4 + .../sqlite3/current_room_state_table.go | 30 +++++ .../storage/tables/interface.go | 2 + syncapi/consumers/keychange_test.go | 4 + sytest-whitelist | 4 + userapi/api/api.go | 16 +++ userapi/internal/api.go | 9 ++ userapi/inthttp/client.go | 19 ++- userapi/inthttp/server.go | 13 ++ userapi/storage/accounts/interface.go | 1 + .../accounts/postgres/profile_table.go | 31 +++++ userapi/storage/accounts/postgres/storage.go | 7 ++ .../storage/accounts/sqlite3/profile_table.go | 31 +++++ userapi/storage/accounts/sqlite3/storage.go | 7 ++ 23 files changed, 402 insertions(+), 8 deletions(-) create mode 100644 clientapi/routing/userdirectory.go diff --git a/clientapi/auth/authtypes/profile.go b/clientapi/auth/authtypes/profile.go index 0bc49658..902850bc 100644 --- a/clientapi/auth/authtypes/profile.go +++ b/clientapi/auth/authtypes/profile.go @@ -16,7 +16,14 @@ package authtypes // Profile represents the profile for a Matrix account. type Profile struct { - Localpart string - DisplayName string - AvatarURL string + Localpart string `json:"local_part"` + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` +} + +// FullyQualifiedProfile represents the profile for a Matrix account. +type FullyQualifiedProfile struct { + UserID string `json:"user_id"` + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 5724a20c..ebb141ef 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -574,6 +574,27 @@ func Setup( }), ).Methods(http.MethodGet) + r0mux.Handle("/user_directory/search", + httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + postContent := struct { + SearchString string `json:"search_term"` + Limit int `json:"limit"` + }{} + if err := json.NewDecoder(req.Body).Decode(&postContent); err != nil { + return util.ErrorResponse(err) + } + return *SearchUserDirectory( + req.Context(), + device, + userAPI, + stateAPI, + cfg.Matrix.ServerName, + postContent.SearchString, + postContent.Limit, + ) + }), + ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/rooms/{roomID}/members", httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go new file mode 100644 index 00000000..db81ffea --- /dev/null +++ b/clientapi/routing/userdirectory.go @@ -0,0 +1,115 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +type UserDirectoryResponse struct { + Results []authtypes.FullyQualifiedProfile `json:"results"` + Limited bool `json:"limited"` +} + +func SearchUserDirectory( + ctx context.Context, + device *userapi.Device, + userAPI userapi.UserInternalAPI, + stateAPI currentstateAPI.CurrentStateInternalAPI, + serverName gomatrixserverlib.ServerName, + searchString string, + limit int, +) *util.JSONResponse { + if limit < 10 { + limit = 10 + } + + results := map[string]authtypes.FullyQualifiedProfile{} + response := &UserDirectoryResponse{ + Results: []authtypes.FullyQualifiedProfile{}, + Limited: false, + } + + // First start searching local users. + + userReq := &userapi.QuerySearchProfilesRequest{ + SearchString: searchString, + Limit: limit, + } + userRes := &userapi.QuerySearchProfilesResponse{} + if err := userAPI.QuerySearchProfiles(ctx, userReq, userRes); err != nil { + errRes := util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err)) + return &errRes + } + + for _, user := range userRes.Profiles { + if len(results) == limit { + response.Limited = true + break + } + + userID := fmt.Sprintf("@%s:%s", user.Localpart, serverName) + if _, ok := results[userID]; !ok { + results[userID] = authtypes.FullyQualifiedProfile{ + UserID: userID, + DisplayName: user.DisplayName, + AvatarURL: user.AvatarURL, + } + } + } + + // Then, if we have enough room left in the response, + // start searching for known users from joined rooms. + + if len(results) <= limit { + stateReq := ¤tstateAPI.QueryKnownUsersRequest{ + UserID: device.UserID, + SearchString: searchString, + Limit: limit - len(results), + } + stateRes := ¤tstateAPI.QueryKnownUsersResponse{} + if err := stateAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil { + errRes := util.ErrorResponse(fmt.Errorf("stateAPI.QueryKnownUsers: %w", err)) + return &errRes + } + + for _, user := range stateRes.Users { + if len(results) == limit { + response.Limited = true + break + } + + if _, ok := results[user.UserID]; !ok { + results[user.UserID] = user + } + } + } + + for _, result := range results { + response.Results = append(response.Results, result) + } + + return &util.JSONResponse{ + Code: 200, + JSON: response, + } +} diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go index b778acb2..4ebe2968 100644 --- a/currentstateserver/api/api.go +++ b/currentstateserver/api/api.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -33,6 +34,8 @@ type CurrentStateInternalAPI interface { QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error + // QueryKnownUsers returns a list of users that we know about from our joined rooms. + QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error } type QuerySharedUsersRequest struct { @@ -88,6 +91,16 @@ type QueryCurrentStateResponse struct { StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent } +type QueryKnownUsersRequest struct { + UserID string `json:"user_id"` + SearchString string `json:"search_string"` + Limit int `json:"limit"` +} + +type QueryKnownUsersResponse struct { + Users []authtypes.FullyQualifiedProfile `json:"profiles"` +} + // MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go index c581c524..dc255412 100644 --- a/currentstateserver/internal/api.go +++ b/currentstateserver/internal/api.go @@ -17,6 +17,7 @@ package internal import ( "context" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/currentstateserver/storage" "github.com/matrix-org/gomatrixserverlib" @@ -49,6 +50,19 @@ func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *ap return nil } +func (a *CurrentStateInternalAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + users, err := a.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit) + if err != nil { + return err + } + for _, user := range users { + res.Users = append(res.Users, authtypes.FullyQualifiedProfile{ + UserID: user, + }) + } + return nil +} + func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) if err != nil { diff --git a/currentstateserver/inthttp/client.go b/currentstateserver/inthttp/client.go index cce881ff..37d289ea 100644 --- a/currentstateserver/inthttp/client.go +++ b/currentstateserver/inthttp/client.go @@ -30,6 +30,7 @@ const ( QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" QuerySharedUsersPath = "/currentstateserver/querySharedUsers" + QueryKnownUsersPath = "/currentstateserver/queryKnownUsers" ) // NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. @@ -97,3 +98,13 @@ func (h *httpCurrentStateInternalAPI) QuerySharedUsers( apiURL := h.apiURL + QuerySharedUsersPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpCurrentStateInternalAPI) QueryKnownUsers( + ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") + defer span.Finish() + + apiURL := h.apiURL + QueryKnownUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/currentstateserver/inthttp/server.go b/currentstateserver/inthttp/server.go index f4e93dcd..aee900e0 100644 --- a/currentstateserver/inthttp/server.go +++ b/currentstateserver/inthttp/server.go @@ -77,4 +77,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QuerySharedUsersPath, + httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { + request := api.QueryKnownUsersRequest{} + response := api.QueryKnownUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.QueryKnownUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go index 8deaa348..5a754b9e 100644 --- a/currentstateserver/storage/interface.go +++ b/currentstateserver/storage/interface.go @@ -39,4 +39,6 @@ type Database interface { RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // GetKnownUsers searches all users that userID knows about. + GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) } diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go index 294f757c..e29fa703 100644 --- a/currentstateserver/storage/postgres/current_room_state_table.go +++ b/currentstateserver/storage/postgres/current_room_state_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "github.com/lib/pq" "github.com/matrix-org/dendrite/currentstateserver/storage/tables" @@ -81,6 +82,14 @@ const selectJoinedUsersSetForRoomsSQL = "" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" + " type = 'm.room.member' and content_value = 'join' GROUP BY state_key" +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +const selectKnownUsersSQL = "" + + "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY(" + + " SELECT DISTINCT room_id FROM currentstate_current_room_state WHERE state_key=$1 AND TYPE='m.room.member' AND content_value='join'" + + ") AND TYPE='m.room.member' AND content_value='join' AND state_key LIKE $2 LIMIT $3" + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -90,6 +99,7 @@ type currentRoomStateStatements struct { selectBulkStateContentStmt *sql.Stmt selectBulkStateContentWildStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -122,6 +132,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { return nil, err } + if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil { + return nil, err + } return s, nil } @@ -295,3 +308,20 @@ func (s *currentRoomStateStatements) SelectBulkStateContent( } return strippedEvents, rows.Err() } + +func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go index dac38790..bd4329a7 100644 --- a/currentstateserver/storage/shared/storage.go +++ b/currentstateserver/storage/shared/storage.go @@ -89,3 +89,7 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs) } + +func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + return d.CurrentRoomState.SelectKnownUsers(ctx, userID, searchString, limit) +} diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 5706fa35..a2989364 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "strings" "github.com/matrix-org/dendrite/currentstateserver/storage/tables" @@ -69,6 +70,14 @@ const selectBulkStateContentWildSQL = "" + const selectJoinedUsersSetForRoomsSQL = "" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key" +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +const selectKnownUsersSQL = "" + + "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN (" + + " SELECT DISTINCT room_id FROM currentstate_current_room_state WHERE state_key=$1 AND TYPE='m.room.member' AND content_value='join'" + + ") AND TYPE='m.room.member' AND content_value='join' AND state_key LIKE $2 LIMIT $3" + type currentRoomStateStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -77,6 +86,7 @@ type currentRoomStateStatements struct { selectRoomIDsWithMembershipStmt *sql.Stmt selectStateEventStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -103,6 +113,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { return nil, err } + if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil { + return nil, err + } return s, nil } @@ -315,3 +328,20 @@ func (s *currentRoomStateStatements) SelectBulkStateContent( } return strippedEvents, rows.Err() } + +func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go index 121bf4fd..6290e7b3 100644 --- a/currentstateserver/storage/tables/interface.go +++ b/currentstateserver/storage/tables/interface.go @@ -39,6 +39,8 @@ type CurrentRoomState interface { // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // SelectKnownUsers searches all users that userID knows about. + SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) } // StrippedEvent represents a stripped event for returning extracted content values. diff --git a/syncapi/consumers/keychange_test.go b/syncapi/consumers/keychange_test.go index 9e7ede1f..7322e208 100644 --- a/syncapi/consumers/keychange_test.go +++ b/syncapi/consumers/keychange_test.go @@ -23,6 +23,10 @@ func (s *mockCurrentStateAPI) QueryCurrentState(ctx context.Context, req *api.Qu return nil } +func (s *mockCurrentStateAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + return nil +} + // QueryRoomsForUser retrieves a list of room IDs matching the given query. func (s *mockCurrentStateAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { return nil diff --git a/sytest-whitelist b/sytest-whitelist index 5087186b..388f95e0 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -415,4 +415,8 @@ We don't send redundant membership state across incremental syncs by default Typing notifications don't leak Users cannot kick users from a room they are not in Users cannot kick users who have already left a room +User appears in user directory +User directory correctly update on display name change +User in shared private room does appear in user directory +User in dir while user still shares private rooms Can get 'm.room.name' state for a departed room (SPEC-216) diff --git a/userapi/api/api.go b/userapi/api/api.go index bd0773f8..5791403f 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -31,6 +32,7 @@ type UserInternalAPI interface { QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error + QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error } // InputAccountDataRequest is the request for InputAccountData @@ -112,6 +114,20 @@ type QueryProfileResponse struct { AvatarURL string } +// QuerySearchProfilesRequest is the request for QueryProfile +type QuerySearchProfilesRequest struct { + // The search string to match + SearchString string + // How many results to return + Limit int +} + +// QuerySearchProfilesResponse is the response for QuerySearchProfilesRequest +type QuerySearchProfilesResponse struct { + // Profiles matching the search + Profiles []authtypes.Profile +} + // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { AccountType AccountType // Required: whether this is a guest or user account diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 2de8f960..5b154196 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -125,6 +125,15 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil return nil } +func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { + profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) + if err != nil { + return err + } + res.Profiles = profiles + return nil +} + func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index b2b42823..3e1ac066 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -31,11 +31,12 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" - QueryProfilePath = "/userapi/queryProfile" - QueryAccessTokenPath = "/userapi/queryAccessToken" - QueryDevicesPath = "/userapi/queryDevices" - QueryAccountDataPath = "/userapi/queryAccountData" - QueryDeviceInfosPath = "/userapi/queryDeviceInfos" + QueryProfilePath = "/userapi/queryProfile" + QueryAccessTokenPath = "/userapi/queryAccessToken" + QueryDevicesPath = "/userapi/queryDevices" + QueryAccountDataPath = "/userapi/queryAccountData" + QueryDeviceInfosPath = "/userapi/queryDeviceInfos" + QuerySearchProfilesPath = "/userapi/querySearchProfiles" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -141,3 +142,11 @@ func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.Que apiURL := h.apiURL + QueryAccountDataPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles") + defer span.Finish() + + apiURL := h.apiURL + QuerySearchProfilesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index d8e151ad..d29f4d44 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -117,4 +117,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryDeviceInfosPath, + httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse { + request := api.QuerySearchProfilesRequest{} + response := api.QuerySearchProfilesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 6f6caf11..86b91e60 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -49,6 +49,7 @@ type Database interface { GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go index d2cbeb8e..14b12c35 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/accounts/postgres/profile_table.go @@ -17,8 +17,10 @@ package postgres import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" ) const profilesSchema = ` @@ -45,11 +47,15 @@ const setAvatarURLSQL = "" + const setDisplayNameSQL = "" + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" +const selectProfilesBySearchSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + type profilesStatements struct { insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt setDisplayNameStmt *sql.Stmt + selectProfilesBySearchStmt *sql.Stmt } func (s *profilesStatements) prepare(db *sql.DB) (err error) { @@ -69,6 +75,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { return } + if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { + return + } return } @@ -105,3 +114,25 @@ func (s *profilesStatements) setDisplayName( _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } + +func (s *profilesStatements) selectProfilesBySearch( + ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + var profiles []authtypes.Profile + // The fmt.Sprintf directive below is building a parameter for the + // "LIKE" condition in the SQL query. %% escapes the % char, so the + // statement in the end will look like "LIKE %searchString%". + rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") + for rows.Next() { + var profile authtypes.Profile + if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + return nil, err + } + profiles = append(profiles, profile) + } + return profiles, nil +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index c76b92f1..f56fb6d8 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -298,3 +298,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, ) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } + +// SearchProfiles returns all profiles where the provided localpart or display name +// match any part of the profiles in the database. +func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + return d.profiles.selectProfilesBySearch(ctx, searchString, limit) +} diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index 68cea516..d4c404ca 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -17,8 +17,10 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" ) @@ -46,6 +48,9 @@ const setAvatarURLSQL = "" + const setDisplayNameSQL = "" + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" +const selectProfilesBySearchSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + type profilesStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -53,6 +58,7 @@ type profilesStatements struct { selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt setDisplayNameStmt *sql.Stmt + selectProfilesBySearchStmt *sql.Stmt } func (s *profilesStatements) prepare(db *sql.DB) (err error) { @@ -74,6 +80,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { return } + if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { + return + } return } @@ -112,3 +121,25 @@ func (s *profilesStatements) setDisplayName( _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } + +func (s *profilesStatements) selectProfilesBySearch( + ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + var profiles []authtypes.Profile + // The fmt.Sprintf directive below is building a parameter for the + // "LIKE" condition in the SQL query. %% escapes the % char, so the + // statement in the end will look like "LIKE %searchString%". + rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") + for rows.Next() { + var profile authtypes.Profile + if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + return nil, err + } + profiles = append(profiles, profile) + } + return profiles, nil +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 2d09090f..72239014 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -343,3 +343,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, ) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } + +// SearchProfiles returns all profiles where the provided localpart or display name +// match any part of the profiles in the database. +func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + return d.profiles.selectProfilesBySearch(ctx, searchString, limit) +}