diff --git a/roomserver/query/query.go b/roomserver/query/query.go index 52b678ac..2de8e0d0 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -637,12 +637,6 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( request *api.QueryStateAndAuthChainRequest, response *api.QueryStateAndAuthChainResponse, ) error { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) - if err != nil { - return err - } - response.QueryStateAndAuthChainRequest = *request roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { @@ -653,31 +647,21 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( } response.RoomExists = true - prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) + stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs) if err != nil { - switch err.(type) { - case types.MissingEventError: - return nil - default: - return err - } + return err } response.PrevEventsExist = true - // Look up the currrent state for the requested tuples. - stateEntries, err := roomState.LoadCombinedStateAfterEvents( - ctx, prevStates, - ) - if err != nil { - return err + // add the auth event IDs for the current state events too + var authEventIDs []string + authEventIDs = append(authEventIDs, request.AuthEventIDs...) + for _, se := range stateEvents { + authEventIDs = append(authEventIDs, se.AuthEventIDs()...) } + authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - stateEvents, err := r.loadStateEvents(ctx, stateEntries) - if err != nil { - return err - } - - authEvents, err := getAuthChain(ctx, r.DB, request.AuthEventIDs) + authEvents, err := getAuthChain(ctx, r.DB, authEventIDs) if err != nil { return err } @@ -699,6 +683,34 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( return err } +func (r *RoomserverQueryAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { + // TODO: get the correct room version + roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + if err != nil { + return nil, err + } + + prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) + if err != nil { + switch err.(type) { + case types.MissingEventError: + return nil, nil + default: + return nil, err + } + } + + // Look up the currrent state for the requested tuples. + stateEntries, err := roomState.LoadCombinedStateAfterEvents( + ctx, prevStates, + ) + if err != nil { + return nil, err + } + + return r.loadStateEvents(ctx, stateEntries) +} + // getAuthChain fetches the auth chain for the given auth events. An auth chain // is the list of all events that are referenced in the auth_events section, and // all their auth_events, recursively. The returned set of events contain the