From c8408a6387f6d155fe4e0547cbfdde7123832c84 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 9 Jul 2021 16:36:45 +0100 Subject: [PATCH] Add more optimised code path for checking if we're in a room (#1909) * Add more optimised code path for checking if we're in a room * Fix database queries * Fix federation API test * Fix logging * Review comments * Make separate API call for room membership --- federationapi/routing/send.go | 27 ++++++++++++------- federationapi/routing/send_test.go | 4 ++- roomserver/api/query.go | 6 +++-- roomserver/internal/api.go | 1 + roomserver/internal/query/query.go | 11 ++++++++ roomserver/storage/interface.go | 2 ++ .../storage/postgres/membership_table.go | 23 ++++++++++++++++ roomserver/storage/shared/storage.go | 5 ++++ .../storage/sqlite3/membership_table.go | 23 ++++++++++++++++ roomserver/storage/tables/interface.go | 1 + 10 files changed, 90 insertions(+), 13 deletions(-) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 1c9e72bf..5f214e0f 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -573,6 +573,23 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) t.work = "" // reset from previous event + // Ask the roomserver if we know about the room and/or if we're joined + // to it. If we aren't then we won't bother processing the event. + joinedReq := api.QueryServerJoinedToRoomRequest{ + RoomID: e.RoomID(), + } + var joinedRes api.QueryServerJoinedToRoomResponse + if err := t.rsAPI.QueryServerJoinedToRoom(ctx, &joinedReq, &joinedRes); err != nil { + return fmt.Errorf("t.rsAPI.QueryServerJoinedToRoom: %w", err) + } + + if !joinedRes.RoomExists || !joinedRes.IsInRoom { + // We don't believe we're a member of this room, therefore there's + // no point in wasting work trying to figure out what to do with + // missing auth or prev events. Drop the event. + return roomNotFoundError{e.RoomID()} + } + // Work out if the roomserver knows everything it needs to know to auth // the event. This includes the prev_events and auth_events. // NOTE! This is going to include prev_events that have an empty state @@ -589,16 +606,6 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e return fmt.Errorf("t.rsAPI.QueryMissingAuthPrevEvents: %w", err) } - if !stateResp.RoomExists { - // TODO: When synapse receives a message for a room it is not in it - // asks the remote server for the state of the room so that it can - // check if the remote server knows of a join "m.room.member" event - // that this server is unaware of. - // However generally speaking we should reject events for rooms we - // aren't a member of. - return roomNotFoundError{e.RoomID()} - } - // Prepare a map of all the events we already had before this point, so // that we don't send them to the roomserver again. for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) { diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 98ff1a0a..0da06aa9 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -190,7 +190,9 @@ func (t *testRoomserverAPI) QueryServerJoinedToRoom( request *api.QueryServerJoinedToRoomRequest, response *api.QueryServerJoinedToRoomResponse, ) error { - return fmt.Errorf("not implemented") + response.RoomExists = true + response.IsInRoom = true + return nil } // Query whether a server is allowed to see an event diff --git a/roomserver/api/query.go b/roomserver/api/query.go index af35f7e7..c70db65c 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -170,7 +170,8 @@ type QueryMembershipsForRoomResponse struct { // QueryServerJoinedToRoomRequest is a request to QueryServerJoinedToRoom type QueryServerJoinedToRoomRequest struct { - // Server name of the server to find + // Server name of the server to find. If not specified, we will + // default to checking if the local server is joined. ServerName gomatrixserverlib.ServerName `json:"server_name"` // ID of the room to see if we are still joined to RoomID string `json:"room_id"` @@ -182,7 +183,8 @@ type QueryServerJoinedToRoomResponse struct { RoomExists bool `json:"room_exists"` // True if we still believe that we are participating in the room IsInRoom bool `json:"is_in_room"` - // List of servers that are also in the room + // List of servers that are also in the room. This will not be populated + // if the queried ServerName is the local server name. ServerNames []gomatrixserverlib.ServerName `json:"server_names"` } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index c9f92f9f..b05a931f 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -59,6 +59,7 @@ func NewRoomserverAPI( Queryer: &query.Queryer{ DB: roomserverDB, Cache: caches, + ServerName: cfg.Matrix.ServerName, ServerACLs: serverACLs, }, Inputer: &input.Inputer{ diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 408f9766..ccd09372 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -36,6 +36,7 @@ import ( type Queryer struct { DB storage.Database Cache caching.RoomServerCaches + ServerName gomatrixserverlib.ServerName ServerACLs *acls.ServerACLs } @@ -328,6 +329,16 @@ func (r *Queryer) QueryServerJoinedToRoom( } response.RoomExists = true + if request.ServerName == r.ServerName || request.ServerName == "" { + var joined bool + joined, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) + if err != nil { + return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) + } + response.IsInRoom = joined + return nil + } + eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) if err != nil { return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index d2b0e75c..c25820aa 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -154,6 +154,8 @@ type Database interface { GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. + GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownRooms returns a list of all rooms we know about. diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 3466da6d..9102f26a 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -124,6 +124,14 @@ var selectKnownUsersSQL = "" + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" +// selectLocalServerInRoomSQL is an optimised case for checking if we, the local server, +// are in the room by using the target_local column of the membership table. Normally when +// we want to know if a server is in a room, we have to unmarshal the entire room state which +// is expensive. The presence of a single row from this query suggests we're still in the +// room, no rows returned suggests we aren't. +const selectLocalServerInRoomSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE target_local = true AND membership_nid = $1 AND room_nid = $2 LIMIT 1" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -137,6 +145,7 @@ type membershipStatements struct { selectJoinedUsersSetForRoomsStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt + selectLocalServerInRoomStmt *sql.Stmt } func createMembershipTable(db *sql.DB) error { @@ -160,6 +169,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, + {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, }.Prepare(db) } @@ -324,3 +334,16 @@ func (s *membershipStatements) UpdateForgetMembership( ) return err } + +func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { + var nid types.RoomNID + err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + found := nid > 0 + return found, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e77d62e0..9d9434cb 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1059,6 +1059,11 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) return result, nil } +// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. +func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { + return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) +} + // GetKnownUsers searches all users that userID knows about. func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index d9fe32cf..82babe0d 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -100,6 +100,14 @@ var selectKnownUsersSQL = "" + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" +// selectLocalServerInRoomSQL is an optimised case for checking if we, the local server, +// are in the room by using the target_local column of the membership table. Normally when +// we want to know if a server is in a room, we have to unmarshal the entire room state which +// is expensive. The presence of a single row from this query suggests we're still in the +// room, no rows returned suggests we aren't. +const selectLocalServerInRoomSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE target_local = 1 AND membership_nid = $1 AND room_nid = $2 LIMIT 1" + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -113,6 +121,7 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt + selectLocalServerInRoomStmt *sql.Stmt } func createMembershipTable(db *sql.DB) error { @@ -137,6 +146,7 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, + {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, }.Prepare(db) } @@ -304,3 +314,16 @@ func (s *membershipStatements) UpdateForgetMembership( ) return err } + +func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { + var nid types.RoomNID + err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + found := nid > 0 + return found, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index dd486873..4a893663 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -135,6 +135,7 @@ type Membership interface { SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error + SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) } type Published interface {