Don't prematurely stop trying to join using servers (#1041)

* Don't prematurely stop trying to join using servers

* Factor out performJoinUsingServer
main
Neil Alexander 2020-05-15 13:55:14 +01:00 committed by GitHub
parent f4f032381b
commit 5f6f8adaa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 104 additions and 81 deletions

View File

@ -52,13 +52,46 @@ func (r *FederationSenderInternalAPI) PerformJoin(
// Try each server that we were provided until we land on one that // Try each server that we were provided until we land on one that
// successfully completes the make-join send-join dance. // successfully completes the make-join send-join dance.
for _, serverName := range request.ServerNames { for _, serverName := range request.ServerNames {
if err := r.performJoinUsingServer(
ctx,
request.RoomID,
request.UserID,
request.Content,
serverName,
supportedVersions,
); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"server_name": serverName,
"room_id": request.RoomID,
}).Warnf("Failed to join room through server")
continue
}
// We're all good.
return nil
}
// If we reach here then we didn't complete a join for some reason.
return fmt.Errorf(
"failed to join user %q to room %q through %d server(s)",
request.UserID, request.RoomID, len(request.ServerNames),
)
}
func (r *FederationSenderInternalAPI) performJoinUsingServer(
ctx context.Context,
roomID, userID string,
content map[string]interface{},
serverName gomatrixserverlib.ServerName,
supportedVersions []gomatrixserverlib.RoomVersion,
) error {
// Try to perform a make_join using the information supplied in the // Try to perform a make_join using the information supplied in the
// request. // request.
respMakeJoin, err := r.federation.MakeJoin( respMakeJoin, err := r.federation.MakeJoin(
ctx, ctx,
serverName, serverName,
request.RoomID, roomID,
request.UserID, userID,
supportedVersions, supportedVersions,
) )
if err != nil { if err != nil {
@ -66,19 +99,20 @@ func (r *FederationSenderInternalAPI) PerformJoin(
r.statistics.ForServer(serverName).Failure() r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.MakeJoin: %w", err) return fmt.Errorf("r.federation.MakeJoin: %w", err)
} }
r.statistics.ForServer(serverName).Success()
// Set all the fields to be what they should be, this should be a no-op // Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd" // but it's possible that the remote server returned us something "odd"
respMakeJoin.JoinEvent.Type = gomatrixserverlib.MRoomMember respMakeJoin.JoinEvent.Type = gomatrixserverlib.MRoomMember
respMakeJoin.JoinEvent.Sender = request.UserID respMakeJoin.JoinEvent.Sender = userID
respMakeJoin.JoinEvent.StateKey = &request.UserID respMakeJoin.JoinEvent.StateKey = &userID
respMakeJoin.JoinEvent.RoomID = request.RoomID respMakeJoin.JoinEvent.RoomID = roomID
respMakeJoin.JoinEvent.Redacts = "" respMakeJoin.JoinEvent.Redacts = ""
if request.Content == nil { if content == nil {
request.Content = map[string]interface{}{} content = map[string]interface{}{}
} }
request.Content["membership"] = "join" content["membership"] = "join"
if err = respMakeJoin.JoinEvent.SetContent(request.Content); err != nil { if err = respMakeJoin.JoinEvent.SetContent(content); err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err) return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err)
} }
if err = respMakeJoin.JoinEvent.SetUnsigned(struct{}{}); err != nil { if err = respMakeJoin.JoinEvent.SetUnsigned(struct{}{}); err != nil {
@ -114,18 +148,17 @@ func (r *FederationSenderInternalAPI) PerformJoin(
respMakeJoin.RoomVersion, respMakeJoin.RoomVersion,
) )
if err != nil { if err != nil {
logrus.WithError(err).Warnf("r.federation.SendJoin failed")
r.statistics.ForServer(serverName).Failure() r.statistics.ForServer(serverName).Failure()
continue return fmt.Errorf("r.federation.SendJoin: %w", err)
} }
r.statistics.ForServer(serverName).Success()
// Check that the send_join response was valid. // Check that the send_join response was valid.
joinCtx := perform.JoinContext(r.federation, r.keyRing) joinCtx := perform.JoinContext(r.federation, r.keyRing)
if err = joinCtx.CheckSendJoinResponse( if err = joinCtx.CheckSendJoinResponse(
ctx, event, serverName, respMakeJoin, respSendJoin, ctx, event, serverName, respMakeJoin, respSendJoin,
); err != nil { ); err != nil {
logrus.WithError(err).Warnf("joinCtx.CheckSendJoinResponse failed") return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err)
continue
} }
// If we successfully performed a send_join above then the other // If we successfully performed a send_join above then the other
@ -136,22 +169,12 @@ func (r *FederationSenderInternalAPI) PerformJoin(
respSendJoin.ToRespState(), respSendJoin.ToRespState(),
event.Headered(respMakeJoin.RoomVersion), event.Headered(respMakeJoin.RoomVersion),
); err != nil { ); err != nil {
logrus.WithError(err).Warnf("r.producer.SendEventWithState failed") return fmt.Errorf("r.producer.SendEventWithState: %w", err)
continue
} }
// We're all good.
r.statistics.ForServer(serverName).Success()
return nil return nil
} }
// If we reach here then we didn't complete a join for some reason.
return fmt.Errorf(
"failed to join user %q to room %q through %d server(s)",
request.UserID, request.RoomID, len(request.ServerNames),
)
}
// PerformLeaveRequest implements api.FederationSenderInternalAPI // PerformLeaveRequest implements api.FederationSenderInternalAPI
func (r *FederationSenderInternalAPI) PerformLeave( func (r *FederationSenderInternalAPI) PerformLeave(
ctx context.Context, ctx context.Context,