Improved state handling in /send (#1521)
* Capture errors * Don't request only state key tuples needed for auth (we end up discarding room state this way) * QueryStateAfterEvent returns all state when no tuples supplied * Resolve state * Commentsmain
parent
20aec70ead
commit
7a1fd123de
|
@ -508,13 +508,12 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser
|
||||||
// Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
|
// Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query
|
||||||
// the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event.
|
// the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event.
|
||||||
var states []*gomatrixserverlib.RespState
|
var states []*gomatrixserverlib.RespState
|
||||||
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples()
|
|
||||||
for _, prevEventID := range backwardsExtremity.PrevEventIDs() {
|
for _, prevEventID := range backwardsExtremity.PrevEventIDs() {
|
||||||
// Look up what the state is after the backward extremity. This will either
|
// Look up what the state is after the backward extremity. This will either
|
||||||
// come from the roomserver, if we know all the required events, or it will
|
// come from the roomserver, if we know all the required events, or it will
|
||||||
// come from a remote server via /state_ids if not.
|
// come from a remote server via /state_ids if not.
|
||||||
var prevState *gomatrixserverlib.RespState
|
var prevState *gomatrixserverlib.RespState
|
||||||
prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed)
|
prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
|
util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
|
||||||
return err
|
return err
|
||||||
|
@ -573,9 +572,9 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser
|
||||||
|
|
||||||
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
|
// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event)
|
||||||
// added into the mix.
|
// added into the mix.
|
||||||
func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) {
|
func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, error) {
|
||||||
// try doing all this locally before we resort to querying federation
|
// try doing all this locally before we resort to querying federation
|
||||||
respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID, needed)
|
respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID)
|
||||||
if respState != nil {
|
if respState != nil {
|
||||||
return respState, nil
|
return respState, nil
|
||||||
}
|
}
|
||||||
|
@ -619,12 +618,11 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
|
||||||
return respState, nil
|
return respState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState {
|
func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState {
|
||||||
var res api.QueryStateAfterEventsResponse
|
var res api.QueryStateAfterEventsResponse
|
||||||
err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{
|
err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
PrevEventIDs: []string{eventID},
|
PrevEventIDs: []string{eventID},
|
||||||
StateToFetch: needed,
|
|
||||||
}, &res)
|
}, &res)
|
||||||
if err != nil || !res.PrevEventsExist {
|
if err != nil || !res.PrevEventsExist {
|
||||||
util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID)
|
util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID)
|
||||||
|
|
|
@ -63,7 +63,8 @@ type QueryStateAfterEventsRequest struct {
|
||||||
RoomID string `json:"room_id"`
|
RoomID string `json:"room_id"`
|
||||||
// The list of previous events to return the events after.
|
// The list of previous events to return the events after.
|
||||||
PrevEventIDs []string `json:"prev_event_ids"`
|
PrevEventIDs []string `json:"prev_event_ids"`
|
||||||
// The state key tuples to fetch from the state
|
// The state key tuples to fetch from the state. If none are specified then
|
||||||
|
// the entire resolved room state will be returned.
|
||||||
StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"`
|
StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -133,8 +133,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
|
|
||||||
// If the event has already been written to the output log then we
|
// If the event has already been written to the output log then we
|
||||||
// don't need to do anything, as we've handled it already.
|
// don't need to do anything, as we've handled it already.
|
||||||
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
|
if hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
|
return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
|
||||||
} else if hasBeenSent {
|
} else if hasBeenSent {
|
||||||
return nil
|
return nil
|
||||||
|
@ -142,17 +141,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
|
|
||||||
// Work out what the latest events are. This will include the new
|
// Work out what the latest events are. This will include the new
|
||||||
// event if it is not already referenced.
|
// event if it is not already referenced.
|
||||||
u.calculateLatest(
|
if err := u.calculateLatest(
|
||||||
oldLatest,
|
oldLatest,
|
||||||
types.StateAtEventAndReference{
|
types.StateAtEventAndReference{
|
||||||
EventReference: u.event.EventReference(),
|
EventReference: u.event.EventReference(),
|
||||||
StateAtEvent: u.stateAtEvent,
|
StateAtEvent: u.stateAtEvent,
|
||||||
},
|
},
|
||||||
)
|
); err != nil {
|
||||||
|
return fmt.Errorf("u.calculateLatest: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Now that we know what the latest events are, it's time to get the
|
// Now that we know what the latest events are, it's time to get the
|
||||||
// latest state.
|
// latest state.
|
||||||
if err = u.latestState(); err != nil {
|
if err := u.latestState(); err != nil {
|
||||||
return fmt.Errorf("u.latestState: %w", err)
|
return fmt.Errorf("u.latestState: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -261,7 +262,7 @@ func (u *latestEventsUpdater) latestState() error {
|
||||||
func (u *latestEventsUpdater) calculateLatest(
|
func (u *latestEventsUpdater) calculateLatest(
|
||||||
oldLatest []types.StateAtEventAndReference,
|
oldLatest []types.StateAtEventAndReference,
|
||||||
newEvent types.StateAtEventAndReference,
|
newEvent types.StateAtEventAndReference,
|
||||||
) {
|
) error {
|
||||||
var newLatest []types.StateAtEventAndReference
|
var newLatest []types.StateAtEventAndReference
|
||||||
|
|
||||||
// First of all, let's see if any of the existing forward extremities
|
// First of all, let's see if any of the existing forward extremities
|
||||||
|
@ -271,6 +272,7 @@ func (u *latestEventsUpdater) calculateLatest(
|
||||||
referenced, err := u.updater.IsReferenced(l.EventReference)
|
referenced, err := u.updater.IsReferenced(l.EventReference)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID)
|
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID)
|
||||||
|
return fmt.Errorf("u.updater.IsReferenced (old): %w", err)
|
||||||
} else if !referenced {
|
} else if !referenced {
|
||||||
newLatest = append(newLatest, l)
|
newLatest = append(newLatest, l)
|
||||||
}
|
}
|
||||||
|
@ -285,7 +287,7 @@ func (u *latestEventsUpdater) calculateLatest(
|
||||||
// We've already referenced this new event so we can just return
|
// We've already referenced this new event so we can just return
|
||||||
// the newly completed extremities at this point.
|
// the newly completed extremities at this point.
|
||||||
u.latest = newLatest
|
u.latest = newLatest
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,11 +298,13 @@ func (u *latestEventsUpdater) calculateLatest(
|
||||||
referenced, err := u.updater.IsReferenced(newEvent.EventReference)
|
referenced, err := u.updater.IsReferenced(newEvent.EventReference)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID)
|
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID)
|
||||||
|
return fmt.Errorf("u.updater.IsReferenced (new): %w", err)
|
||||||
} else if !referenced || len(newLatest) == 0 {
|
} else if !referenced || len(newLatest) == 0 {
|
||||||
newLatest = append(newLatest, newEvent)
|
newLatest = append(newLatest, newEvent)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.latest = newLatest
|
u.latest = newLatest
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
|
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
|
||||||
|
|
|
@ -49,6 +49,7 @@ func (r *Queryer) QueryLatestEventsAndState(
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryStateAfterEvents implements api.RoomserverInternalAPI
|
// QueryStateAfterEvents implements api.RoomserverInternalAPI
|
||||||
|
// nolint:gocyclo
|
||||||
func (r *Queryer) QueryStateAfterEvents(
|
func (r *Queryer) QueryStateAfterEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryStateAfterEventsRequest,
|
request *api.QueryStateAfterEventsRequest,
|
||||||
|
@ -78,10 +79,18 @@ func (r *Queryer) QueryStateAfterEvents(
|
||||||
}
|
}
|
||||||
response.PrevEventsExist = true
|
response.PrevEventsExist = true
|
||||||
|
|
||||||
// Look up the currrent state for the requested tuples.
|
var stateEntries []types.StateEntry
|
||||||
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples(
|
if len(request.StateToFetch) == 0 {
|
||||||
ctx, prevStates, request.StateToFetch,
|
// Look up all of the current room state.
|
||||||
)
|
stateEntries, err = roomState.LoadCombinedStateAfterEvents(
|
||||||
|
ctx, prevStates,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Look up the current state for the requested tuples.
|
||||||
|
stateEntries, err = roomState.LoadStateAfterEventsForStringTuples(
|
||||||
|
ctx, prevStates, request.StateToFetch,
|
||||||
|
)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -91,6 +100,24 @@ func (r *Queryer) QueryStateAfterEvents(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(request.PrevEventIDs) > 1 && len(request.StateToFetch) == 0 {
|
||||||
|
var authEventIDs []string
|
||||||
|
for _, e := range stateEvents {
|
||||||
|
authEventIDs = append(authEventIDs, e.AuthEventIDs()...)
|
||||||
|
}
|
||||||
|
authEventIDs = util.UniqueStrings(authEventIDs)
|
||||||
|
|
||||||
|
authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getAuthChain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEvents, err = state.ResolveConflictsAdhoc(info.RoomVersion, stateEvents, authEvents)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, event := range stateEvents {
|
for _, event := range stateEvents {
|
||||||
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
|
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue