diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input_membership.go index 19b7d805..a0029a28 100644 --- a/roomserver/internal/input_membership.go +++ b/roomserver/internal/input_membership.go @@ -231,8 +231,7 @@ func updateToLeaveMembership( return updates, nil } -// membershipChanges pairs up the membership state changes from a sorted list -// of state removed and a sorted list of state added. +// membershipChanges pairs up the membership state changes. func membershipChanges(removed, added []types.StateEntry) []stateChange { changes := pairUpChanges(removed, added) var result []stateChange @@ -251,64 +250,39 @@ type stateChange struct { } // pairUpChanges pairs up the state events added and removed for each type, -// state key tuple. Assumes that removed and added are sorted. +// state key tuple. func pairUpChanges(removed, added []types.StateEntry) []stateChange { - var ai int - var ri int - var result []stateChange - for { - switch { - case ai == len(added): - // We've reached the end of the added entries. - // The rest of the removed list are events that were removed without - // an event with the same state key being added. - for _, s := range removed[ri:] { - result = append(result, stateChange{ - StateKeyTuple: s.StateKeyTuple, - removedEventNID: s.EventNID, - }) - } - return result - case ri == len(removed): - // We've reached the end of the removed entries. - // The rest of the added list are events that were added without - // an event with the same state key being removed. - for _, s := range added[ai:] { - result = append(result, stateChange{ - StateKeyTuple: s.StateKeyTuple, - addedEventNID: s.EventNID, - }) - } - return result - case added[ai].StateKeyTuple == removed[ri].StateKeyTuple: - // The tuple is in both lists so an event with that key is being - // removed and another event with the same key is being added. - result = append(result, stateChange{ - StateKeyTuple: added[ai].StateKeyTuple, - removedEventNID: removed[ri].EventNID, - addedEventNID: added[ai].EventNID, - }) - ai++ - ri++ - case added[ai].StateKeyTuple.LessThan(removed[ri].StateKeyTuple): - // The lists are sorted so the added entry being less than the - // removed entry means that the added event was added without an - // event with the same key being removed. - result = append(result, stateChange{ - StateKeyTuple: added[ai].StateKeyTuple, - addedEventNID: added[ai].EventNID, - }) - ai++ - default: - // Reaching the default case implies that the removed entry is less - // than the added entry. Since the lists are sorted this means that - // the removed event was removed without an event with the same - // key being added. - result = append(result, stateChange{ - StateKeyTuple: removed[ai].StateKeyTuple, - removedEventNID: removed[ri].EventNID, - }) - ri++ + tuples := make(map[types.StateKeyTuple]stateChange) + changes := []stateChange{} + + // First, go through the newly added state entries. + for _, add := range added { + if change, ok := tuples[add.StateKeyTuple]; ok { + // If we already have an entry, update it. + change.addedEventNID = add.EventNID + tuples[add.StateKeyTuple] = change + } else { + // Otherwise, create a new entry. + tuples[add.StateKeyTuple] = stateChange{add.StateKeyTuple, 0, add.EventNID} } } + + // Now go through the removed state entries. + for _, remove := range removed { + if change, ok := tuples[remove.StateKeyTuple]; ok { + // If we already have an entry, update it. + change.removedEventNID = remove.EventNID + tuples[remove.StateKeyTuple] = change + } else { + // Otherwise, create a new entry. + tuples[remove.StateKeyTuple] = stateChange{remove.StateKeyTuple, remove.EventNID, 0} + } + } + + // Now return the changes as an array. + for _, change := range tuples { + changes = append(changes, change) + } + + return changes } diff --git a/roomserver/internal/perform_join.go b/roomserver/internal/perform_join.go index 99e10d97..8f2f84e0 100644 --- a/roomserver/internal/perform_join.go +++ b/roomserver/internal/perform_join.go @@ -121,6 +121,22 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( return fmt.Errorf("eb.SetContent: %w", err) } + // First work out if this is in response to an existing invite. + // If it is then we avoid the situation where we might think we + // know about a room in the following section but don't know the + // latest state as all of our users have left. + isInvitePending, inviteSender, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID) + if err == nil && isInvitePending { + // Add the server of the person who invited us to the server list, + // as they should be a fairly good bet. + if _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender); ierr == nil { + req.ServerNames = append(req.ServerNames, inviterDomain) + } + + // Perform a federated room join. + return r.performFederatedJoinRoomByID(ctx, req, res) + } + // Try to construct an actual join event from the template. // If this succeeds then it is a sign that the room already exists // locally on the homeserver. @@ -178,21 +194,32 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( return fmt.Errorf("Room ID %q does not exist", req.RoomIDOrAlias) } - // Try joining by all of the supplied server names. - fedReq := fsAPI.PerformJoinRequest{ - RoomID: req.RoomIDOrAlias, // the room ID to try and join - UserID: req.UserID, // the user ID joining the room - ServerNames: req.ServerNames, // the server to try joining with - Content: req.Content, // the membership event content - } - fedRes := fsAPI.PerformJoinResponse{} - err = r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes) - if err != nil { - return fmt.Errorf("Error joining federated room: %q", err) - } + // Perform a federated room join. + return r.performFederatedJoinRoomByID(ctx, req, res) default: - return fmt.Errorf("Error joining room %q: %w", req.RoomIDOrAlias, err) + // Something else went wrong. + return fmt.Errorf("Error joining local room: %q", err) + } + + return nil +} + +func (r *RoomserverInternalAPI) performFederatedJoinRoomByID( + ctx context.Context, + req *api.PerformJoinRequest, + res *api.PerformJoinResponse, // nolint:unparam +) error { + // Try joining by all of the supplied server names. + fedReq := fsAPI.PerformJoinRequest{ + RoomID: req.RoomIDOrAlias, // the room ID to try and join + UserID: req.UserID, // the user ID joining the room + ServerNames: req.ServerNames, // the server to try joining with + Content: req.Content, // the membership event content + } + fedRes := fsAPI.PerformJoinResponse{} + if err := r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes); err != nil { + return fmt.Errorf("Error joining federated room: %q", err) } return nil diff --git a/roomserver/internal/perform_leave.go b/roomserver/internal/perform_leave.go index 422748e6..5d9b251c 100644 --- a/roomserver/internal/perform_leave.go +++ b/roomserver/internal/perform_leave.go @@ -38,7 +38,7 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID( ) error { // If there's an invite outstanding for the room then respond to // that. - isInvitePending, senderUser, err := r.isInvitePending(ctx, req, res) + isInvitePending, senderUser, err := r.isInvitePending(ctx, req.RoomID, req.UserID) if err == nil && isInvitePending { return r.performRejectInvite(ctx, req, res, senderUser) } @@ -160,23 +160,22 @@ func (r *RoomserverInternalAPI) performRejectInvite( func (r *RoomserverInternalAPI) isInvitePending( ctx context.Context, - req *api.PerformLeaveRequest, - res *api.PerformLeaveResponse, // nolint:unparam + roomID, userID string, ) (bool, string, error) { // Look up the room NID for the supplied room ID. - roomNID, err := r.DB.RoomNID(ctx, req.RoomID) + roomNID, err := r.DB.RoomNID(ctx, roomID) if err != nil { return false, "", fmt.Errorf("r.DB.RoomNID: %w", err) } // Look up the state key NID for the supplied user ID. - targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{req.UserID}) + targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID}) if err != nil { return false, "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) } - targetUserNID, targetUserFound := targetUserNIDs[req.UserID] + targetUserNID, targetUserFound := targetUserNIDs[userID] if !targetUserFound { - return false, "", fmt.Errorf("missing NID for user %q (%+v)", req.UserID, targetUserNIDs) + return false, "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) } // Let's see if we have an event active for the user in the room. If