diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index b48d6c0b..fee6d565 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -620,7 +620,9 @@ func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserv return gomatrixserverlib.Allowed(e, &authUsingState) } -func (t *txnReq) processEventWithMissingState(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { +func (t *txnReq) processEventWithMissingState( + ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, +) error { // Do this with a fresh context, so that we keep working even if the // original request times out. With any luck, by the time the remote // side retries, we'll have fetched the missing state. @@ -784,7 +786,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix default: return nil, false, fmt.Errorf("t.lookupEvent: %w", err) } - t.haveEvents[h.EventID()] = h + t.cacheAndReturn(h) if h.StateKey() != nil { addedToState := false for i := range respState.StateEvents { @@ -803,6 +805,14 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix return respState, false, nil } +func (t *txnReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { + if cached, exists := t.haveEvents[ev.EventID()]; exists { + return cached + } + t.haveEvents[ev.EventID()] = ev + return ev +} + func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { var res api.QueryStateAfterEventsResponse err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ @@ -810,15 +820,21 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event PrevEventIDs: []string{eventID}, }, &res) if err != nil || !res.PrevEventsExist { - util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist) return nil } + stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) for i, ev := range res.StateEvents { - t.haveEvents[ev.EventID()] = res.StateEvents[i] + // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this + // processEvent request, which is better for memory. + stateEvents[i] = t.cacheAndReturn(ev) } + // we should never access res.StateEvents again so we delete it here to make GC faster + res.StateEvents = nil + var authEvents []*gomatrixserverlib.Event missingAuthEvents := map[string]bool{} - for _, ev := range res.StateEvents { + for _, ev := range stateEvents { for _, ae := range ev.AuthEventIDs() { if aev, ok := t.haveEvents[ae]; ok { authEvents = append(authEvents, aev.Unwrap()) @@ -843,14 +859,13 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event return nil } for i := range queryRes.Events { - evID := queryRes.Events[i].EventID() - t.haveEvents[evID] = queryRes.Events[i] - authEvents = append(authEvents, queryRes.Events[i].Unwrap()) + authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) } + queryRes.Events = nil } return &gomatrixserverlib.RespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(res.StateEvents), + StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), AuthEvents: authEvents, } } @@ -860,8 +875,6 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event func (t *txnReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( *gomatrixserverlib.RespState, error) { - util.GetLogger(ctx).Infof("lookupStateBeforeEvent %s", eventID) - // Attempt to fetch the missing state using /state_ids and /events return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) } @@ -992,7 +1005,6 @@ Event: } } - // we processed everything! return newEvents, nil } @@ -1011,7 +1023,7 @@ func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( *gomatrixserverlib.RespState, error) { - util.GetLogger(ctx).Infof("lookupMissingStateViaStateIDs %s", eventID) + util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event stateIDs, err := t.federation.LookupStateIDs(ctx, t.Origin, roomID, eventID) if err != nil { @@ -1040,14 +1052,16 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even } for i := range queryRes.Events { evID := queryRes.Events[i].EventID() - t.haveEvents[evID] = queryRes.Events[i] + t.cacheAndReturn(queryRes.Events[i]) if missing[evID] { delete(missing, evID) } } + queryRes.Events = nil // allow it to be GCed concurrentRequests := 8 missingCount := len(missing) + util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Infof("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) // If over 50% of the auth/state events from /state_ids are missing // then we'll just call /state instead, otherwise we'll just end up @@ -1112,7 +1126,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even return } haveEventsMutex.Lock() - t.haveEvents[h.EventID()] = h + t.cacheAndReturn(h) haveEventsMutex.Unlock() } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index bd563ef3..47bfb72c 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -72,10 +72,8 @@ func init() { // we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests // (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse // than being stuck behind foo.bar -// In the event that the query fails, the worker spins up a short-lived goroutine whose sole purpose is to inject the server -// name back into the channel after a certain amount of time. If in the interim the device lists have been updated, then -// the database query will return no stale lists. Reinjection into the channel continues until success or the server terminates, -// when it will be reloaded on startup. +// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is +// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried. type DeviceListUpdater struct { // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 // request to the remote server and race. @@ -297,42 +295,38 @@ func (u *DeviceListUpdater) clearChannel(userID string) { } func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { - // It's possible to get many of the same server name in the channel, so in order - // to prevent processing the same server over and over we keep track of when we - // last made a request to the server. If we get the server name during the cooloff - // period, we'll ignore the poke. - lastProcessed := make(map[gomatrixserverlib.ServerName]time.Time) - // this can't be too long else sytest will give up trying to do a test - cooloffPeriod := 500 * time.Millisecond - shouldProcess := func(srv gomatrixserverlib.ServerName) bool { - // we should process requests when now is after the last process time + cooloff - return time.Now().After(lastProcessed[srv].Add(cooloffPeriod)) - } - - // on failure, spin up a short-lived goroutine to inject the server name again. - scheduledRetries := make(map[gomatrixserverlib.ServerName]time.Time) - inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) { - time.Sleep(duration) - ch <- srv - } - - for serverName := range ch { - if !shouldProcess(serverName) { - if time.Now().Before(scheduledRetries[serverName]) { - // do not inject into the channel as we know there will be a sleeping goroutine - // which will do it after the cooloff period expires - continue - } else { - scheduledRetries[serverName] = time.Now().Add(cooloffPeriod) - go inject(serverName, cooloffPeriod) - continue + retries := make(map[gomatrixserverlib.ServerName]time.Time) + retriesMu := &sync.Mutex{} + // restarter goroutine which will inject failed servers into ch when it is time + go func() { + for { + var serversToRetry []gomatrixserverlib.ServerName + time.Sleep(time.Second) + retriesMu.Lock() + now := time.Now() + for srv, retryAt := range retries { + if now.After(retryAt) { + serversToRetry = append(serversToRetry, srv) + } + } + for _, srv := range serversToRetry { + delete(retries, srv) + } + retriesMu.Unlock() + for _, srv := range serversToRetry { + ch <- srv } } - lastProcessed[serverName] = time.Now() + }() + for serverName := range ch { waitTime, shouldRetry := u.processServer(serverName) if shouldRetry { - scheduledRetries[serverName] = time.Now().Add(waitTime) - go inject(serverName, waitTime) + retriesMu.Lock() + _, exists := retries[serverName] + if !exists { + retries[serverName] = time.Now().Add(waitTime) + } + retriesMu.Unlock() } } } @@ -380,7 +374,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam } } if failCount > 0 { - logger.WithField("total", len(userIDs)).WithField("failed", failCount).Error("failed to query device keys for some users") + logger.WithField("total", len(userIDs)).WithField("failed", failCount).WithField("wait", waitTime).Error("failed to query device keys for some users") } for _, userID := range userIDs { // always clear the channel to unblock Update calls regardless of success/failure