currentstate: Add QuerySharedUsers (#1217)

This will be used to determine who to send device list updates to. It
can also be used to determine who to send presence info to.
main
Kegsay 2020-07-23 12:26:31 +01:00 committed by GitHub
parent cfeb1b2f42
commit 7b862384a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 232 additions and 13 deletions

View File

@ -31,6 +31,16 @@ type CurrentStateInternalAPI interface {
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
// QueryBulkStateContent does a bulk query for state event content in the given rooms. // QueryBulkStateContent does a bulk query for state event content in the given rooms.
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
}
type QuerySharedUsersRequest struct {
UserID string
}
type QuerySharedUsersResponse struct {
UserIDs []string
} }
type QueryRoomsForUserRequest struct { type QueryRoomsForUserRequest struct {

View File

@ -16,9 +16,11 @@ package currentstateserver
import ( import (
"context" "context"
"crypto/ed25519"
"encoding/json" "encoding/json"
"net/http" "net/http"
"reflect" "reflect"
"sort"
"testing" "testing"
"time" "time"
@ -178,3 +180,112 @@ func TestQueryCurrentState(t *testing.T) {
runCases(currStateAPI) runCases(currStateAPI)
}) })
} }
func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *roomserverAPI.OutputNewRoomEvent {
eb := gomatrixserverlib.EventBuilder{
RoomID: roomID,
Sender: userID,
StateKey: &userID,
Type: "m.room.member",
Content: []byte(`{"membership":"` + membership + `"}`),
}
_, pkey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("failed to make ed25519 key: %s", err)
}
roomVer := gomatrixserverlib.RoomVersionV5
ev, err := eb.Build(
time.Now(), gomatrixserverlib.ServerName("localhost"), gomatrixserverlib.KeyID("ed25519:test"),
pkey, roomVer,
)
if err != nil {
t.Fatalf("mustMakeMembershipEvent failed: %s", err)
}
return &roomserverAPI.OutputNewRoomEvent{
Event: ev.Headered(roomVer),
AddsStateEventIDs: []string{ev.EventID()},
}
}
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
func TestQuerySharedUsers(t *testing.T) {
currStateAPI, producer := MustMakeInternalAPI(t)
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@charlie:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
testCases := []struct {
req api.QuerySharedUsersRequest
wantRes api.QuerySharedUsersResponse
}{
// Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C)
{
req: api.QuerySharedUsersRequest{
UserID: "@alice:localhost",
},
wantRes: api.QuerySharedUsersResponse{
UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"},
},
},
// Unknown user has no shared users
{
req: api.QuerySharedUsersRequest{
UserID: "@unknownuser:localhost",
},
wantRes: api.QuerySharedUsersResponse{
UserIDs: nil,
},
},
// left real user produces no shared users
{
req: api.QuerySharedUsersRequest{
UserID: "@dave:localhost",
},
wantRes: api.QuerySharedUsersResponse{
UserIDs: nil,
},
},
}
runCases := func(testAPI api.CurrentStateInternalAPI) {
for _, tc := range testCases {
var res api.QuerySharedUsersResponse
err := testAPI.QuerySharedUsers(context.Background(), &tc.req, &res)
if err != nil {
t.Errorf("QuerySharedUsers returned error: %s", err)
continue
}
sort.Strings(res.UserIDs)
sort.Strings(tc.wantRes.UserIDs)
if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) {
t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs)
}
}
}
t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
AddInternalRoutes(router, currStateAPI)
apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel()
httpAPI, err := inthttp.NewCurrentStateAPIClient(apiURL, &http.Client{})
if err != nil {
t.Fatalf("failed to create HTTP client")
}
runCases(httpAPI)
})
t.Run("Monolith", func(t *testing.T) {
runCases(currStateAPI)
})
}

View File

@ -68,3 +68,16 @@ func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req
} }
return nil return nil
} }
func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, "join")
if err != nil {
return err
}
users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs)
if err != nil {
return err
}
res.UserIDs = users
return nil
}

View File

@ -29,6 +29,7 @@ const (
QueryCurrentStatePath = "/currentstateserver/queryCurrentState" QueryCurrentStatePath = "/currentstateserver/queryCurrentState"
QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser"
QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent"
QuerySharedUsersPath = "/currentstateserver/querySharedUsers"
) )
// NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. // NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API.
@ -86,3 +87,13 @@ func (h *httpCurrentStateInternalAPI) QueryBulkStateContent(
apiURL := h.apiURL + QueryBulkStateContentPath apiURL := h.apiURL + QueryBulkStateContentPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpCurrentStateInternalAPI) QuerySharedUsers(
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
defer span.Finish()
apiURL := h.apiURL + QuerySharedUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View File

@ -64,4 +64,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QuerySharedUsersPath,
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
request := api.QuerySharedUsersRequest{}
response := api.QuerySharedUsersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QuerySharedUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View File

@ -37,4 +37,6 @@ type Database interface {
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// Redact a state event // Redact a state event
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
// JoinedUsersSetInRooms returns all joined users in the rooms given.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error)
} }

View File

@ -77,14 +77,18 @@ const selectBulkStateContentSQL = "" +
const selectBulkStateContentWildSQL = "" + const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)" "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)"
const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectEventsWithEventIDsStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectBulkStateContentStmt *sql.Stmt selectBulkStateContentStmt *sql.Stmt
selectBulkStateContentWildStmt *sql.Stmt selectBulkStateContentWildStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
} }
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -114,9 +118,29 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectBulkStateContentWildStmt, err = db.Prepare(selectBulkStateContentWildSQL); err != nil { if s.selectBulkStateContentWildStmt, err = db.Prepare(selectBulkStateContentWildSQL); err != nil {
return nil, err return nil, err
} }
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
var userIDs []string
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context, ctx context.Context,

View File

@ -85,3 +85,7 @@ func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatr
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
} }
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) {
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
}

View File

@ -66,13 +66,17 @@ const selectBulkStateContentSQL = "" +
const selectBulkStateContentWildSQL = "" + const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)" "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)"
const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -96,9 +100,34 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err return nil, err
} }
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs {
iRoomIDs[i] = v
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
var userIDs []string
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context, ctx context.Context,

View File

@ -36,6 +36,8 @@ type CurrentRoomState interface {
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error) SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.