package helpers import ( "context" "database/sql" "errors" "fmt" "strings" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // TODO: temporary package which has helper functions used by both internal/perform packages. // Move these to a more sensible place. func UpdateToInviteMembership( mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, roomVersion gomatrixserverlib.RoomVersion, ) ([]api.OutputEvent, error) { // We may have already sent the invite to the user, either because we are // reprocessing this event, or because the we received this invite from a // remote server via the federation invite API. In those cases we don't need // to send the event. needsSending, err := mu.SetToInvite(*add) if err != nil { return nil, err } if needsSending { // We notify the consumers using a special event even though we will // notify them about the change in current state as part of the normal // room event stream. This ensures that the consumers only have to // consider a single stream of events when determining whether a user // is invited, rather than having to combine multiple streams themselves. onie := api.OutputNewInviteEvent{ Event: add.Headered(roomVersion), RoomVersion: roomVersion, } updates = append(updates, api.OutputEvent{ Type: api.OutputTypeNewInviteEvent, NewInviteEvent: &onie, }) } return updates, nil } // IsServerCurrentlyInRoom checks if a server is in a given room, based on the room // memberships. If the servername is not supplied then the local server will be // checked instead using a faster code path. // TODO: This should probably be replaced by an API call. func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { info, err := db.RoomInfo(ctx, roomID) if err != nil { return false, err } if info == nil { return false, fmt.Errorf("unknown room %s", roomID) } if serverName == "" { return db.GetLocalServerInRoom(ctx, info.RoomNID) } eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) if err != nil { return false, err } events, err := db.Events(ctx, eventNIDs) if err != nil { return false, err } gmslEvents := make([]*gomatrixserverlib.Event, len(events)) for i := range events { gmslEvents[i] = events[i].Event } return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil } func IsInvitePending( ctx context.Context, db storage.Database, roomID, userID string, ) (bool, string, string, error) { // Look up the room NID for the supplied room ID. info, err := db.RoomInfo(ctx, roomID) if err != nil { return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err) } if info == nil { return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID) } // Look up the state key NID for the supplied user ID. targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) if err != nil { return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) } targetUserNID, targetUserFound := targetUserNIDs[userID] if !targetUserFound { return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) } // Let's see if we have an event active for the user in the room. If // we do then it will contain a server name that we can direct the // send_leave to. senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID) if err != nil { return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err) } if len(senderUserNIDs) == 0 { return false, "", "", nil } userNIDToEventID := make(map[types.EventStateKeyNID]string) for i, nid := range senderUserNIDs { userNIDToEventID[nid] = eventIDs[i] } // Look up the user ID from the NID. senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs) if err != nil { return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err) } if len(senderUsers) == 0 { return false, "", "", fmt.Errorf("no senderUsers") } senderUser, senderUserFound := senderUsers[senderUserNIDs[0]] if !senderUserFound { return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) } return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil } // GetMembershipsAtState filters the state events to // only keep the "m.room.member" events with a "join" membership. These events are returned. // Returns an error if there was an issue fetching the events. func GetMembershipsAtState( ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, ) ([]types.Event, error) { var eventNIDs []types.EventNID for _, entry := range stateEntries { // Filter the events to retrieve to only keep the membership events if entry.EventTypeNID == types.MRoomMemberNID { eventNIDs = append(eventNIDs, entry.EventNID) } } // Get all of the events in this state stateEvents, err := db.Events(ctx, eventNIDs) if err != nil { return nil, err } if !joinedOnly { return stateEvents, nil } // Filter the events to only keep the "join" membership events var events []types.Event for _, event := range stateEvents { membership, err := event.Membership() if err != nil { return nil, err } if membership == gomatrixserverlib.Join { events = append(events, event) } } return events, nil } func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { roomState := state.NewStateResolution(db, info) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { return nil, err } eventIDs := []string{eIDs[eventNID]} prevState, err := db.StateAtEventIDs(ctx, eventIDs) if err != nil { return nil, err } // Fetch the state as it was when this event was fired return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } func LoadEvents( ctx context.Context, db storage.Database, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { stateEvents, err := db.Events(ctx, eventNIDs) if err != nil { return nil, err } result := make([]*gomatrixserverlib.Event, len(stateEvents)) for i := range stateEvents { result[i] = stateEvents[i].Event } return result, nil } func LoadStateEvents( ctx context.Context, db storage.Database, stateEntries []types.StateEntry, ) ([]*gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID } return LoadEvents(ctx, db, eventNIDs) } func CheckServerAllowedToSeeEvent( ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil } return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err) } // Extract all of the event state key NIDs from the room state. var stateKeyNIDs []types.EventStateKeyNID for _, entry := range stateEntries { stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID) } // Then request those state key NIDs from the database. stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs) if err != nil { return false, fmt.Errorf("db.EventStateKeys: %w", err) } // If the event state key doesn't match the given servername // then we'll filter it out. This does preserve state keys that // are "" since these will contain history visibility etc. for nid, key := range stateKeys { if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { delete(stateKeys, nid) } } // Now filter through all of the state events for the room. // If the state key NID appears in the list of valid state // keys then we'll add it to the list of filtered entries. var filteredEntries []types.StateEntry for _, entry := range stateEntries { if _, ok := stateKeys[entry.EventStateKeyNID]; ok { filteredEntries = append(filteredEntries, entry) } } if len(filteredEntries) == 0 { return false, nil } stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries) if err != nil { return false, err } return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil } // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int, serverName gomatrixserverlib.ServerName, ) ([]types.EventNID, error) { var resultNIDs []types.EventNID var err error var allowed bool var events []types.Event var next []string var pre string // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be) // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in // duplicate events being sent in response to /backfill requests. initialIgnoreList := make(map[string]bool, len(visited)) for k, v := range visited { initialIgnoreList[k] = v } resultNIDs = make([]types.EventNID, 0, limit) var checkedServerInRoom bool var isServerInRoom bool // Loop through the event IDs to retrieve the requested events and go // through the whole tree (up to the provided limit) using the events' // "prev_event" key. BFSLoop: for len(front) > 0 { // Prevent unnecessary allocations: reset the slice only when not empty. if len(next) > 0 { next = make([]string, 0) } // Retrieve the events to process from the database. events, err = db.EventsFromIDs(ctx, front) if err != nil { return resultNIDs, err } if !checkedServerInRoom && len(events) > 0 { // It's nasty that we have to extract the room ID from an event, but many federation requests // only talk in event IDs, no room IDs at all (!!!) ev := events[0] isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID()) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") } checkedServerInRoom = true } for _, ev := range events { // Break out of the loop if the provided limit is reached. if len(resultNIDs) == limit { break BFSLoop } if !initialIgnoreList[ev.EventID()] { // Update the list of events to retrieve. resultNIDs = append(resultNIDs, ev.EventNID) } // Loop through the event's parents. for _, pre = range ev.PrevEventIDs() { // Only add an event to the list of next events to process if it // hasn't been seen before. if !visited[pre] { visited[pre] = true allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", ) // drop the error, as we will often error at the DB level if we don't have the prev_event itself. Let's // just return what we have. return resultNIDs, nil } // If the event hasn't been seen before and the HS // requesting to retrieve it is allowed to do so, add it to // the list of events to retrieve. if allowed { next = append(next, pre) } else { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event") } } } } // Repeat the same process with the parent events we just processed. front = next } return resultNIDs, err } func QueryLatestEventsAndState( ctx context.Context, db storage.Database, request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { roomInfo, err := db.RoomInfo(ctx, request.RoomID) if err != nil { return err } if roomInfo == nil || roomInfo.IsStub { response.RoomExists = false return nil } roomState := state.NewStateResolution(db, *roomInfo) response.RoomExists = true response.RoomVersion = roomInfo.RoomVersion var currentStateSnapshotNID types.StateSnapshotNID response.LatestEvents, currentStateSnapshotNID, response.Depth, err = db.LatestEventIDs(ctx, roomInfo.RoomNID) if err != nil { return err } var stateEntries []types.StateEntry if len(request.StateToFetch) == 0 { // Look up all room state. stateEntries, err = roomState.LoadStateAtSnapshot( ctx, currentStateSnapshotNID, ) } else { // Look up the current state for the requested tuples. stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples( ctx, currentStateSnapshotNID, request.StateToFetch, ) } if err != nil { return err } stateEvents, err := LoadStateEvents(ctx, db, stateEntries) if err != nil { return err } for _, event := range stateEvents { response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) } return nil }