diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index 2a70f7ed..49c53755 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -2,6 +2,7 @@ package internal import ( "context" + "sync" "time" "github.com/matrix-org/dendrite/federationsender/api" @@ -23,6 +24,7 @@ type FederationSenderInternalAPI struct { federation *gomatrixserverlib.FederationClient keyRing *gomatrixserverlib.KeyRing queues *queue.OutgoingQueues + joins sync.Map // joins currently in progress } func NewFederationSenderInternalAPI( diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index a0abf7ff..6aea296b 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -37,12 +37,32 @@ func (r *FederationSenderInternalAPI) PerformDirectoryLookup( return nil } +type federatedJoin struct { + UserID string + RoomID string +} + // PerformJoinRequest implements api.FederationSenderInternalAPI func (r *FederationSenderInternalAPI) PerformJoin( ctx context.Context, request *api.PerformJoinRequest, response *api.PerformJoinResponse, ) { + // Check that a join isn't already in progress for this user/room. + j := federatedJoin{request.UserID, request.RoomID} + if _, found := r.joins.Load(j); found { + response.LastError = &gomatrix.HTTPError{ + Code: 429, + Message: `{ + "errcode": "M_LIMIT_EXCEEDED", + "error": "There is already a federated join to this room in progress. Please wait for it to finish." + }`, // TODO: Why do none of our error types play nicely with each other? + } + return + } + r.joins.Store(j, nil) + defer r.joins.Delete(j) + // Look up the supported room versions. var supportedVersions []gomatrixserverlib.RoomVersion for version := range version.SupportedRoomVersions() { @@ -186,27 +206,47 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer( } r.statistics.ForServer(serverName).Success() - // Check that the send_join response was valid. - joinCtx := perform.JoinContext(r.federation, r.keyRing) - respState, err := joinCtx.CheckSendJoinResponse( - ctx, event, serverName, respMakeJoin, respSendJoin, - ) - if err != nil { - return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) - } + // Process the join response in a goroutine. The idea here is + // that we'll try and wait for as long as possible for the work + // to complete, but if the client does give up waiting, we'll + // still continue to process the join anyway so that we don't + // waste the effort. + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(context.Background()) + go func() { + defer cancel() - // If we successfully performed a send_join above then the other - // server now thinks we're a part of the room. Send the newly - // returned state to the roomserver to update our local view. - if err = roomserverAPI.SendEventWithRewrite( - ctx, r.rsAPI, - respState, - event.Headered(respMakeJoin.RoomVersion), - nil, - ); err != nil { - return fmt.Errorf("r.producer.SendEventWithState: %w", err) - } + // Check that the send_join response was valid. + joinCtx := perform.JoinContext(r.federation, r.keyRing) + respState, err := joinCtx.CheckSendJoinResponse( + ctx, event, serverName, respMakeJoin, respSendJoin, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "room_id": roomID, + "user_id": userID, + }).WithError(err).Error("Failed to process room join response") + return + } + // If we successfully performed a send_join above then the other + // server now thinks we're a part of the room. Send the newly + // returned state to the roomserver to update our local view. + if err = roomserverAPI.SendEventWithRewrite( + ctx, r.rsAPI, + respState, + event.Headered(respMakeJoin.RoomVersion), + nil, + ); err != nil { + logrus.WithFields(logrus.Fields{ + "room_id": roomID, + "user_id": userID, + }).WithError(err).Error("Failed to send room join response to roomserver") + return + } + }() + + <-ctx.Done() return nil }