diff --git a/clientapi/producers/roomserver.go b/clientapi/producers/roomserver.go index 7eee83f5..f0733db9 100644 --- a/clientapi/producers/roomserver.go +++ b/clientapi/producers/roomserver.go @@ -52,9 +52,10 @@ func (c *RoomserverProducer) SendEvents( } // SendEventWithState writes an event with KindNew to the roomserver input log -// with the state at the event as KindOutlier before it. +// with the state at the event as KindOutlier before it. Will not send any event that is +// marked as `true` in haveEventIDs func (c *RoomserverProducer) SendEventWithState( - ctx context.Context, state *gomatrixserverlib.RespState, event gomatrixserverlib.HeaderedEvent, + ctx context.Context, state *gomatrixserverlib.RespState, event gomatrixserverlib.HeaderedEvent, haveEventIDs map[string]bool, ) error { outliers, err := state.Events() if err != nil { @@ -63,6 +64,9 @@ func (c *RoomserverProducer) SendEventWithState( var ires []api.InputRoomEvent for _, outlier := range outliers { + if haveEventIDs[outlier.EventID()] { + continue + } ires = append(ires, api.InputRoomEvent{ Kind: api.KindOutlier, Event: outlier.Headered(event.RoomVersion), diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 10210db6..e6f91d94 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -309,9 +309,7 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer // TODO: Attempt to fill in the gap using /get_missing_events // Attempt to fetch the missing state using /state_ids and /events - var respState *gomatrixserverlib.RespState - var err error - respState, err = t.lookupMissingStateViaStateIDs(e, roomVersion) + respState, haveEventIDs, err := t.lookupMissingStateViaStateIDs(e, roomVersion) if err != nil { // Fallback to /state util.GetLogger(t.context).WithError(err).Warn("processEventWithMissingState failed to /state_ids, falling back to /state") @@ -343,8 +341,9 @@ retryAllowedState: return err } - // pass the event along with the state to the roomserver - return t.producer.SendEventWithState(t.context, respState, e.Headered(roomVersion)) + // pass the event along with the state to the roomserver using a background context so we don't + // needlessly expire + return t.producer.SendEventWithState(context.Background(), respState, e.Headered(roomVersion), haveEventIDs) } func (t *txnReq) lookupMissingStateViaState(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) ( @@ -361,28 +360,30 @@ func (t *txnReq) lookupMissingStateViaState(e gomatrixserverlib.Event, roomVersi } func (t *txnReq) lookupMissingStateViaStateIDs(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) ( - *gomatrixserverlib.RespState, error) { + *gomatrixserverlib.RespState, map[string]bool, error) { // fetch the state event IDs at the time of the event stateIDs, err := t.federation.LookupStateIDs(t.context, t.Origin, e.RoomID(), e.EventID()) if err != nil { - return nil, err + return nil, nil, err } // fetch as many as we can from the roomserver, do them as 2 calls rather than // 1 to try to reduce the number of parameters in the bulk query this will use haveEventMap := make(map[string]*gomatrixserverlib.HeaderedEvent, len(stateIDs.StateEventIDs)) + haveEventIDs := make(map[string]bool) for _, eventList := range [][]string{stateIDs.StateEventIDs, stateIDs.AuthEventIDs} { queryReq := api.QueryEventsByIDRequest{ EventIDs: eventList, } var queryRes api.QueryEventsByIDResponse - if err := t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { - return nil, err + if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + return nil, nil, err } // allow indexing of current state by event ID for i := range queryRes.Events { haveEventMap[queryRes.Events[i].EventID()] = &queryRes.Events[i] + haveEventIDs[queryRes.Events[i].EventID()] = true } } @@ -404,26 +405,29 @@ func (t *txnReq) lookupMissingStateViaStateIDs(e gomatrixserverlib.Event, roomVe }).Info("Fetching missing state at event") for missingEventID := range missing { - txn, err := t.federation.GetEvent(t.context, t.Origin, missingEventID) + var txn gomatrixserverlib.Transaction + txn, err = t.federation.GetEvent(t.context, t.Origin, missingEventID) if err != nil { util.GetLogger(t.context).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID") - return nil, err + return nil, nil, err } for _, pdu := range txn.PDUs { - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + var event gomatrixserverlib.Event + event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) if err != nil { util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) - return nil, unmarshalError{err} + return nil, nil, unmarshalError{err} } - if err := gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { + if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) - return nil, verifySigError{event.EventID(), err} + return nil, nil, verifySigError{event.EventID(), err} } h := event.Headered(roomVersion) haveEventMap[event.EventID()] = &h } } - return t.createRespStateFromStateIDs(stateIDs, haveEventMap) + resp, err := t.createRespStateFromStateIDs(stateIDs, haveEventMap) + return resp, haveEventIDs, err } func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs, haveEventMap map[string]*gomatrixserverlib.HeaderedEvent) ( diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 5e8e503a..89d28aa1 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "reflect" + "sort" "testing" "time" @@ -79,6 +81,7 @@ func (p *testEDUProducer) InputTypingEvent( type testRoomserverAPI struct { inputRoomEvents []api.InputRoomEvent queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse } func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} @@ -138,6 +141,8 @@ func (t *testRoomserverAPI) QueryEventsByID( request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { + res := t.queryEventsByID(request) + response.Events = res.Events return nil } @@ -270,21 +275,43 @@ func (t *testRoomserverAPI) RemoveRoomAlias( return nil } -type txnFedClient struct{} +type txnFedClient struct { + state map[string]gomatrixserverlib.RespState // event_id to response + stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response + getEvent map[string]gomatrixserverlib.Transaction // event_id to response +} func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) { + r, ok := c.state[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) + return + } + res = r return } func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { + r, ok := c.stateIDs[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) + return + } + res = r return } func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { + r, ok := c.getEvent[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) + return + } + res = r return } -func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage, edus []gomatrixserverlib.EDU) *txnReq { +func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { t := &txnReq{ context: context.Background(), rsAPI: rsAPI, @@ -294,7 +321,6 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat federation: fedClient, } t.PDUs = pdus - t.EDUs = edus t.Origin = testOrigin t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) t.Destination = testDestination @@ -368,7 +394,7 @@ func TestBasicTransaction(t *testing.T) { pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus, nil) + txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) mustProcessTransaction(t, txn, nil) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } @@ -390,10 +416,136 @@ func TestTransactionFailAuthChecks(t *testing.T) { pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus, nil) + txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) mustProcessTransaction(t, txn, []string{ // expect the event to have an error testEvents[len(testEvents)-1].EventID(), }) assertInputRoomEvents(t, rsAPI.inputRoomEvents, nil) // expect no messages to be sent to the roomserver } + +// The purpose of this test is to check that when there are missing prev_events that state is fetched via /state_ids +// and /event and not /state. It works by setting PrevEventsExist=false in the roomserver query response, resulting in +// a call to /state_ids which returns the whole room state. It should attempt to fetch as many of these events from the +// roomserver FIRST, resulting in a call to QueryEventsByID. However, this will be missing the m.room.power_levels event which +// should then be requested via /event. The net result is that the transaction should succeed and there should be 2 +// new events, first the m.room.power_levels event we were missing, then the transaction PDU. +func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { + missingStateEvent := testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomPowerLevels, + StateKey: "", + }] + rsAPI := &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + // setting this to false should trigger a call to /state_ids + PrevEventsExist: false, + RoomExists: true, + StateEvents: nil, + } + }, + queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { + var res api.QueryEventsByIDResponse + for _, wantEventID := range req.EventIDs { + for _, ev := range testStateEvents { + // roomserver is missing the power levels event + if wantEventID == missingStateEvent.EventID() { + continue + } + if ev.EventID() == wantEventID { + res.Events = append(res.Events, ev) + } + } + } + res.QueryEventsByIDRequest = *req + return res + }, + } + inputEvent := testEvents[len(testEvents)-1] + var stateEventIDs []string + for _, ev := range testStateEvents { + stateEventIDs = append(stateEventIDs, ev.EventID()) + } + cli := &txnFedClient{ + // /state_ids returns all the state events + stateIDs: map[string]gomatrixserverlib.RespStateIDs{ + inputEvent.EventID(): gomatrixserverlib.RespStateIDs{ + StateEventIDs: stateEventIDs, + AuthEventIDs: stateEventIDs, + }, + }, + // /event for the missing state event returns it + getEvent: map[string]gomatrixserverlib.Transaction{ + missingStateEvent.EventID(): gomatrixserverlib.Transaction{ + PDUs: []json.RawMessage{ + missingStateEvent.JSON(), + }, + }, + }, + } + + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, cli, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{missingStateEvent, inputEvent}) +} + +// The purpose of this test is to check that when there are missing prev_events and /state_ids fails, that we fallback to +// calling /state which returns the entire room state at that event. It works by setting PrevEventsExist=false in the +// roomserver query response, resulting in a call to /state_ids which fails (unset). It should then fetch via /state. +func TestTransactionFetchMissingStateByFallbackState(t *testing.T) { + rsAPI := &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + // setting this to false should trigger a call to /state_ids + PrevEventsExist: false, + RoomExists: true, + StateEvents: nil, + } + }, + } + inputEvent := testEvents[len(testEvents)-1] + // first 5 events are the state events, in auth event order. + stateEvents := testEvents[:5] + + cli := &txnFedClient{ + // /state_ids purposefully unset + stateIDs: nil, + // /state returns the state at that event (which is the current state) + state: map[string]gomatrixserverlib.RespState{ + inputEvent.EventID(): gomatrixserverlib.RespState{ + AuthEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), + StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), + }, + }, + } + + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, cli, pdus) + mustProcessTransaction(t, txn, nil) + // the roomserver should get all state events and the new input event + // TODO: it should really be only giving the missing ones + got := rsAPI.inputRoomEvents + if len(got) != len(stateEvents)+1 { + t.Fatalf("wrong number of InputRoomEvents: got %d want %d", len(got), len(stateEvents)+1) + } + last := got[len(got)-1] + if last.Event.EventID() != inputEvent.EventID() { + t.Errorf("last event should be the input event but it wasn't. got %s want %s", last.Event.EventID(), inputEvent.EventID()) + } + gots := make([]string, len(stateEvents)) + wants := make([]string, len(stateEvents)) + for i := range stateEvents { + gots[i] = got[i].Event.EventID() + wants[i] = stateEvents[i].EventID() + } + sort.Strings(gots) + sort.Strings(wants) + if !reflect.DeepEqual(gots, wants) { + t.Errorf("state events returned mismatch, got (sorted): %+v want %+v", gots, wants) + } +}