diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index bc638d00..7b0385bc 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -3,7 +3,6 @@ package gobind import ( "context" "crypto/tls" - "encoding/hex" "fmt" "net" "net/http" @@ -25,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" - "github.com/yggdrasil-network/yggdrasil-go/src/crypto" "go.uber.org/atomic" ) @@ -169,33 +167,6 @@ func (m *DendriteMonolith) Start() { base.UseHTTPAPIs, ) - ygg.NewSession = func(serverName gomatrixserverlib.ServerName) { - logrus.Infof("Found new session %q", serverName) - time.Sleep(time.Second * 3) - req := &api.PerformServersAliveRequest{ - Servers: []gomatrixserverlib.ServerName{serverName}, - } - res := &api.PerformServersAliveResponse{} - if err := fsAPI.PerformServersAlive(context.TODO(), req, res); err != nil { - logrus.WithError(err).Warn("Failed to notify server alive due to new session") - } - } - - ygg.NotifyLinkNew(func(_ crypto.BoxPubKey, sigPubKey crypto.SigPubKey, linkType, remote string) { - serverName := hex.EncodeToString(sigPubKey[:]) - logrus.Infof("Found new peer %q", serverName) - time.Sleep(time.Second * 3) - req := &api.PerformServersAliveRequest{ - Servers: []gomatrixserverlib.ServerName{ - gomatrixserverlib.ServerName(serverName), - }, - } - res := &api.PerformServersAliveResponse{} - if err := fsAPI.PerformServersAlive(context.TODO(), req, res); err != nil { - logrus.WithError(err).Warn("Failed to notify server alive due to new session") - } - }) - // Build both ends of a HTTP multiplex. m.httpServer = &http.Server{ Addr: ":0", @@ -209,28 +180,22 @@ func (m *DendriteMonolith) Start() { Handler: base.BaseMux, } - m.Resume() -} - -func (m *DendriteMonolith) Resume() { - logrus.Info("Resuming monolith") - if listener, err := net.Listen("tcp", "localhost:65432"); err == nil { - m.listener = listener - } - if m.yggListening.CAS(false, true) { - go func() { - m.logger.Info("Listening on ", m.YggdrasilNode.DerivedServerName()) - m.logger.Fatal(m.httpServer.Serve(m.YggdrasilNode)) - m.yggListening.Store(false) - }() - } - if m.httpListening.CAS(false, true) { - go func() { - m.logger.Info("Listening on ", m.BaseURL()) - m.logger.Fatal(m.httpServer.Serve(m.listener)) - m.httpListening.Store(false) - }() - } + go func() { + logger.Info("Listening on ", ygg.DerivedServerName()) + logger.Fatal(httpServer.Serve(ygg)) + }() + go func() { + logger.Info("Listening on ", m.BaseURL()) + logger.Fatal(httpServer.Serve(m.listener)) + }() + go func() { + logrus.Info("Sending wake-up message to known nodes") + req := &api.PerformBroadcastEDURequest{} + res := &api.PerformBroadcastEDUResponse{} + if err := fsAPI.PerformBroadcastEDU(context.TODO(), req, res); err != nil { + logrus.WithError(err).Error("Failed to send wake-up message to known nodes") + } + }() } func (m *DendriteMonolith) Suspend() { diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 33bcd102..7ab90000 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -17,7 +17,6 @@ package main import ( "context" "crypto/tls" - "encoding/hex" "flag" "fmt" "net" @@ -42,7 +41,6 @@ import ( "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" - "github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/sirupsen/logrus" ) @@ -158,31 +156,6 @@ func main() { base.UseHTTPAPIs, ) - ygg.NewSession = func(serverName gomatrixserverlib.ServerName) { - logrus.Infof("Found new session %q", serverName) - req := &api.PerformServersAliveRequest{ - Servers: []gomatrixserverlib.ServerName{serverName}, - } - res := &api.PerformServersAliveResponse{} - if err := fsAPI.PerformServersAlive(context.TODO(), req, res); err != nil { - logrus.WithError(err).Warn("Failed to notify server alive due to new session") - } - } - - ygg.NotifyLinkNew(func(_ crypto.BoxPubKey, sigPubKey crypto.SigPubKey, linkType, remote string) { - serverName := hex.EncodeToString(sigPubKey[:]) - logrus.Infof("Found new peer %q", serverName) - req := &api.PerformServersAliveRequest{ - Servers: []gomatrixserverlib.ServerName{ - gomatrixserverlib.ServerName(serverName), - }, - } - res := &api.PerformServersAliveResponse{} - if err := fsAPI.PerformServersAlive(context.TODO(), req, res); err != nil { - logrus.WithError(err).Warn("Failed to notify server alive due to new session") - } - }) - // Build both ends of a HTTP multiplex. httpServer := &http.Server{ Addr: ":0", @@ -205,6 +178,14 @@ func main() { logrus.Info("Listening on ", httpBindAddr) logrus.Fatal(http.ListenAndServe(httpBindAddr, base.BaseMux)) }() + go func() { + logrus.Info("Sending wake-up message to known nodes") + req := &api.PerformBroadcastEDURequest{} + res := &api.PerformBroadcastEDUResponse{} + if err := fsAPI.PerformBroadcastEDU(context.TODO(), req, res); err != nil { + logrus.WithError(err).Error("Failed to send wake-up message to known nodes") + } + }() select {} } diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index b74468db..c5b3eb72 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -1,35 +1,13 @@ package yggconn import ( - "context" - "crypto/ed25519" - "encoding/hex" - "fmt" - "net" "net/http" - "strings" "time" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/convert" "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/gomatrixserverlib" ) -func (n *Node) yggdialer(_, address string) (net.Conn, error) { - tokens := strings.Split(address, ":") - raw, err := hex.DecodeString(tokens[0]) - if err != nil { - return nil, fmt.Errorf("hex.DecodeString: %w", err) - } - converted := convert.Ed25519PublicKeyToCurve25519(ed25519.PublicKey(raw)) - convhex := hex.EncodeToString(converted) - return n.Dial("curve25519", convhex) -} - -func (n *Node) yggdialerctx(ctx context.Context, network, address string) (net.Conn, error) { - return n.yggdialer(network, address) -} - type yggroundtripper struct { inner *http.Transport } @@ -49,7 +27,7 @@ func (n *Node) CreateClient( TLSHandshakeTimeout: 20 * time.Second, ResponseHeaderTimeout: 10 * time.Second, IdleConnTimeout: 60 * time.Second, - DialContext: n.yggdialerctx, + DialContext: n.DialerContext, }, }, ) @@ -66,7 +44,8 @@ func (n *Node) CreateFederationClient( TLSHandshakeTimeout: 20 * time.Second, ResponseHeaderTimeout: 10 * time.Second, IdleConnTimeout: 60 * time.Second, - DialContext: n.yggdialerctx, + DialContext: n.DialerContext, + TLSClientConfig: n.tlsConfig, }, }, ) diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/node.go b/cmd/dendrite-demo-yggdrasil/yggconn/node.go index 2bc300c8..67aec050 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/node.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/node.go @@ -20,6 +20,7 @@ import ( "crypto/tls" "encoding/hex" "encoding/json" + "errors" "fmt" "io/ioutil" "log" @@ -33,9 +34,7 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/convert" "github.com/matrix-org/gomatrixserverlib" - yggdrasiladmin "github.com/yggdrasil-network/yggdrasil-go/src/admin" yggdrasilconfig "github.com/yggdrasil-network/yggdrasil-go/src/config" - "github.com/yggdrasil-network/yggdrasil-go/src/crypto" yggdrasilmulticast "github.com/yggdrasil-network/yggdrasil-go/src/multicast" "github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil" @@ -46,10 +45,8 @@ type Node struct { core *yggdrasil.Core config *yggdrasilconfig.NodeConfig state *yggdrasilconfig.NodeState - admin *yggdrasiladmin.AdminSocket multicast *yggdrasilmulticast.Multicast log *gologme.Logger - packetConn *yggdrasil.PacketConn listener quic.Listener tlsConfig *tls.Config quicConfig *quic.Config @@ -58,15 +55,10 @@ type Node struct { NewSession func(remote gomatrixserverlib.ServerName) } -func (n *Node) BuildName() string { - return "dendrite" -} - -func (n *Node) BuildVersion() string { - return "dev" -} - func (n *Node) Dialer(_, address string) (net.Conn, error) { + if len(n.core.GetSwitchPeers()) == 0 { + return nil, errors.New("no peer connections available") + } tokens := strings.Split(address, ":") raw, err := hex.DecodeString(tokens[0]) if err != nil { @@ -86,12 +78,10 @@ func Setup(instanceName, storageDirectory string) (*Node, error) { n := &Node{ core: &yggdrasil.Core{}, config: yggdrasilconfig.GenerateConfig(), - admin: &yggdrasiladmin.AdminSocket{}, multicast: &yggdrasilmulticast.Multicast{}, log: gologme.New(os.Stdout, "YGG ", log.Flags()), incoming: make(chan QUICStream), } - n.core.SetBuildInfo(n) yggfile := fmt.Sprintf("%s/%s-yggdrasil.conf", storageDirectory, instanceName) if _, err := os.Stat(yggfile); !os.IsNotExist(err) { @@ -132,20 +122,22 @@ func Setup(instanceName, storageDirectory string) (*Node, error) { panic(err) } - n.packetConn = n.core.PacketConn() n.tlsConfig = n.generateTLSConfig() n.quicConfig = &quic.Config{ MaxIncomingStreams: 0, MaxIncomingUniStreams: 0, KeepAlive: true, - MaxIdleTimeout: time.Minute * 15, - HandshakeTimeout: time.Second * 15, + MaxIdleTimeout: time.Minute * 30, + HandshakeTimeout: time.Second * 30, } n.log.Println("Public curve25519:", n.core.EncryptionPublicKey()) n.log.Println("Public ed25519:", n.core.SigningPublicKey()) - go n.listenFromYgg() + go func() { + time.Sleep(time.Second) + n.listenFromYgg() + }() return n, nil } @@ -193,9 +185,11 @@ func (n *Node) KnownNodes() []gomatrixserverlib.ServerName { nodemap := map[string]struct{}{ "b5ae50589e50991dd9dd7d59c5c5f7a4521e8da5b603b7f57076272abc58b374": struct{}{}, } - for _, peer := range n.core.GetSwitchPeers() { - nodemap[hex.EncodeToString(peer.SigningKey[:])] = struct{}{} - } + /* + for _, peer := range n.core.GetSwitchPeers() { + nodemap[hex.EncodeToString(peer.SigningKey[:])] = struct{}{} + } + */ n.sessions.Range(func(_, v interface{}) bool { session, ok := v.(quic.Session) if !ok { @@ -266,11 +260,3 @@ func (n *Node) SetStaticPeer(uri string) error { } return nil } - -func (n *Node) NotifyLinkNew(f func(boxPubKey crypto.BoxPubKey, sigPubKey crypto.SigPubKey, linkType, remote string)) { - n.core.NotifyLinkNew(f) -} - -func (n *Node) NotifyLinkGone(f func(boxPubKey crypto.BoxPubKey, sigPubKey crypto.SigPubKey, linkType, remote string)) { - n.core.NotifyLinkGone(f) -} diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/session.go b/cmd/dendrite-demo-yggdrasil/yggconn/session.go index ff77e64f..0d231f6d 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/session.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/session.go @@ -24,19 +24,19 @@ import ( "encoding/hex" "encoding/pem" "errors" + "fmt" "math/big" "net" "time" "github.com/lucas-clemente/quic-go" - "github.com/matrix-org/gomatrixserverlib" "github.com/yggdrasil-network/yggdrasil-go/src/crypto" ) func (n *Node) listenFromYgg() { var err error n.listener, err = quic.Listen( - n.packetConn, // yggdrasil.PacketConn + n.core, // yggdrasil.PacketConn n.tlsConfig, // TLS config n.quicConfig, // QUIC config ) @@ -45,24 +45,25 @@ func (n *Node) listenFromYgg() { } for { + n.log.Infoln("Waiting to accept QUIC sessions") session, err := n.listener.Accept(context.TODO()) if err != nil { n.log.Println("n.listener.Accept:", err) return } - go n.listenFromQUIC(session) + if len(session.ConnectionState().PeerCertificates) != 1 { + _ = session.CloseWithError(0, "expected a peer certificate") + continue + } + address := session.ConnectionState().PeerCertificates[0].Subject.CommonName + n.log.Infoln("Accepted connection from", address) + go n.listenFromQUIC(session, address) } } -func (n *Node) listenFromQUIC(session quic.Session) { - n.sessions.Store(session.RemoteAddr().String(), session) - defer n.sessions.Delete(session.RemoteAddr()) - if n.NewSession != nil { - if len(session.ConnectionState().PeerCertificates) == 1 { - subjectName := session.ConnectionState().PeerCertificates[0].Subject.CommonName - go n.NewSession(gomatrixserverlib.ServerName(subjectName)) - } - } +func (n *Node) listenFromQUIC(session quic.Session, address string) { + n.sessions.Store(address, session) + defer n.sessions.Delete(address) for { st, err := session.AcceptStream(context.TODO()) if err != nil { @@ -107,10 +108,23 @@ func (n *Node) DialContext(ctx context.Context, network, address string) (net.Co } var pubKey crypto.BoxPubKey copy(pubKey[:], dest) + nodeID := crypto.GetNodeID(&pubKey) + nodeMask := &crypto.NodeID{} + for i := range nodeMask { + nodeMask[i] = 0xFF + } + + fmt.Println("Resolving coords") + coords, err := n.core.Resolve(nodeID, nodeMask) + if err != nil { + return nil, fmt.Errorf("n.core.Resolve: %w", err) + } + fmt.Println("Found coords:", coords) + fmt.Println("Dialling") session, err = quic.Dial( - n.packetConn, // yggdrasil.PacketConn - &pubKey, // dial address + n.core, // yggdrasil.PacketConn + coords, // dial address address, // dial SNI n.tlsConfig, // TLS config n.quicConfig, // QUIC config @@ -119,7 +133,8 @@ func (n *Node) DialContext(ctx context.Context, network, address string) (net.Co n.log.Println("n.dialer.DialContext:", err) return nil, err } - go n.listenFromQUIC(session) + fmt.Println("Dial OK") + go n.listenFromQUIC(session, address) } st, err := session.OpenStream() if err != nil { @@ -157,5 +172,9 @@ func (n *Node) generateTLSConfig() *tls.Config { Certificates: []tls.Certificate{tlsCert}, NextProtos: []string{"quic-matrix-ygg"}, InsecureSkipVerify: true, + ClientAuth: tls.RequireAnyClientCert, + GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return &tlsCert, nil + }, } } diff --git a/federationsender/api/api.go b/federationsender/api/api.go index d90ffd29..b87af0eb 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -42,6 +42,12 @@ type FederationSenderInternalAPI interface { request *PerformServersAliveRequest, response *PerformServersAliveResponse, ) error + // Broadcasts an EDU to all servers in rooms we are joined to. + PerformBroadcastEDU( + ctx context.Context, + request *PerformBroadcastEDURequest, + response *PerformBroadcastEDUResponse, + ) error } type PerformDirectoryLookupRequest struct { @@ -91,3 +97,9 @@ type QueryJoinedHostServerNamesInRoomRequest struct { type QueryJoinedHostServerNamesInRoomResponse struct { ServerNames []gomatrixserverlib.ServerName `json:"server_names"` } + +type PerformBroadcastEDURequest struct { +} + +type PerformBroadcastEDUResponse struct { +} diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 96b1149d..d9a4b963 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -308,3 +308,25 @@ func (r *FederationSenderInternalAPI) PerformServersAlive( return nil } + +// PerformServersAlive implements api.FederationSenderInternalAPI +func (r *FederationSenderInternalAPI) PerformBroadcastEDU( + ctx context.Context, + request *api.PerformBroadcastEDURequest, + response *api.PerformBroadcastEDUResponse, +) (err error) { + destinations, err := r.db.GetAllJoinedHosts(ctx) + if err != nil { + return fmt.Errorf("r.db.GetAllJoinedHosts: %w", err) + } + + edu := &gomatrixserverlib.EDU{ + Type: "org.matrix.dendrite.wakeup", + Origin: string(r.cfg.Matrix.ServerName), + } + if err = r.queues.SendEDU(edu, r.cfg.Matrix.ServerName, destinations); err != nil { + return fmt.Errorf("r.queues.SendEDU: %w", err) + } + + return nil +} diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go index 25de99cc..4d968919 100644 --- a/federationsender/inthttp/client.go +++ b/federationsender/inthttp/client.go @@ -19,6 +19,7 @@ const ( FederationSenderPerformJoinRequestPath = "/federationsender/performJoinRequest" FederationSenderPerformLeaveRequestPath = "/federationsender/performLeaveRequest" FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" + FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU" ) // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. @@ -105,3 +106,16 @@ func (h *httpFederationSenderInternalAPI) PerformDirectoryLookup( apiURL := h.federationSenderURL + FederationSenderPerformDirectoryLookupRequestPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to. +func (h *httpFederationSenderInternalAPI) PerformBroadcastEDU( + ctx context.Context, + request *api.PerformBroadcastEDURequest, + response *api.PerformBroadcastEDUResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderPerformBroadcastEDUPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go index a4f3d63d..ee05cf95 100644 --- a/federationsender/inthttp/server.go +++ b/federationsender/inthttp/server.go @@ -76,4 +76,17 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(FederationSenderPerformBroadcastEDUPath, + httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse { + var request api.PerformBroadcastEDURequest + var response api.PerformBroadcastEDUResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 4bf36c24..6fff3518 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -26,6 +26,7 @@ type Database interface { internal.PartitionStorer UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) + GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) StoreJSON(ctx context.Context, js string) (int64, error) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index c0f9a7d5..2612e7e0 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -57,10 +57,14 @@ const selectJoinedHostsSQL = "" + "SELECT event_id, server_name FROM federationsender_joined_hosts" + " WHERE room_id = $1" +const selectAllJoinedHostsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" + type joinedHostsStatements struct { - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt } func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { @@ -77,6 +81,9 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { return } + if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { + return + } return } @@ -112,6 +119,27 @@ func (s *joinedHostsStatements) selectJoinedHosts( return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } +func (s *joinedHostsStatements) selectAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + func joinedHostsFromStmt( ctx context.Context, stmt *sql.Stmt, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index 80686e09..1535ebdf 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -134,6 +134,13 @@ func (d *Database) GetJoinedHosts( return d.selectJoinedHosts(ctx, roomID) } +// GetAllJoinedHosts returns the currently joined hosts for +// all rooms known to the federation sender. +// Returns an error if something goes wrong. +func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return d.selectAllJoinedHosts(ctx) +} + // StoreJSON adds a JSON blob into the queue JSON table and returns // a NID. The NID will then be used when inserting the per-destination // metadata entries. diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index d9824658..fd9ffedc 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -56,10 +56,14 @@ const selectJoinedHostsSQL = "" + "SELECT event_id, server_name FROM federationsender_joined_hosts" + " WHERE room_id = $1" +const selectAllJoinedHostsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" + type joinedHostsStatements struct { - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt } func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { @@ -76,6 +80,9 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { return } + if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { + return + } return } @@ -115,6 +122,27 @@ func (s *joinedHostsStatements) selectJoinedHosts( return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } +func (s *joinedHostsStatements) selectAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + func joinedHostsFromStmt( ctx context.Context, stmt *sql.Stmt, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 1a4715bf..b23a2dbe 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -145,6 +145,13 @@ func (d *Database) GetJoinedHosts( return d.selectJoinedHosts(ctx, roomID) } +// GetAllJoinedHosts returns the currently joined hosts for +// all rooms known to the federation sender. +// Returns an error if something goes wrong. +func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return d.selectAllJoinedHosts(ctx) +} + // StoreJSON adds a JSON blob into the queue JSON table and returns // a NID. The NID will then be used when inserting the per-destination // metadata entries. diff --git a/go.mod b/go.mod index 2a60e3c5..d0e643fa 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/uber-go/atomic v1.3.0 // indirect github.com/uber/jaeger-client-go v2.15.0+incompatible github.com/uber/jaeger-lib v1.5.0 - github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200713083728-5a765b33d55b + github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3 go.uber.org/atomic v1.4.0 golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5 gopkg.in/h2non/bimg.v1 v1.0.18 diff --git a/go.sum b/go.sum index b6b60e9f..c41d0181 100644 --- a/go.sum +++ b/go.sum @@ -652,8 +652,8 @@ github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhe github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yggdrasil-network/yggdrasil-extras v0.0.0-20200525205615-6c8a4a2e8855/go.mod h1:xQdsh08Io6nV4WRnOVTe6gI8/2iTvfLDQ0CYa5aMt+I= -github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200713083728-5a765b33d55b h1:py36vWqSnHIQ2DQ9gC0jbkiFd9OCTQX01PdYJ7KmaQE= -github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200713083728-5a765b33d55b/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE= +github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3 h1:teLoIJgPHysREs8P6GlcS/PgaU9W9+GQndikFCQ1lY0= +github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.1/go.mod h1:Ap50jQcDJrx6rB6VgeeFPtuPIf3wMRvRfrfYDO6+BmA=