Improved state handling in /send (#1521)

* Capture errors

* Don't request only state key tuples needed for auth (we end up discarding room state this way)

* QueryStateAfterEvent returns all state when no tuples supplied

* Resolve state

* Comments
main
Neil Alexander 2020-10-14 12:39:37 +01:00 committed by GitHub
parent 20aec70ead
commit 7a1fd123de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 18 deletions

View File

@ -508,13 +508,12 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser
// Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
// the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event.
var states []*gomatrixserverlib.RespState var states []*gomatrixserverlib.RespState
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples()
for _, prevEventID := range backwardsExtremity.PrevEventIDs() { for _, prevEventID := range backwardsExtremity.PrevEventIDs() {
// Look up what the state is after the backward extremity. This will either // Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will // come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not. // come from a remote server via /state_ids if not.
var prevState *gomatrixserverlib.RespState var prevState *gomatrixserverlib.RespState
prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
return err return err
@ -573,9 +572,9 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
// added into the mix. // added into the mix.
func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, error) {
// try doing all this locally before we resort to querying federation // try doing all this locally before we resort to querying federation
respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID, needed) respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID)
if respState != nil { if respState != nil {
return respState, nil return respState, nil
} }
@ -619,12 +618,11 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
return respState, nil return respState, nil
} }
func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState { func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState {
var res api.QueryStateAfterEventsResponse var res api.QueryStateAfterEventsResponse
err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{
RoomID: roomID, RoomID: roomID,
PrevEventIDs: []string{eventID}, PrevEventIDs: []string{eventID},
StateToFetch: needed,
}, &res) }, &res)
if err != nil || !res.PrevEventsExist { if err != nil || !res.PrevEventsExist {
util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID)

View File

@ -63,7 +63,8 @@ type QueryStateAfterEventsRequest struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
// The list of previous events to return the events after. // The list of previous events to return the events after.
PrevEventIDs []string `json:"prev_event_ids"` PrevEventIDs []string `json:"prev_event_ids"`
// The state key tuples to fetch from the state // The state key tuples to fetch from the state. If none are specified then
// the entire resolved room state will be returned.
StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"` StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"`
} }

View File

@ -133,8 +133,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// If the event has already been written to the output log then we // If the event has already been written to the output log then we
// don't need to do anything, as we've handled it already. // don't need to do anything, as we've handled it already.
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) if hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID); err != nil {
if err != nil {
return fmt.Errorf("u.updater.HasEventBeenSent: %w", err) return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
} else if hasBeenSent { } else if hasBeenSent {
return nil return nil
@ -142,17 +141,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// Work out what the latest events are. This will include the new // Work out what the latest events are. This will include the new
// event if it is not already referenced. // event if it is not already referenced.
u.calculateLatest( if err := u.calculateLatest(
oldLatest, oldLatest,
types.StateAtEventAndReference{ types.StateAtEventAndReference{
EventReference: u.event.EventReference(), EventReference: u.event.EventReference(),
StateAtEvent: u.stateAtEvent, StateAtEvent: u.stateAtEvent,
}, },
) ); err != nil {
return fmt.Errorf("u.calculateLatest: %w", err)
}
// Now that we know what the latest events are, it's time to get the // Now that we know what the latest events are, it's time to get the
// latest state. // latest state.
if err = u.latestState(); err != nil { if err := u.latestState(); err != nil {
return fmt.Errorf("u.latestState: %w", err) return fmt.Errorf("u.latestState: %w", err)
} }
@ -261,7 +262,7 @@ func (u *latestEventsUpdater) latestState() error {
func (u *latestEventsUpdater) calculateLatest( func (u *latestEventsUpdater) calculateLatest(
oldLatest []types.StateAtEventAndReference, oldLatest []types.StateAtEventAndReference,
newEvent types.StateAtEventAndReference, newEvent types.StateAtEventAndReference,
) { ) error {
var newLatest []types.StateAtEventAndReference var newLatest []types.StateAtEventAndReference
// First of all, let's see if any of the existing forward extremities // First of all, let's see if any of the existing forward extremities
@ -271,6 +272,7 @@ func (u *latestEventsUpdater) calculateLatest(
referenced, err := u.updater.IsReferenced(l.EventReference) referenced, err := u.updater.IsReferenced(l.EventReference)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID) logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID)
return fmt.Errorf("u.updater.IsReferenced (old): %w", err)
} else if !referenced { } else if !referenced {
newLatest = append(newLatest, l) newLatest = append(newLatest, l)
} }
@ -285,7 +287,7 @@ func (u *latestEventsUpdater) calculateLatest(
// We've already referenced this new event so we can just return // We've already referenced this new event so we can just return
// the newly completed extremities at this point. // the newly completed extremities at this point.
u.latest = newLatest u.latest = newLatest
return return nil
} }
} }
@ -296,11 +298,13 @@ func (u *latestEventsUpdater) calculateLatest(
referenced, err := u.updater.IsReferenced(newEvent.EventReference) referenced, err := u.updater.IsReferenced(newEvent.EventReference)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID) logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID)
return fmt.Errorf("u.updater.IsReferenced (new): %w", err)
} else if !referenced || len(newLatest) == 0 { } else if !referenced || len(newLatest) == 0 {
newLatest = append(newLatest, newEvent) newLatest = append(newLatest, newEvent)
} }
u.latest = newLatest u.latest = newLatest
return nil
} }
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) { func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {

View File

@ -49,6 +49,7 @@ func (r *Queryer) QueryLatestEventsAndState(
} }
// QueryStateAfterEvents implements api.RoomserverInternalAPI // QueryStateAfterEvents implements api.RoomserverInternalAPI
// nolint:gocyclo
func (r *Queryer) QueryStateAfterEvents( func (r *Queryer) QueryStateAfterEvents(
ctx context.Context, ctx context.Context,
request *api.QueryStateAfterEventsRequest, request *api.QueryStateAfterEventsRequest,
@ -78,10 +79,18 @@ func (r *Queryer) QueryStateAfterEvents(
} }
response.PrevEventsExist = true response.PrevEventsExist = true
// Look up the currrent state for the requested tuples. var stateEntries []types.StateEntry
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( if len(request.StateToFetch) == 0 {
ctx, prevStates, request.StateToFetch, // Look up all of the current room state.
) stateEntries, err = roomState.LoadCombinedStateAfterEvents(
ctx, prevStates,
)
} else {
// Look up the current state for the requested tuples.
stateEntries, err = roomState.LoadStateAfterEventsForStringTuples(
ctx, prevStates, request.StateToFetch,
)
}
if err != nil { if err != nil {
return err return err
} }
@ -91,6 +100,24 @@ func (r *Queryer) QueryStateAfterEvents(
return err return err
} }
if len(request.PrevEventIDs) > 1 && len(request.StateToFetch) == 0 {
var authEventIDs []string
for _, e := range stateEvents {
authEventIDs = append(authEventIDs, e.AuthEventIDs()...)
}
authEventIDs = util.UniqueStrings(authEventIDs)
authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
if err != nil {
return fmt.Errorf("getAuthChain: %w", err)
}
stateEvents, err = state.ResolveConflictsAdhoc(info.RoomVersion, stateEvents, authEvents)
if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
}
}
for _, event := range stateEvents { for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
} }