diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index b02686fe..5794d40a 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -58,7 +58,8 @@ func NewInternalAPI( } queues := queue.NewOutgoingQueues( - federationSenderDB, cfg.Matrix.ServerName, federation, rsAPI, stats, + federationSenderDB, cfg.Matrix.ServerName, federation, + rsAPI, stateAPI, stats, &queue.SigningInfo{ KeyID: cfg.Matrix.KeyID, PrivateKey: cfg.Matrix.PrivateKey, diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 6d5403e8..6d856fe2 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -21,12 +21,13 @@ import ( "fmt" "sync" + stateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" ) // OutgoingQueues is a collection of queues for sending transactions to other @@ -34,6 +35,7 @@ import ( type OutgoingQueues struct { db storage.Database rsAPI api.RoomserverInternalAPI + stateAPI stateapi.CurrentStateInternalAPI origin gomatrixserverlib.ServerName client *gomatrixserverlib.FederationClient statistics *statistics.Statistics @@ -48,12 +50,14 @@ func NewOutgoingQueues( origin gomatrixserverlib.ServerName, client *gomatrixserverlib.FederationClient, rsAPI api.RoomserverInternalAPI, + stateAPI stateapi.CurrentStateInternalAPI, statistics *statistics.Statistics, signing *SigningInfo, ) *OutgoingQueues { queues := &OutgoingQueues{ db: db, rsAPI: rsAPI, + stateAPI: stateAPI, origin: origin, client: client, statistics: statistics, @@ -128,14 +132,33 @@ func (oqs *OutgoingQueues) SendEvent( ) } - // Remove our own server from the list of destinations. - destinations = filterAndDedupeDests(oqs.origin, destinations) - if len(destinations) == 0 { + // Deduplicate destinations and remove the origin from the list of + // destinations just to be sure. + destmap := map[gomatrixserverlib.ServerName]struct{}{} + for _, d := range destinations { + destmap[d] = struct{}{} + } + delete(destmap, oqs.origin) + + // Check if any of the destinations are prohibited by server ACLs. + for destination := range destmap { + if stateapi.IsServerBannedFromRoom( + context.TODO(), + oqs.stateAPI, + ev.RoomID(), + destination, + ) { + delete(destmap, destination) + } + } + + // If there are no remaining destinations then give up. + if len(destmap) == 0 { return nil } log.WithFields(log.Fields{ - "destinations": destinations, "event": ev.EventID(), + "destinations": len(destmap), "event": ev.EventID(), }).Infof("Sending event") headeredJSON, err := json.Marshal(ev) @@ -148,7 +171,7 @@ func (oqs *OutgoingQueues) SendEvent( return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) } - for _, destination := range destinations { + for destination := range destmap { oqs.getQueue(destination).sendEvent(nid) } @@ -164,7 +187,7 @@ func (oqs *OutgoingQueues) SendInvite( if stateKey == nil { log.WithFields(log.Fields{ "event_id": ev.EventID(), - }).Info("invite had no state key, dropping") + }).Info("Invite had no state key, dropping") return nil } @@ -173,7 +196,20 @@ func (oqs *OutgoingQueues) SendInvite( log.WithFields(log.Fields{ "event_id": ev.EventID(), "state_key": stateKey, - }).Info("failed to split destination from state key") + }).Info("Failed to split destination from state key") + return nil + } + + if stateapi.IsServerBannedFromRoom( + context.TODO(), + oqs.stateAPI, + ev.RoomID(), + destination, + ) { + log.WithFields(log.Fields{ + "room_id": ev.RoomID(), + "destination": destination, + }).Info("Dropping invite to server which is prohibited by ACLs") return nil } @@ -200,14 +236,40 @@ func (oqs *OutgoingQueues) SendEDU( ) } - // Remove our own server from the list of destinations. - destinations = filterAndDedupeDests(oqs.origin, destinations) - - if len(destinations) > 0 { - log.WithFields(log.Fields{ - "destinations": destinations, "edu_type": e.Type, - }).Info("Sending EDU event") + // Deduplicate destinations and remove the origin from the list of + // destinations just to be sure. + destmap := map[gomatrixserverlib.ServerName]struct{}{} + for _, d := range destinations { + destmap[d] = struct{}{} } + delete(destmap, oqs.origin) + + // There is absolutely no guarantee that the EDU will have a room_id + // field, as it is not required by the spec. However, if it *does* + // (e.g. typing notifications) then we should try to make sure we don't + // bother sending them to servers that are prohibited by the server + // ACLs. + if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() { + for destination := range destmap { + if stateapi.IsServerBannedFromRoom( + context.TODO(), + oqs.stateAPI, + result.Str, + destination, + ) { + delete(destmap, destination) + } + } + } + + // If there are no remaining destinations then give up. + if len(destmap) == 0 { + return nil + } + + log.WithFields(log.Fields{ + "destinations": len(destmap), "edu_type": e.Type, + }).Info("Sending EDU event") ephemeralJSON, err := json.Marshal(e) if err != nil { @@ -219,7 +281,7 @@ func (oqs *OutgoingQueues) SendEDU( return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) } - for _, destination := range destinations { + for destination := range destmap { oqs.getQueue(destination).sendEDU(nid) } @@ -234,21 +296,3 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { } q.wakeQueueIfNeeded() } - -// filterAndDedupeDests removes our own server from the list of destinations -// and deduplicates any servers in the list that may appear more than once. -func filterAndDedupeDests(origin gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName) ( - result []gomatrixserverlib.ServerName, -) { - strs := make([]string, len(destinations)) - for i, d := range destinations { - strs[i] = string(d) - } - for _, destination := range util.UniqueStrings(strs) { - if gomatrixserverlib.ServerName(destination) == origin { - continue - } - result = append(result, gomatrixserverlib.ServerName(destination)) - } - return result -}