diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go index b74514c1..28389ad9 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go @@ -16,9 +16,6 @@ package storage import ( "database/sql" - "encoding/json" - - "github.com/matrix-org/dendrite/clientapi/events" "github.com/matrix-org/gomatrixserverlib" ) @@ -96,7 +93,7 @@ func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) { } // JoinedMemberLists returns a map of room ID to a list of joined user IDs. -func (s *currentRoomStateStatements) JoinedMemberLists() (map[string][]string, error) { +func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, error) { rows, err := s.selectJoinedUsersStmt.Query() if err != nil { return nil, err @@ -118,7 +115,7 @@ func (s *currentRoomStateStatements) JoinedMemberLists() (map[string][]string, e } // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. -func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) { +func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) { rows, err := txn.Stmt(s.selectRoomIDsWithMembershipStmt).Query(userID, membership) if err != nil { return nil, err @@ -137,7 +134,7 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(txn *sql.Tx, us } // CurrentState returns all the current state events for the given room. -func (s *currentRoomStateStatements) CurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) { +func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) { rows, err := txn.Stmt(s.selectCurrentStateStmt).Query(roomID) if err != nil { return nil, err @@ -160,34 +157,14 @@ func (s *currentRoomStateStatements) CurrentState(txn *sql.Tx, roomID string) ([ return result, nil } -func (s *currentRoomStateStatements) UpdateRoomState(txn *sql.Tx, added []gomatrixserverlib.Event, removedEventIDs []string) error { - // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. - for _, eventID := range removedEventIDs { - _, err := txn.Stmt(s.deleteRoomStateByEventIDStmt).Exec(eventID) - if err != nil { - return err - } - } - - for _, event := range added { - if event.StateKey() == nil { - // ignore non state events - continue - } - var membership *string - if event.Type() == "m.room.member" { - var memberContent events.MemberContent - if err := json.Unmarshal(event.Content(), &memberContent); err != nil { - return err - } - membership = &memberContent.Membership - } - _, err := txn.Stmt(s.upsertRoomStateStmt).Exec( - event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, - ) - if err != nil { - return err - } - } - return nil +func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error { + _, err := txn.Stmt(s.deleteRoomStateByEventIDStmt).Exec(eventID) + return err +} + +func (s *currentRoomStateStatements) upsertRoomState(txn *sql.Tx, event gomatrixserverlib.Event, membership *string) error { + _, err := txn.Stmt(s.upsertRoomStateStmt).Exec( + event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, + ) + return err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go index 6196fa75..b3cc3925 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go @@ -95,13 +95,15 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { return } -// StateBetween returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos. +// selectStateInRange returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos. // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. -func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos types.StreamPosition) (map[string][]streamEvent, error) { +func (s *outputRoomEventsStatements) selectStateInRange( + txn *sql.Tx, oldPos, newPos types.StreamPosition, +) (map[string]map[string]bool, map[string]streamEvent, error) { rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos) if err != nil { - return nil, err + return nil, nil, err } // Fetch all the state change events for all rooms between the two positions then loop each event and: // - Keep a cache of the event by ID (99% of state change events are for the event itself) @@ -121,7 +123,7 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty delIDs pq.StringArray ) if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil { - return nil, err + return nil, nil, err } // Sanity check for deleted state and whine if we see it. We don't need to do anything // since it'll just mark the event as not being needed. @@ -137,7 +139,7 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty // TODO: Handle redacted events ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false) if err != nil { - return nil, err + return nil, nil, err } needSet := stateNeeded[ev.RoomID()] if needSet == nil { // make set if required @@ -154,56 +156,13 @@ func (s *outputRoomEventsStatements) StateBetween(txn *sql.Tx, oldPos, newPos ty eventIDToEvent[ev.EventID()] = streamEvent{ev, types.StreamPosition(streamPos)} } - return s.fetchStateEvents(txn, stateNeeded, eventIDToEvent) -} - -// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. -// Returns a map of room ID to list of events. -func (s *outputRoomEventsStatements) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent) (map[string][]streamEvent, error) { - stateBetween := make(map[string][]streamEvent) - missingEvents := make(map[string][]string) - for roomID, ids := range roomIDToEventIDSet { - events := stateBetween[roomID] - for id, need := range ids { - if !need { - continue // deleted state - } - e, ok := eventIDToEvent[id] - if ok { - events = append(events, e) - } else { - m := missingEvents[roomID] - m = append(m, id) - missingEvents[roomID] = m - } - } - stateBetween[roomID] = events - } - - if len(missingEvents) > 0 { - // This happens when add_state_ids has an event ID which is not in the provided range. - // We need to explicitly fetch them. - allMissingEventIDs := []string{} - for _, missingEvIDs := range missingEvents { - allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) - } - evs, err := s.Events(txn, allMissingEventIDs) - if err != nil { - return nil, err - } - // we know we got them all otherwise an error would've been returned, so just loop the events - for _, ev := range evs { - roomID := ev.RoomID() - stateBetween[roomID] = append(stateBetween[roomID], ev) - } - } - return stateBetween, nil + return stateNeeded, eventIDToEvent, nil } // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, // then this function should only ever be used at startup, as it will race with inserting events if it is // done afterwards. If there are no inserted events, 0 is returned. -func (s *outputRoomEventsStatements) MaxID(txn *sql.Tx) (id int64, err error) { +func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) { stmt := s.selectMaxIDStmt if txn != nil { stmt = txn.Stmt(stmt) @@ -218,7 +177,7 @@ func (s *outputRoomEventsStatements) MaxID(txn *sql.Tx) (id int64, err error) { // InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position // of the inserted event. -func (s *outputRoomEventsStatements) InsertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) { +func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) { err = txn.Stmt(s.insertEventStmt).QueryRow( event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState), ).Scan(&streamPos) @@ -226,7 +185,9 @@ func (s *outputRoomEventsStatements) InsertEvent(txn *sql.Tx, event *gomatrixser } // RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. -func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int) ([]streamEvent, error) { +func (s *outputRoomEventsStatements) selectRecentEvents( + txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int, +) ([]streamEvent, error) { rows, err := s.selectRecentEventsStmt.Query(roomID, fromPos, toPos, limit) if err != nil { return nil, err @@ -243,7 +204,7 @@ func (s *outputRoomEventsStatements) RecentEventsInRoom(txn *sql.Tx, roomID stri // Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing // from the database. -func (s *outputRoomEventsStatements) Events(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { +func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { rows, err := txn.Stmt(s.selectEventsStmt).Query(pq.StringArray(eventIDs)) if err != nil { return nil, err diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go index 326d12c4..dee0b51c 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go @@ -72,7 +72,7 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. func (d *SyncServerDatabase) AllJoinedUsersInRooms() (map[string][]string, error) { - return d.roomstate.JoinedMemberLists() + return d.roomstate.selectJoinedUsers() } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races @@ -81,7 +81,7 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms() (map[string][]string, error func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string) (streamPos types.StreamPosition, returnErr error) { returnErr = runTransaction(d.db, func(txn *sql.Tx) error { var err error - pos, err := d.events.InsertEvent(txn, ev, addStateEventIDs, removeStateEventIDs) + pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs) if err != nil { return err } @@ -97,22 +97,49 @@ func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEve // However, conflict resolution may result in there being different events being added, or even some removed. if len(removeStateEventIDs) == 0 && len(addStateEventIDs) == 1 && addStateEventIDs[0] == ev.EventID() { // common case - if err = d.roomstate.UpdateRoomState(txn, []gomatrixserverlib.Event{*ev}, nil); err != nil { - return err - } - return nil + return d.updateRoomState(txn, nil, []gomatrixserverlib.Event{*ev}) } // uncommon case: we need to fetch the full event for each event ID mentioned, then update room state - added, err := d.events.Events(txn, addStateEventIDs) + added, err := d.events.selectEvents(txn, addStateEventIDs) if err != nil { return err } - return d.roomstate.UpdateRoomState(txn, streamEventsToEvents(added), removeStateEventIDs) + + return d.updateRoomState(txn, removeStateEventIDs, streamEventsToEvents(added)) }) return } +func (d *SyncServerDatabase) updateRoomState(txn *sql.Tx, removedEventIDs []string, addedEvents []gomatrixserverlib.Event) error { + // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. + for _, eventID := range removedEventIDs { + if err := d.roomstate.deleteRoomStateByEventID(txn, eventID); err != nil { + return err + } + } + + for _, event := range addedEvents { + if event.StateKey() == nil { + // ignore non state events + continue + } + var membership *string + if event.Type() == "m.room.member" { + var memberContent events.MemberContent + if err := json.Unmarshal(event.Content(), &memberContent); err != nil { + return err + } + membership = &memberContent.Membership + } + if err := d.roomstate.upsertRoomState(txn, event, membership); err != nil { + return err + } + } + + return nil +} + // PartitionOffsets implements common.PartitionStorer func (d *SyncServerDatabase) PartitionOffsets(topic string) ([]common.PartitionOffset, error) { return d.partitions.SelectPartitionOffsets(topic) @@ -125,7 +152,7 @@ func (d *SyncServerDatabase) SetPartitionOffset(topic string, partition int32, o // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error) { - id, err := d.events.MaxID(nil) + id, err := d.events.selectMaxID(nil) if err != nil { return types.StreamPosition(0), err } @@ -156,7 +183,7 @@ func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types // This is all "okay" assuming history_visibility == "shared" which it is by default. endPos = delta.membershipPos } - recentStreamEvents, err := d.events.RecentEventsInRoom(txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom) + recentStreamEvents, err := d.events.selectRecentEvents(txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom) if err != nil { return err } @@ -197,14 +224,14 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom // but it's better to not hide the fact that this is being done in a transaction. returnErr = runTransaction(d.db, func(txn *sql.Tx) error { // Get the current stream position which we will base the sync response on. - id, err := d.events.MaxID(txn) + id, err := d.events.selectMaxID(txn) if err != nil { return err } pos := types.StreamPosition(id) // Extract room state and recent events for all rooms the user is joined to. - roomIDs, err := d.roomstate.SelectRoomIDsWithMembership(txn, userID, "join") + roomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "join") if err != nil { return err } @@ -212,13 +239,15 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom // Build up a /sync response. Add joined rooms. res = types.NewResponse(pos) for _, roomID := range roomIDs { - stateEvents, err := d.roomstate.CurrentState(txn, roomID) + stateEvents, err := d.roomstate.selectCurrentState(txn, roomID) if err != nil { return err } // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - recentStreamEvents, err := d.events.RecentEventsInRoom(txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom) + recentStreamEvents, err := d.events.selectRecentEvents( + txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, + ) if err != nil { return err } @@ -239,7 +268,7 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom func (d *SyncServerDatabase) addInvitesToResponse(txn *sql.Tx, userID string, res *types.Response) error { // Add invites - TODO: This will break over federation as they won't be in the current state table according to Mark. - roomIDs, err := d.roomstate.SelectRoomIDsWithMembership(txn, userID, "invite") + roomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "invite") if err != nil { return err } @@ -251,6 +280,49 @@ func (d *SyncServerDatabase) addInvitesToResponse(txn *sql.Tx, userID string, re return nil } +// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. +// Returns a map of room ID to list of events. +func (d *SyncServerDatabase) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent) (map[string][]streamEvent, error) { + stateBetween := make(map[string][]streamEvent) + missingEvents := make(map[string][]string) + for roomID, ids := range roomIDToEventIDSet { + events := stateBetween[roomID] + for id, need := range ids { + if !need { + continue // deleted state + } + e, ok := eventIDToEvent[id] + if ok { + events = append(events, e) + } else { + m := missingEvents[roomID] + m = append(m, id) + missingEvents[roomID] = m + } + } + stateBetween[roomID] = events + } + + if len(missingEvents) > 0 { + // This happens when add_state_ids has an event ID which is not in the provided range. + // We need to explicitly fetch them. + allMissingEventIDs := []string{} + for _, missingEvIDs := range missingEvents { + allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) + } + evs, err := d.events.selectEvents(txn, allMissingEventIDs) + if err != nil { + return nil, err + } + // we know we got them all otherwise an error would've been returned, so just loop the events + for _, ev := range evs { + roomID := ev.RoomID() + stateBetween[roomID] = append(stateBetween[roomID], ev) + } + } + return stateBetween, nil +} + func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string) ([]stateDelta, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response @@ -263,10 +335,15 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St var deltas []stateDelta // get all the state events ever between these two positions - state, err := d.events.StateBetween(txn, fromPos, toPos) + stateNeeded, eventMap, err := d.events.selectStateInRange(txn, fromPos, toPos) if err != nil { return nil, err } + state, err := d.fetchStateEvents(txn, stateNeeded, eventMap) + if err != nil { + return nil, err + } + for roomID, stateStreamEvents := range state { for _, ev := range stateStreamEvents { // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. @@ -278,7 +355,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St if membership == "join" { // send full room state down instead of a delta var allState []gomatrixserverlib.Event - allState, err = d.roomstate.CurrentState(txn, roomID) + allState, err = d.roomstate.selectCurrentState(txn, roomID) if err != nil { return nil, err } @@ -302,7 +379,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St } // Add in currently joined rooms - joinedRoomIDs, err := d.roomstate.SelectRoomIDsWithMembership(txn, userID, "join") + joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "join") if err != nil { return nil, err }