diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index 57c8cff6..b7582bf9 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -22,6 +22,7 @@ import ( "time" "github.com/matrix-org/dendrite/federationsender/storage" + "github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrix" @@ -31,8 +32,11 @@ import ( "go.uber.org/atomic" ) -const maxPDUsPerTransaction = 50 -const queueIdleTimeout = time.Second * 30 +const ( + maxPDUsPerTransaction = 50 + maxEDUsPerTransaction = 50 + queueIdleTimeout = time.Second * 30 +) // destinationQueue is a queue of events for a single destination. // It is responsible for sending the events to the destination and @@ -49,20 +53,19 @@ type destinationQueue struct { backingOff atomic.Bool // true if we're backing off statistics *types.ServerStatistics // statistics about this remote server incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send - incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send transactionIDMutex sync.Mutex // protects transactionID transactionID gomatrixserverlib.TransactionID // last transaction ID transactionCount atomic.Int32 // how many events in this transaction so far - pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend notifyPDUs chan bool // interrupts idle wait for PDUs + notifyEDUs chan bool // interrupts idle wait for EDUs interruptBackoff chan bool // interrupts backoff } // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(nid int64) { +func (oq *destinationQueue) sendEvent(receipt *shared.Receipt) { if oq.statistics.Blacklisted() { // If the destination is blacklisted then drop the event. log.Infof("%s is blacklisted; dropping event", oq.destination) @@ -86,9 +89,9 @@ func (oq *destinationQueue) sendEvent(nid int64) { context.TODO(), oq.transactionID, // the current transaction ID oq.destination, // the destination server name - []int64{nid}, // NID from federationsender_queue_json table + receipt, // NIDs from federationsender_queue_json table ); err != nil { - log.WithError(err).Errorf("failed to associate PDU NID %d with destination %q", nid, oq.destination) + log.WithError(err).Errorf("failed to associate PDU receipt %q with destination %q", receipt.String(), oq.destination) return } // We've successfully added a PDU to the transaction so increase @@ -107,13 +110,34 @@ func (oq *destinationQueue) sendEvent(nid int64) { // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(ev *gomatrixserverlib.EDU) { +func (oq *destinationQueue) sendEDU(receipt *shared.Receipt) { if oq.statistics.Blacklisted() { // If the destination is blacklisted then drop the event. + log.Infof("%s is blacklisted; dropping ephemeral event", oq.destination) return } + // Create a database entry that associates the given PDU NID with + // this destination queue. We'll then be able to retrieve the PDU + // later. + if err := oq.db.AssociateEDUWithDestination( + context.TODO(), + oq.destination, // the destination server name + receipt, // NIDs from federationsender_queue_json table + ); err != nil { + log.WithError(err).Errorf("failed to associate EDU receipt %q with destination %q", receipt.String(), oq.destination) + return + } + // We've successfully added an EDU to the transaction so increase + // the counter. + oq.transactionCount.Add(1) + // Wake up the queue if it's asleep. oq.wakeQueueIfNeeded() - oq.incomingEDUs <- ev + // If we're blocking on waiting PDUs then tell the queue that we + // have work to do. + select { + case oq.notifyEDUs <- true: + default: + } } // sendInvite adds the invite event to the pending queue for the @@ -166,6 +190,28 @@ func (oq *destinationQueue) waitForPDUs() chan bool { return oq.notifyPDUs } +// waitForEDUs returns a channel for pending EDUs, which will be +// used in backgroundSend select. It returns a closed channel if +// there is something pending right now, or an open channel if +// we're waiting for something. +func (oq *destinationQueue) waitForEDUs() chan bool { + pendingEDUs, err := oq.db.GetPendingEDUCount(context.TODO(), oq.destination) + if err != nil { + log.WithError(err).Errorf("Failed to get pending EDU count on queue %q", oq.destination) + } + // If there are EDUs pending right now then we'll return a closed + // channel. This will mean that the backgroundSend will not block. + if pendingEDUs > 0 { + ch := make(chan bool, 1) + close(ch) + return ch + } + // If there are no EDUs pending right now then instead we'll return + // the notify channel, so that backgroundSend can pick up normal + // notifications from sendEvent. + return oq.notifyEDUs +} + // backgroundSend is the worker goroutine for sending events. // nolint:gocyclo func (oq *destinationQueue) backgroundSend() { @@ -177,7 +223,7 @@ func (oq *destinationQueue) backgroundSend() { defer oq.running.Store(false) for { - pendingPDUs := false + pendingPDUs, pendingEDUs := false, false // If we have nothing to do then wait either for incoming events, or // until we hit an idle timeout. @@ -186,18 +232,10 @@ func (oq *destinationQueue) backgroundSend() { // We were woken up because there are new PDUs waiting in the // database. pendingPDUs = true - case edu := <-oq.incomingEDUs: - // EDUs are handled in-memory for now. We will try to keep - // the ordering intact. - // TODO: Certain EDU types need persistence, e.g. send-to-device - oq.pendingEDUs = append(oq.pendingEDUs, edu) - // If there are any more things waiting in the channel queue - // then read them. This is safe because we guarantee only - // having one goroutine per destination queue, so the channel - // isn't being consumed anywhere else. - for len(oq.incomingEDUs) > 0 { - oq.pendingEDUs = append(oq.pendingEDUs, <-oq.incomingEDUs) - } + case <-oq.waitForEDUs(): + // We were woken up because there are new PDUs waiting in the + // database. + pendingEDUs = true case invite := <-oq.incomingInvites: // There's no strict ordering requirement for invites like // there is for transactions, so we put the invite onto the @@ -238,16 +276,15 @@ func (oq *destinationQueue) backgroundSend() { } // If we have pending PDUs or EDUs then construct a transaction. - if pendingPDUs || len(oq.pendingEDUs) > 0 { + if pendingPDUs || pendingEDUs { // Try sending the next transaction and see what happens. - transaction, terr := oq.nextTransaction(oq.pendingEDUs) + transaction, terr := oq.nextTransaction() if terr != nil { // We failed to send the transaction. if giveUp := oq.statistics.Failure(); giveUp { // It's been suggested that we should give up because the backoff // has exceeded a maximum allowable value. Clean up the in-memory // buffers at this point. The PDU clean-up is already on a defer. - oq.cleanPendingEDUs() oq.cleanPendingInvites() log.Infof("Blacklisting %q due to errors", oq.destination) return @@ -265,8 +302,6 @@ func (oq *destinationQueue) backgroundSend() { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. oq.statistics.Success() - // Clean up the in-memory buffers. - oq.cleanPendingEDUs() } } @@ -294,15 +329,6 @@ func (oq *destinationQueue) backgroundSend() { } } -// cleanPendingEDUs cleans out the pending EDU buffer, removing -// all references so that the underlying objects can be GC'd. -func (oq *destinationQueue) cleanPendingEDUs() { - for i := 0; i < len(oq.pendingEDUs); i++ { - oq.pendingEDUs[i] = nil - } - oq.pendingEDUs = []*gomatrixserverlib.EDU{} -} - // cleanPendingInvites cleans out the pending invite buffer, // removing all references so that the underlying objects can // be GC'd. @@ -316,9 +342,8 @@ func (oq *destinationQueue) cleanPendingInvites() { // nextTransaction creates a new transaction from the pending event // queue and sends it. Returns true if a transaction was sent or // false otherwise. -func (oq *destinationQueue) nextTransaction( - pendingEDUs []*gomatrixserverlib.EDU, -) (bool, error) { +// nolint:gocyclo +func (oq *destinationQueue) nextTransaction() (bool, error) { // Before we do anything, we need to roll over the transaction // ID that is being used to coalesce events into the next TX. // Otherwise it's possible that we'll pick up an incomplete @@ -343,7 +368,7 @@ func (oq *destinationQueue) nextTransaction( // actually retrieve that many events. ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - txid, pdus, err := oq.db.GetNextTransactionPDUs( + txid, pdus, pduReceipt, err := oq.db.GetNextTransactionPDUs( ctx, // context oq.destination, // server name maxPDUsPerTransaction, // max events to retrieve @@ -353,9 +378,19 @@ func (oq *destinationQueue) nextTransaction( return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err) } + edus, eduReceipt, err := oq.db.GetNextTransactionEDUs( + ctx, // context + oq.destination, // server name + maxEDUsPerTransaction, // max events to retrieve + ) + if err != nil { + log.WithError(err).Errorf("failed to get next transaction EDUs for server %q", oq.destination) + return false, fmt.Errorf("oq.db.GetNextTransactionEDUs: %w", err) + } + // If we didn't get anything from the database and there are no // pending EDUs then there's nothing to do - stop here. - if len(pdus) == 0 && len(pendingEDUs) == 0 { + if len(pdus) == 0 && len(edus) == 0 { return false, nil } @@ -377,7 +412,7 @@ func (oq *destinationQueue) nextTransaction( } // Do the same for pending EDUS in the queue. - for _, edu := range pendingEDUs { + for _, edu := range edus { t.EDUs = append(t.EDUs, *edu) } @@ -393,12 +428,17 @@ func (oq *destinationQueue) nextTransaction( switch err.(type) { case nil: // Clean up the transaction in the database. - if err = oq.db.CleanTransactionPDUs( - context.Background(), - t.Destination, - t.TransactionID, - ); err != nil { - log.WithError(err).Errorf("failed to clean transaction %q for server %q", t.TransactionID, t.Destination) + if pduReceipt != nil { + //logrus.Infof("Cleaning PDUs %q", pduReceipt.String()) + if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipt); err != nil { + log.WithError(err).Errorf("failed to clean PDUs %q for server %q", pduReceipt.String(), t.Destination) + } + } + if eduReceipt != nil { + //logrus.Infof("Cleaning EDUs %q", eduReceipt.String()) + if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipt); err != nil { + log.WithError(err).Errorf("failed to clean EDUs %q for server %q", eduReceipt.String(), t.Destination) + } } return true, nil case gomatrix.HTTPError: diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 46c9fddb..e488a34a 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -61,12 +61,23 @@ func NewOutgoingQueues( queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, } // Look up which servers we have pending items for and then rehydrate those queues. - if serverNames, err := db.GetPendingServerNames(context.Background()); err == nil { - for _, serverName := range serverNames { - queues.getQueue(serverName).wakeQueueIfNeeded() + serverNames := map[gomatrixserverlib.ServerName]struct{}{} + if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil { + for _, serverName := range names { + serverNames[serverName] = struct{}{} } } else { - log.WithError(err).Error("Failed to get server names for destination queue hydration") + log.WithError(err).Error("Failed to get PDU server names for destination queue hydration") + } + if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil { + for _, serverName := range names { + serverNames[serverName] = struct{}{} + } + } else { + log.WithError(err).Error("Failed to get EDU server names for destination queue hydration") + } + for serverName := range serverNames { + queues.getQueue(serverName).wakeQueueIfNeeded() } return queues } @@ -91,9 +102,9 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d destination: destination, client: oqs.client, statistics: oqs.statistics.ForServer(destination), - incomingEDUs: make(chan *gomatrixserverlib.EDU, 128), incomingInvites: make(chan *gomatrixserverlib.InviteV2Request, 128), notifyPDUs: make(chan bool, 1), + notifyEDUs: make(chan bool, 1), interruptBackoff: make(chan bool), signing: oqs.signing, } @@ -196,8 +207,18 @@ func (oqs *OutgoingQueues) SendEDU( }).Info("Sending EDU event") } + ephemeralJSON, err := json.Marshal(e) + if err != nil { + return fmt.Errorf("json.Marshal: %w", err) + } + + nid, err := oqs.db.StoreJSON(context.TODO(), string(ephemeralJSON)) + if err != nil { + return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) + } + for _, destination := range destinations { - oqs.getQueue(destination).sendEDU(e) + oqs.getQueue(destination).sendEDU(nid) } return nil diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 6fff3518..1bea83e2 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -17,6 +17,7 @@ package storage import ( "context" + "github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" @@ -24,13 +25,26 @@ import ( type Database interface { internal.PartitionStorer + UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) + GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - StoreJSON(ctx context.Context, js string) (int64, error) - AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error - GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) - CleanTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID) error + + StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) + + AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error + AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error + + GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, *shared.Receipt, error) + GetNextTransactionEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) ([]*gomatrixserverlib.EDU, *shared.Receipt, error) + + CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error + CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error + GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) + + GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) } diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index 2612e7e0..af0a5258 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -61,33 +61,37 @@ const selectAllJoinedHostsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" type joinedHostsStatements struct { + db *sql.DB insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt } -func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(joinedHostsSchema) +func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { + s = &joinedHostsStatements{ + db: db, + } + _, err = s.db.Exec(joinedHostsSchema) if err != nil { return } - if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil { + if s.insertJoinedHostsStmt, err = s.db.Prepare(insertJoinedHostsSQL); err != nil { return } - if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil { + if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil { return } - if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { + if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil { return } - if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { + if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil { return } return } -func (s *joinedHostsStatements) insertJoinedHosts( +func (s *joinedHostsStatements) InsertJoinedHosts( ctx context.Context, txn *sql.Tx, roomID, eventID string, @@ -98,7 +102,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( return err } -func (s *joinedHostsStatements) deleteJoinedHosts( +func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) @@ -106,20 +110,20 @@ func (s *joinedHostsStatements) deleteJoinedHosts( return err } -func (s *joinedHostsStatements) selectJoinedHostsWithTx( +func (s *joinedHostsStatements) SelectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) return joinedHostsFromStmt(ctx, stmt, roomID) } -func (s *joinedHostsStatements) selectJoinedHosts( +func (s *joinedHostsStatements) SelectJoinedHosts( ctx context.Context, roomID string, ) ([]types.JoinedHost, error) { return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } -func (s *joinedHostsStatements) selectAllJoinedHosts( +func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, ) ([]gomatrixserverlib.ServerName, error) { rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) diff --git a/federationsender/storage/postgres/queue_edus_table.go b/federationsender/storage/postgres/queue_edus_table.go new file mode 100644 index 00000000..6cac489b --- /dev/null +++ b/federationsender/storage/postgres/queue_edus_table.go @@ -0,0 +1,198 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queueEDUsSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( + -- The type of the event (informational). + edu_type TEXT NOT NULL, + -- The domain part of the user ID the EDU event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_queue_edus_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx + ON federationsender_queue_edus (json_nid, server_name); +` + +const insertQueueEDUSQL = "" + + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEDUSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid = ANY($2)" + +const selectQueueEDUSQL = "" + + "SELECT json_nid FROM federationsender_queue_edus" + + " WHERE server_name = $1" + + " LIMIT $2" + +const selectQueueEDUReferenceJSONCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE json_nid = $1" + +const selectQueueEDUCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE server_name = $1" + +const selectQueueServerNamesSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_queue_edus" + +type queueEDUsStatements struct { + db *sql.DB + insertQueueEDUStmt *sql.Stmt + deleteQueueEDUStmt *sql.Stmt + selectQueueEDUStmt *sql.Stmt + selectQueueEDUReferenceJSONCountStmt *sql.Stmt + selectQueueEDUCountStmt *sql.Stmt + selectQueueEDUServerNamesStmt *sql.Stmt +} + +func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { + s = &queueEDUsStatements{ + db: db, + } + _, err = s.db.Exec(queueEDUsSchema) + if err != nil { + return + } + if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil { + return + } + if s.deleteQueueEDUStmt, err = s.db.Prepare(deleteQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { + return + } + if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil { + return + } + if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil { + return + } + return +} + +func (s *queueEDUsStatements) InsertQueueEDU( + ctx context.Context, + txn *sql.Tx, + eduType string, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) + _, err := stmt.ExecContext( + ctx, + eduType, // the EDU type + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *queueEDUsStatements) DeleteQueueEDUs( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueEDUStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) + return err +} + +func (s *queueEDUsStatements) SelectQueueEDUs( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, nil +} + +func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( + ctx context.Context, txn *sql.Tx, jsonNID int64, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt) + err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + if err == sql.ErrNoRows { + return -1, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUServerNames( + ctx context.Context, txn *sql.Tx, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName gomatrixserverlib.ServerName + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, serverName) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/postgres/queue_json_table.go b/federationsender/storage/postgres/queue_json_table.go index eac2ea98..85307374 100644 --- a/federationsender/storage/postgres/queue_json_table.go +++ b/federationsender/storage/postgres/queue_json_table.go @@ -48,29 +48,33 @@ const selectJSONSQL = "" + " WHERE json_nid = ANY($1)" type queueJSONStatements struct { + db *sql.DB insertJSONStmt *sql.Stmt deleteJSONStmt *sql.Stmt selectJSONStmt *sql.Stmt } -func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(queueJSONSchema) +func NewPostgresQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { + s = &queueJSONStatements{ + db: db, + } + _, err = s.db.Exec(queueJSONSchema) if err != nil { return } - if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil { + if s.insertJSONStmt, err = s.db.Prepare(insertJSONSQL); err != nil { return } - if s.deleteJSONStmt, err = db.Prepare(deleteJSONSQL); err != nil { + if s.deleteJSONStmt, err = s.db.Prepare(deleteJSONSQL); err != nil { return } - if s.selectJSONStmt, err = db.Prepare(selectJSONSQL); err != nil { + if s.selectJSONStmt, err = s.db.Prepare(selectJSONSQL); err != nil { return } return } -func (s *queueJSONStatements) insertQueueJSON( +func (s *queueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (int64, error) { stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) @@ -81,7 +85,7 @@ func (s *queueJSONStatements) insertQueueJSON( return lastid, nil } -func (s *queueJSONStatements) deleteQueueJSON( +func (s *queueJSONStatements) DeleteQueueJSON( ctx context.Context, txn *sql.Tx, nids []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) @@ -89,7 +93,7 @@ func (s *queueJSONStatements) deleteQueueJSON( return err } -func (s *queueJSONStatements) selectQueueJSON( +func (s *queueJSONStatements) SelectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { blobs := map[int64][]byte{} diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go index dab6003e..95a3b9ee 100644 --- a/federationsender/storage/postgres/queue_pdus_table.go +++ b/federationsender/storage/postgres/queue_pdus_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -41,10 +42,10 @@ const insertQueuePDUSQL = "" + "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + " VALUES ($1, $2, $3)" -const deleteQueueTransactionPDUsSQL = "" + - "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND transaction_id = $2" +const deleteQueuePDUSQL = "" + + "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)" -const selectQueueNextTransactionIDSQL = "" + +const selectQueuePDUNextTransactionIDSQL = "" + "SELECT transaction_id FROM federationsender_queue_pdus" + " WHERE server_name = $1" + " ORDER BY transaction_id ASC" + @@ -55,7 +56,7 @@ const selectQueuePDUsByTransactionSQL = "" + " WHERE server_name = $1 AND transaction_id = $2" + " LIMIT $3" -const selectQueueReferenceJSONCountSQL = "" + +const selectQueuePDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" @@ -63,49 +64,53 @@ const selectQueuePDUsCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE server_name = $1" -const selectQueueServerNamesSQL = "" + +const selectQueuePDUServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" type queuePDUsStatements struct { - insertQueuePDUStmt *sql.Stmt - deleteQueueTransactionPDUsStmt *sql.Stmt - selectQueueNextTransactionIDStmt *sql.Stmt - selectQueuePDUsByTransactionStmt *sql.Stmt - selectQueueReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt - selectQueueServerNamesStmt *sql.Stmt + db *sql.DB + insertQueuePDUStmt *sql.Stmt + deleteQueuePDUsStmt *sql.Stmt + selectQueuePDUNextTransactionIDStmt *sql.Stmt + selectQueuePDUsByTransactionStmt *sql.Stmt + selectQueuePDUReferenceJSONCountStmt *sql.Stmt + selectQueuePDUsCountStmt *sql.Stmt + selectQueuePDUServerNamesStmt *sql.Stmt } -func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(queuePDUsSchema) +func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { + s = &queuePDUsStatements{ + db: db, + } + _, err = s.db.Exec(queuePDUsSchema) if err != nil { return } - if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil { + if s.insertQueuePDUStmt, err = s.db.Prepare(insertQueuePDUSQL); err != nil { return } - if s.deleteQueueTransactionPDUsStmt, err = db.Prepare(deleteQueueTransactionPDUsSQL); err != nil { + if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil { return } - if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { + if s.selectQueuePDUNextTransactionIDStmt, err = s.db.Prepare(selectQueuePDUNextTransactionIDSQL); err != nil { return } - if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { + if s.selectQueuePDUsByTransactionStmt, err = s.db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { return } - if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { + if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { + if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil { return } - if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { return } return } -func (s *queuePDUsStatements) insertQueuePDU( +func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, @@ -122,21 +127,21 @@ func (s *queuePDUsStatements) insertQueuePDU( return err } -func (s *queuePDUsStatements) deleteQueueTransaction( +func (s *queuePDUsStatements) DeleteQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, - transactionID gomatrixserverlib.TransactionID, + jsonNIDs []int64, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt) - _, err := stmt.ExecContext(ctx, serverName, transactionID) + stmt := sqlutil.TxStmt(txn, s.deleteQueuePDUsStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) return err } -func (s *queuePDUsStatements) selectQueueNextTransactionID( +func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (gomatrixserverlib.TransactionID, error) { var transactionID gomatrixserverlib.TransactionID - stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUNextTransactionIDStmt) err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) if err == sql.ErrNoRows { return "", nil @@ -144,11 +149,11 @@ func (s *queuePDUsStatements) selectQueueNextTransactionID( return transactionID, err } -func (s *queuePDUsStatements) selectQueueReferenceJSONCount( +func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( ctx context.Context, txn *sql.Tx, jsonNID int64, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUReferenceJSONCountStmt) err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) if err == sql.ErrNoRows { // It's acceptable for there to be no rows referencing a given @@ -159,7 +164,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUCount( +func (s *queuePDUsStatements) SelectQueuePDUCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 @@ -174,7 +179,7 @@ func (s *queuePDUsStatements) selectQueuePDUCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUs( +func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, @@ -198,10 +203,10 @@ func (s *queuePDUsStatements) selectQueuePDUs( return result, rows.Err() } -func (s *queuePDUsStatements) selectQueueServerNames( +func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err diff --git a/federationsender/storage/postgres/room_table.go b/federationsender/storage/postgres/room_table.go index e5266c63..8d3ed20f 100644 --- a/federationsender/storage/postgres/room_table.go +++ b/federationsender/storage/postgres/room_table.go @@ -43,24 +43,27 @@ const updateRoomSQL = "" + "UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1" type roomStatements struct { + db *sql.DB insertRoomStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt updateRoomStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomSchema) +func NewPostgresRoomsTable(db *sql.DB) (s *roomStatements, err error) { + s = &roomStatements{ + db: db, + } + _, err = s.db.Exec(roomSchema) if err != nil { return } - - if s.insertRoomStmt, err = db.Prepare(insertRoomSQL); err != nil { + if s.insertRoomStmt, err = s.db.Prepare(insertRoomSQL); err != nil { return } - if s.selectRoomForUpdateStmt, err = db.Prepare(selectRoomForUpdateSQL); err != nil { + if s.selectRoomForUpdateStmt, err = s.db.Prepare(selectRoomForUpdateSQL); err != nil { return } - if s.updateRoomStmt, err = db.Prepare(updateRoomSQL); err != nil { + if s.updateRoomStmt, err = s.db.Prepare(updateRoomSQL); err != nil { return } return @@ -68,7 +71,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { // insertRoom inserts the room if it didn't already exist. // If the room didn't exist then last_event_id is set to the empty string. -func (s *roomStatements) insertRoom( +func (s *roomStatements) InsertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) @@ -78,7 +81,7 @@ func (s *roomStatements) insertRoom( // selectRoomForUpdate locks the row for the room and returns the last_event_id. // The row must already exist in the table. Callers can ensure that the row // exists by calling insertRoom first. -func (s *roomStatements) selectRoomForUpdate( +func (s *roomStatements) SelectRoomForUpdate( ctx context.Context, txn *sql.Tx, roomID string, ) (string, error) { var lastEventID string @@ -92,7 +95,7 @@ func (s *roomStatements) selectRoomForUpdate( // updateRoom updates the last_event_id for the room. selectRoomForUpdate should // have already been called earlier within the transaction. -func (s *roomStatements) updateRoom( +func (s *roomStatements) UpdateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index 1535ebdf..66388bfe 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -16,266 +16,56 @@ package postgres import ( - "context" "database/sql" - "encoding/json" - "fmt" - "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" ) // Database stores information needed by the federation sender type Database struct { - joinedHostsStatements - roomStatements - queuePDUsStatements - queueJSONStatements + shared.Database sqlutil.PartitionOffsetStatements db *sql.DB } // NewDatabase opens a new database func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { - var result Database + var d Database var err error - if result.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { + if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - if err = result.prepare(); err != nil { - return nil, err - } - return &result, nil -} - -func (d *Database) prepare() error { - var err error - - if err = d.joinedHostsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.roomStatements.prepare(d.db); err != nil { - return err - } - - if err = d.queuePDUsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.queueJSONStatements.prepare(d.db); err != nil { - return err - } - - return d.PartitionOffsetStatements.Prepare(d.db, "federationsender") -} - -// UpdateRoom updates the joined hosts for a room and returns what the joined -// hosts were before the update, or nil if this was a duplicate message. -// This is called when we receive a message from kafka, so we pass in -// oldEventID and newEventID to check that we haven't missed any messages or -// this isn't a duplicate message. -func (d *Database) UpdateRoom( - ctx context.Context, - roomID, oldEventID, newEventID string, - addHosts []types.JoinedHost, - removeHosts []string, -) (joinedHosts []types.JoinedHost, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - err = d.insertRoom(ctx, txn, roomID) - if err != nil { - return err - } - - lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID) - if err != nil { - return err - } - - if lastSentEventID == newEventID { - // We've handled this message before, so let's just ignore it. - // We can only get a duplicate for the last message we processed, - // so its enough just to compare the newEventID with lastSentEventID - return nil - } - - if lastSentEventID != "" && lastSentEventID != oldEventID { - return types.EventIDMismatchError{ - DatabaseID: lastSentEventID, RoomServerID: oldEventID, - } - } - - joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID) - if err != nil { - return err - } - - for _, add := range addHosts { - err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) - if err != nil { - return err - } - } - if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil { - return err - } - return d.updateRoom(ctx, txn, roomID, newEventID) - }) - return -} - -// GetJoinedHosts returns the currently joined hosts for room, -// as known to federationserver. -// Returns an error if something goes wrong. -func (d *Database) GetJoinedHosts( - ctx context.Context, roomID string, -) ([]types.JoinedHost, error) { - return d.selectJoinedHosts(ctx, roomID) -} - -// GetAllJoinedHosts returns the currently joined hosts for -// all rooms known to the federation sender. -// Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - return d.selectAllJoinedHosts(ctx) -} - -// StoreJSON adds a JSON blob into the queue JSON table and returns -// a NID. The NID will then be used when inserting the per-destination -// metadata entries. -func (d *Database) StoreJSON( - ctx context.Context, js string, -) (int64, error) { - nid, err := d.insertQueueJSON(ctx, nil, js) + joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { - return 0, fmt.Errorf("d.insertQueueJSON: %w", err) + return nil, err } - return nid, nil -} - -// AssociatePDUWithDestination creates an association that the -// destination queues will use to determine which JSON blobs to send -// to which servers. -func (d *Database) AssociatePDUWithDestination( - ctx context.Context, - transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, - nids []int64, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - for _, nid := range nids { - if err := d.insertQueuePDU( - ctx, // context - txn, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - nid, // NID from the federationsender_queue_json table - ); err != nil { - return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) - } - } - return nil - }) -} - -// GetNextTransactionPDUs retrieves events from the database for -// the next pending transaction, up to the limit specified. -func (d *Database) GetNextTransactionPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, - limit int, -) ( - transactionID gomatrixserverlib.TransactionID, - events []*gomatrixserverlib.HeaderedEvent, - err error, -) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - transactionID, err = d.selectQueueNextTransactionID(ctx, txn, serverName) - if err != nil { - return fmt.Errorf("d.selectQueueNextTransactionID: %w", err) - } - - if transactionID == "" { - return nil - } - - nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, limit) - if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - blobs, err := d.selectQueueJSON(ctx, txn, nids) - if err != nil { - return fmt.Errorf("d.selectJSON: %w", err) - } - - for _, blob := range blobs { - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(blob, &event); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - events = append(events, &event) - } - - return nil - }) - return -} - -// CleanTransactionPDUs cleans up all associated events for a -// given transaction. This is done when the transaction was sent -// successfully. -func (d *Database) CleanTransactionPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, - transactionID gomatrixserverlib.TransactionID, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50) - if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil { - return fmt.Errorf("d.deleteQueueTransaction: %w", err) - } - - var count int64 - var deleteNIDs []int64 - for _, nid := range nids { - count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid) - if err != nil { - return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) - } - if count == 0 { - deleteNIDs = append(deleteNIDs, nid) - } - } - - if len(deleteNIDs) > 0 { - if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil { - return fmt.Errorf("d.deleteQueueJSON: %w", err) - } - } - - return nil - }) -} - -// GetPendingPDUCount returns the number of PDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingPDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.selectQueuePDUCount(ctx, nil, serverName) -} - -// GetPendingServerNames returns the server names that have PDUs -// waiting to be sent. -func (d *Database) GetPendingServerNames( - ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { - return d.selectQueueServerNames(ctx, nil) + queuePDUs, err := NewPostgresQueuePDUsTable(d.db) + if err != nil { + return nil, err + } + queueEDUs, err := NewPostgresQueueEDUsTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewPostgresQueueJSONTable(d.db) + if err != nil { + return nil, err + } + rooms, err := NewPostgresRoomsTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + FederationSenderJoinedHosts: joinedHosts, + FederationSenderQueuePDUs: queuePDUs, + FederationSenderQueueEDUs: queueEDUs, + FederationSenderQueueJSON: queueJSON, + FederationSenderRooms: rooms, + } + if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { + return nil, err + } + return &d, nil } diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go new file mode 100644 index 00000000..75681ea3 --- /dev/null +++ b/federationsender/storage/shared/storage.go @@ -0,0 +1,138 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/federationsender/storage/tables" + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + FederationSenderQueuePDUs tables.FederationSenderQueuePDUs + FederationSenderQueueEDUs tables.FederationSenderQueueEDUs + FederationSenderQueueJSON tables.FederationSenderQueueJSON + FederationSenderJoinedHosts tables.FederationSenderJoinedHosts + FederationSenderRooms tables.FederationSenderRooms +} + +// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. +// We don't actually export the NIDs but we need the caller to be able +// to pass them back so that we can clean up if the transaction sends +// successfully. +type Receipt struct { + nids []int64 +} + +func (e *Receipt) Empty() bool { + return len(e.nids) == 0 +} + +func (e *Receipt) String() string { + j, _ := json.Marshal(e.nids) + return string(j) +} + +// UpdateRoom updates the joined hosts for a room and returns what the joined +// hosts were before the update, or nil if this was a duplicate message. +// This is called when we receive a message from kafka, so we pass in +// oldEventID and newEventID to check that we haven't missed any messages or +// this isn't a duplicate message. +func (d *Database) UpdateRoom( + ctx context.Context, + roomID, oldEventID, newEventID string, + addHosts []types.JoinedHost, + removeHosts []string, +) (joinedHosts []types.JoinedHost, err error) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID) + if err != nil { + return err + } + + lastSentEventID, err := d.FederationSenderRooms.SelectRoomForUpdate(ctx, txn, roomID) + if err != nil { + return err + } + + if lastSentEventID == newEventID { + // We've handled this message before, so let's just ignore it. + // We can only get a duplicate for the last message we processed, + // so its enough just to compare the newEventID with lastSentEventID + return nil + } + + if lastSentEventID != "" && lastSentEventID != oldEventID { + return types.EventIDMismatchError{ + DatabaseID: lastSentEventID, RoomServerID: oldEventID, + } + } + + joinedHosts, err = d.FederationSenderJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID) + if err != nil { + return err + } + + for _, add := range addHosts { + err = d.FederationSenderJoinedHosts.InsertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) + if err != nil { + return err + } + } + if err = d.FederationSenderJoinedHosts.DeleteJoinedHosts(ctx, txn, removeHosts); err != nil { + return err + } + return d.FederationSenderRooms.UpdateRoom(ctx, txn, roomID, newEventID) + }) + return +} + +// GetJoinedHosts returns the currently joined hosts for room, +// as known to federationserver. +// Returns an error if something goes wrong. +func (d *Database) GetJoinedHosts( + ctx context.Context, roomID string, +) ([]types.JoinedHost, error) { + return d.FederationSenderJoinedHosts.SelectJoinedHosts(ctx, roomID) +} + +// GetAllJoinedHosts returns the currently joined hosts for +// all rooms known to the federation sender. +// Returns an error if something goes wrong. +func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return d.FederationSenderJoinedHosts.SelectAllJoinedHosts(ctx) +} + +// StoreJSON adds a JSON blob into the queue JSON table and returns +// a NID. The NID will then be used when inserting the per-destination +// metadata entries. +func (d *Database) StoreJSON( + ctx context.Context, js string, +) (*Receipt, error) { + nid, err := d.FederationSenderQueueJSON.InsertQueueJSON(ctx, nil, js) + if err != nil { + return nil, fmt.Errorf("d.insertQueueJSON: %w", err) + } + return &Receipt{ + nids: []int64{nid}, + }, nil +} diff --git a/federationsender/storage/shared/storage_edus.go b/federationsender/storage/shared/storage_edus.go new file mode 100644 index 00000000..75a6dd51 --- /dev/null +++ b/federationsender/storage/shared/storage_edus.go @@ -0,0 +1,143 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +// AssociateEDUWithDestination creates an association that the +// destination queues will use to determine which JSON blobs to send +// to which servers. +func (d *Database) AssociateEDUWithDestination( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipt *Receipt, +) error { + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for _, nid := range receipt.nids { + if err := d.FederationSenderQueueEDUs.InsertQueueEDU( + ctx, // context + txn, // SQL transaction + "", // TODO: EDU type for coalescing + serverName, // destination server name + nid, // NID from the federationsender_queue_json table + ); err != nil { + return fmt.Errorf("InsertQueueEDU: %w", err) + } + } + return nil + }) +} + +// GetNextTransactionEDUs retrieves events from the database for +// the next pending transaction, up to the limit specified. +func (d *Database) GetNextTransactionEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) ( + edus []*gomatrixserverlib.EDU, + receipt *Receipt, + err error, +) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + nids, err := d.FederationSenderQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) + if err != nil { + return fmt.Errorf("SelectQueueEDUs: %w", err) + } + + receipt = &Receipt{ + nids: nids, + } + + blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, nids) + if err != nil { + return fmt.Errorf("SelectQueueJSON: %w", err) + } + + for _, blob := range blobs { + var event gomatrixserverlib.EDU + if err := json.Unmarshal(blob, &event); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + edus = append(edus, &event) + } + + return nil + }) + return +} + +// CleanEDUs cleans up all specified EDUs. This is done when a +// transaction was sent successfully. +func (d *Database) CleanEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipt *Receipt, +) error { + if receipt == nil { + return errors.New("expected receipt") + } + + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + if err := d.FederationSenderQueueEDUs.DeleteQueueEDUs(ctx, txn, serverName, receipt.nids); err != nil { + return err + } + + var deleteNIDs []int64 + for _, nid := range receipt.nids { + count, err := d.FederationSenderQueueEDUs.SelectQueueEDUReferenceJSONCount(ctx, txn, nid) + if err != nil { + return fmt.Errorf("SelectQueueEDUReferenceJSONCount: %w", err) + } + if count == 0 { + deleteNIDs = append(deleteNIDs, nid) + } + } + + if len(deleteNIDs) > 0 { + if err := d.FederationSenderQueueJSON.DeleteQueueJSON(ctx, txn, deleteNIDs); err != nil { + return fmt.Errorf("DeleteQueueJSON: %w", err) + } + } + + return nil + }) +} + +// GetPendingEDUCount returns the number of EDUs waiting to be +// sent for a given servername. +func (d *Database) GetPendingEDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return d.FederationSenderQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName) +} + +// GetPendingServerNames returns the server names that have EDUs +// waiting to be sent. +func (d *Database) GetPendingEDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationSenderQueueEDUs.SelectQueueEDUServerNames(ctx, nil) +} diff --git a/federationsender/storage/shared/storage_pdus.go b/federationsender/storage/shared/storage_pdus.go new file mode 100644 index 00000000..00588956 --- /dev/null +++ b/federationsender/storage/shared/storage_pdus.go @@ -0,0 +1,155 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +// AssociatePDUWithDestination creates an association that the +// destination queues will use to determine which JSON blobs to send +// to which servers. +func (d *Database) AssociatePDUWithDestination( + ctx context.Context, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + receipt *Receipt, +) error { + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for _, nid := range receipt.nids { + if err := d.FederationSenderQueuePDUs.InsertQueuePDU( + ctx, // context + txn, // SQL transaction + transactionID, // transaction ID + serverName, // destination server name + nid, // NID from the federationsender_queue_json table + ); err != nil { + return fmt.Errorf("InsertQueuePDU: %w", err) + } + } + return nil + }) +} + +// GetNextTransactionPDUs retrieves events from the database for +// the next pending transaction, up to the limit specified. +func (d *Database) GetNextTransactionPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) ( + transactionID gomatrixserverlib.TransactionID, + events []*gomatrixserverlib.HeaderedEvent, + receipt *Receipt, + err error, +) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + transactionID, err = d.FederationSenderQueuePDUs.SelectQueuePDUNextTransactionID(ctx, txn, serverName) + if err != nil { + return fmt.Errorf("SelectQueuePDUNextTransactionID: %w", err) + } + + if transactionID == "" { + return nil + } + + nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, transactionID, limit) + if err != nil { + return fmt.Errorf("SelectQueuePDUs: %w", err) + } + + receipt = &Receipt{ + nids: nids, + } + + blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, nids) + if err != nil { + return fmt.Errorf("SelectQueueJSON: %w", err) + } + + for _, blob := range blobs { + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal(blob, &event); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + events = append(events, &event) + } + + return nil + }) + return +} + +// CleanTransactionPDUs cleans up all associated events for a +// given transaction. This is done when the transaction was sent +// successfully. +func (d *Database) CleanPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipt *Receipt, +) error { + if receipt == nil { + return errors.New("expected receipt") + } + + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + if err := d.FederationSenderQueuePDUs.DeleteQueuePDUs(ctx, txn, serverName, receipt.nids); err != nil { + return err + } + + var deleteNIDs []int64 + for _, nid := range receipt.nids { + count, err := d.FederationSenderQueuePDUs.SelectQueuePDUReferenceJSONCount(ctx, txn, nid) + if err != nil { + return fmt.Errorf("SelectQueuePDUReferenceJSONCount: %w", err) + } + if count == 0 { + deleteNIDs = append(deleteNIDs, nid) + } + } + + if len(deleteNIDs) > 0 { + if err := d.FederationSenderQueueJSON.DeleteQueueJSON(ctx, txn, deleteNIDs); err != nil { + return fmt.Errorf("DeleteQueueJSON: %w", err) + } + } + + return nil + }) +} + +// GetPendingPDUCount returns the number of PDUs waiting to be +// sent for a given servername. +func (d *Database) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return d.FederationSenderQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName) +} + +// GetPendingServerNames returns the server names that have PDUs +// waiting to be sent. +func (d *Database) GetPendingPDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationSenderQueuePDUs.SelectQueuePDUServerNames(ctx, nil) +} diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index fd9ffedc..bd917c61 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -60,13 +60,19 @@ const selectAllJoinedHostsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" type joinedHostsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt } -func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { +func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { + s = &joinedHostsStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err = db.Exec(joinedHostsSchema) if err != nil { return @@ -86,43 +92,47 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { return } -func (s *joinedHostsStatements) insertJoinedHosts( +func (s *joinedHostsStatements) InsertJoinedHosts( ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) - _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) + _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) + return err + }) } -func (s *joinedHostsStatements) deleteJoinedHosts( +func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { - for _, eventID := range eventIDs { - stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) - if _, err := stmt.ExecContext(ctx, eventID); err != nil { - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + for _, eventID := range eventIDs { + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) + if _, err := stmt.ExecContext(ctx, eventID); err != nil { + return err + } } - } - return nil + return nil + }) } -func (s *joinedHostsStatements) selectJoinedHostsWithTx( +func (s *joinedHostsStatements) SelectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) return joinedHostsFromStmt(ctx, stmt, roomID) } -func (s *joinedHostsStatements) selectJoinedHosts( +func (s *joinedHostsStatements) SelectJoinedHosts( ctx context.Context, roomID string, ) ([]types.JoinedHost, error) { return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } -func (s *joinedHostsStatements) selectAllJoinedHosts( +func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, ) ([]gomatrixserverlib.ServerName, error) { rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) diff --git a/federationsender/storage/sqlite3/queue_edus_table.go b/federationsender/storage/sqlite3/queue_edus_table.go new file mode 100644 index 00000000..ed5d9ffa --- /dev/null +++ b/federationsender/storage/sqlite3/queue_edus_table.go @@ -0,0 +1,214 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queueEDUsSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( + -- The type of the event (informational). + edu_type TEXT NOT NULL, + -- The domain part of the user ID the EDU event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_queue_edus_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx + ON federationsender_queue_edus (json_nid, server_name); +` + +const insertQueueEDUSQL = "" + + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEDUsSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueEDUSQL = "" + + "SELECT json_nid FROM federationsender_queue_edus" + + " WHERE server_name = $1" + + " LIMIT $2" + +const selectQueueEDUReferenceJSONCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE json_nid = $1" + +const selectQueueEDUCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE server_name = $1" + +const selectQueueServerNamesSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_queue_edus" + +type queueEDUsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter + insertQueueEDUStmt *sql.Stmt + selectQueueEDUStmt *sql.Stmt + selectQueueEDUReferenceJSONCountStmt *sql.Stmt + selectQueueEDUCountStmt *sql.Stmt + selectQueueEDUServerNamesStmt *sql.Stmt +} + +func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { + s = &queueEDUsStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } + _, err = db.Exec(queueEDUsSchema) + if err != nil { + return + } + if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { + return + } + if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil { + return + } + if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + return + } + return +} + +func (s *queueEDUsStatements) InsertQueueEDU( + ctx context.Context, + txn *sql.Tx, + eduType string, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) + _, err := stmt.ExecContext( + ctx, + eduType, // the EDU type + serverName, // destination server name + nid, // JSON blob NID + ) + return err + }) +} + +func (s *queueEDUsStatements) DeleteQueueEDUs( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + fmt.Println(deleteSQL) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err := stmt.ExecContext(ctx, params...) + return err + }) +} + +func (s *queueEDUsStatements) SelectQueueEDUs( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, nil +} + +func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( + ctx context.Context, txn *sql.Tx, jsonNID int64, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt) + err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + if err == sql.ErrNoRows { + return -1, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUServerNames( + ctx context.Context, txn *sql.Tx, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName gomatrixserverlib.ServerName + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, serverName) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go index 01b7160d..46dfd9ab 100644 --- a/federationsender/storage/sqlite3/queue_json_table.go +++ b/federationsender/storage/sqlite3/queue_json_table.go @@ -49,12 +49,18 @@ const selectJSONSQL = "" + " WHERE json_nid IN ($1)" type queueJSONStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertJSONStmt *sql.Stmt //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic } -func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { +func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { + s = &queueJSONStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err = db.Exec(queueJSONSchema) if err != nil { return @@ -65,22 +71,25 @@ func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { return } -func (s *queueJSONStatements) insertQueueJSON( +func (s *queueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, -) (int64, error) { - stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) - res, err := stmt.ExecContext(ctx, json) - if err != nil { - return 0, fmt.Errorf("stmt.QueryContext: %w", err) - } - lastid, err := res.LastInsertId() - if err != nil { - return 0, fmt.Errorf("res.LastInsertId: %w", err) - } - return lastid, nil +) (lastid int64, err error) { + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return fmt.Errorf("res.LastInsertId: %w", err) + } + return nil + }) + return } -func (s *queueJSONStatements) deleteQueueJSON( +func (s *queueJSONStatements) DeleteQueueJSON( ctx context.Context, txn *sql.Tx, nids []int64, ) error { deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) @@ -94,12 +103,14 @@ func (s *queueJSONStatements) deleteQueueJSON( iNIDs[k] = v } - stmt := sqlutil.TxStmt(txn, deleteStmt) - _, err = stmt.ExecContext(ctx, iNIDs...) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err + }) } -func (s *queueJSONStatements) selectQueueJSON( +func (s *queueJSONStatements) SelectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index 33eef91e..9d8eaab6 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -18,6 +18,8 @@ package sqlite3 import ( "context" "database/sql" + "fmt" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -42,8 +44,8 @@ const insertQueuePDUSQL = "" + "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + " VALUES ($1, $2, $3)" -const deleteQueueTransactionPDUsSQL = "" + - "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND transaction_id = $2" +const deleteQueuePDUsSQL = "" + + "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)" const selectQueueNextTransactionIDSQL = "" + "SELECT transaction_id FROM federationsender_queue_pdus" + @@ -56,7 +58,7 @@ const selectQueuePDUsByTransactionSQL = "" + " WHERE server_name = $1 AND transaction_id = $2" + " LIMIT $3" -const selectQueueReferenceJSONCountSQL = "" + +const selectQueuePDUsReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" @@ -64,20 +66,26 @@ const selectQueuePDUsCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE server_name = $1" -const selectQueueServerNamesSQL = "" + +const selectQueuePDUsServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" type queuePDUsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertQueuePDUStmt *sql.Stmt - deleteQueueTransactionPDUsStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsByTransactionStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt selectQueuePDUsCountStmt *sql.Stmt selectQueueServerNamesStmt *sql.Stmt + // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic } -func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { +func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { + s = &queuePDUsStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err = db.Exec(queuePDUsSchema) if err != nil { return @@ -85,55 +93,72 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil { return } - if s.deleteQueueTransactionPDUsStmt, err = db.Prepare(deleteQueueTransactionPDUsSQL); err != nil { - return - } + //if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil { + // return + //} if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { return } if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { return } - if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { + if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { return } if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { return } - if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { return } return } -func (s *queuePDUsStatements) insertQueuePDU( +func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64, ) error { - stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) - _, err := stmt.ExecContext( - ctx, - transactionID, // the transaction ID that we initially attempted - serverName, // destination server name - nid, // JSON blob NID - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err + }) } -func (s *queuePDUsStatements) deleteQueueTransaction( +func (s *queuePDUsStatements) DeleteQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, - transactionID gomatrixserverlib.TransactionID, + jsonNIDs []int64, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt) - _, err := stmt.ExecContext(ctx, serverName, transactionID) - return err + deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + fmt.Println(deleteSQL) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err := stmt.ExecContext(ctx, params...) + return err + }) } -func (s *queuePDUsStatements) selectQueueNextTransactionID( +func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (gomatrixserverlib.TransactionID, error) { var transactionID gomatrixserverlib.TransactionID @@ -145,7 +170,7 @@ func (s *queuePDUsStatements) selectQueueNextTransactionID( return transactionID, err } -func (s *queuePDUsStatements) selectQueueReferenceJSONCount( +func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( ctx context.Context, txn *sql.Tx, jsonNID int64, ) (int64, error) { var count int64 @@ -157,7 +182,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUCount( +func (s *queuePDUsStatements) SelectQueuePDUCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 @@ -172,7 +197,7 @@ func (s *queuePDUsStatements) selectQueuePDUCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUs( +func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, @@ -196,7 +221,7 @@ func (s *queuePDUsStatements) selectQueuePDUs( return result, rows.Err() } -func (s *queuePDUsStatements) selectQueueServerNames( +func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go index ca0c4d0b..51793874 100644 --- a/federationsender/storage/sqlite3/room_table.go +++ b/federationsender/storage/sqlite3/room_table.go @@ -43,12 +43,18 @@ const updateRoomSQL = "" + "UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1" type roomStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertRoomStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt updateRoomStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { +func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) { + s = &roomStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err = db.Exec(roomSchema) if err != nil { return @@ -68,17 +74,19 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { // insertRoom inserts the room if it didn't already exist. // If the room didn't exist then last_event_id is set to the empty string. -func (s *roomStatements) insertRoom( +func (s *roomStatements) InsertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) + return err + }) } // selectRoomForUpdate locks the row for the room and returns the last_event_id. // The row must already exist in the table. Callers can ensure that the row // exists by calling insertRoom first. -func (s *roomStatements) selectRoomForUpdate( +func (s *roomStatements) SelectRoomForUpdate( ctx context.Context, txn *sql.Tx, roomID string, ) (string, error) { var lastEventID string @@ -92,10 +100,12 @@ func (s *roomStatements) selectRoomForUpdate( // updateRoom updates the last_event_id for the room. selectRoomForUpdate should // have already been called earlier within the transaction. -func (s *roomStatements) updateRoom( +func (s *roomStatements) UpdateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { - stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) - _, err := stmt.ExecContext(ctx, roomID, lastEventID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) + _, err := stmt.ExecContext(ctx, roomID, lastEventID) + return err + }) } diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index b23a2dbe..545a229c 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -16,283 +16,62 @@ package sqlite3 import ( - "context" "database/sql" - "encoding/json" - "fmt" _ "github.com/mattn/go-sqlite3" - "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" ) // Database stores information needed by the federation sender type Database struct { - joinedHostsStatements - roomStatements - queuePDUsStatements - queueJSONStatements + shared.Database sqlutil.PartitionOffsetStatements - db *sql.DB - queuePDUsWriter *sqlutil.TransactionWriter - queueJSONWriter *sqlutil.TransactionWriter + db *sql.DB } // NewDatabase opens a new database func NewDatabase(dataSourceName string) (*Database, error) { - var result Database + var d Database var err error cs, err := sqlutil.ParseFileURI(dataSourceName) if err != nil { return nil, err } - if result.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { + if d.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } - if err = result.prepare(); err != nil { - return nil, err - } - return &result, nil -} - -func (d *Database) prepare() error { - var err error - - if err = d.joinedHostsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.roomStatements.prepare(d.db); err != nil { - return err - } - - if err = d.queuePDUsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.queueJSONStatements.prepare(d.db); err != nil { - return err - } - - d.queuePDUsWriter = sqlutil.NewTransactionWriter() - d.queueJSONWriter = sqlutil.NewTransactionWriter() - - return d.PartitionOffsetStatements.Prepare(d.db, "federationsender") -} - -// UpdateRoom updates the joined hosts for a room and returns what the joined -// hosts were before the update, or nil if this was a duplicate message. -// This is called when we receive a message from kafka, so we pass in -// oldEventID and newEventID to check that we haven't missed any messages or -// this isn't a duplicate message. -func (d *Database) UpdateRoom( - ctx context.Context, - roomID, oldEventID, newEventID string, - addHosts []types.JoinedHost, - removeHosts []string, -) (joinedHosts []types.JoinedHost, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - err = d.insertRoom(ctx, txn, roomID) - if err != nil { - return err - } - - lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID) - if err != nil { - return err - } - - if lastSentEventID == newEventID { - // We've handled this message before, so let's just ignore it. - // We can only get a duplicate for the last message we processed, - // so its enough just to compare the newEventID with lastSentEventID - return nil - } - - if lastSentEventID != "" && lastSentEventID != oldEventID { - return types.EventIDMismatchError{ - DatabaseID: lastSentEventID, RoomServerID: oldEventID, - } - } - - joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID) - if err != nil { - return err - } - - for _, add := range addHosts { - err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) - if err != nil { - return err - } - } - if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil { - return err - } - return d.updateRoom(ctx, txn, roomID, newEventID) - }) - return -} - -// GetJoinedHosts returns the currently joined hosts for room, -// as known to federationserver. -// Returns an error if something goes wrong. -func (d *Database) GetJoinedHosts( - ctx context.Context, roomID string, -) ([]types.JoinedHost, error) { - return d.selectJoinedHosts(ctx, roomID) -} - -// GetAllJoinedHosts returns the currently joined hosts for -// all rooms known to the federation sender. -// Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - return d.selectAllJoinedHosts(ctx) -} - -// StoreJSON adds a JSON blob into the queue JSON table and returns -// a NID. The NID will then be used when inserting the per-destination -// metadata entries. -func (d *Database) StoreJSON( - ctx context.Context, js string, -) (nid int64, err error) { - err = d.queueJSONWriter.Do(d.db, func(txn *sql.Tx) error { - n, e := d.insertQueueJSON(ctx, nil, js) - if e != nil { - return fmt.Errorf("d.insertQueueJSON: %w", e) - } - nid = n - return nil - }) - return -} - -// AssociatePDUWithDestination creates an association that the -// destination queues will use to determine which JSON blobs to send -// to which servers. -func (d *Database) AssociatePDUWithDestination( - ctx context.Context, - transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, - nids []int64, -) error { - return d.queuePDUsWriter.Do(d.db, func(txn *sql.Tx) error { - for _, nid := range nids { - if err := d.insertQueuePDU( - ctx, // context - txn, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - nid, // NID from the federationsender_queue_json table - ); err != nil { - return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) - } - } - return nil - }) -} - -// GetNextTransactionPDUs retrieves events from the database for -// the next pending transaction, up to the limit specified. -func (d *Database) GetNextTransactionPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, - limit int, -) ( - transactionID gomatrixserverlib.TransactionID, - events []*gomatrixserverlib.HeaderedEvent, - err error, -) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - transactionID, err = d.selectQueueNextTransactionID(ctx, txn, serverName) - if err != nil { - return fmt.Errorf("d.selectQueueNextTransactionID: %w", err) - } - - if transactionID == "" { - return nil - } - - nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, limit) - if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - blobs, err := d.selectQueueJSON(ctx, txn, nids) - if err != nil { - return fmt.Errorf("d.selectJSON: %w", err) - } - - for _, blob := range blobs { - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(blob, &event); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - events = append(events, &event) - } - - return nil - }) - return -} - -// CleanTransactionPDUs cleans up all associated events for a -// given transaction. This is done when the transaction was sent -// successfully. -func (d *Database) CleanTransactionPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, - transactionID gomatrixserverlib.TransactionID, -) error { - var deleteNIDs []int64 - nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, 50) + joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) + return nil, err } - if err = d.queuePDUsWriter.Do(d.db, func(txn *sql.Tx) error { - if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil { - return fmt.Errorf("d.deleteQueueTransaction: %w", err) - } - return nil - }); err != nil { - return err + rooms, err := NewSQLiteRoomsTable(d.db) + if err != nil { + return nil, err } - var count int64 - for _, nid := range nids { - count, err = d.selectQueueReferenceJSONCount(ctx, nil, nid) - if err != nil { - return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) - } - if count == 0 { - deleteNIDs = append(deleteNIDs, nid) - } + queuePDUs, err := NewSQLiteQueuePDUsTable(d.db) + if err != nil { + return nil, err } - if len(deleteNIDs) > 0 { - err = d.queueJSONWriter.Do(d.db, func(txn *sql.Tx) error { - if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil { - return fmt.Errorf("d.deleteQueueJSON: %w", err) - } - return nil - }) + queueEDUs, err := NewSQLiteQueueEDUsTable(d.db) + if err != nil { + return nil, err } - return err -} - -// GetPendingPDUCount returns the number of PDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingPDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.selectQueuePDUCount(ctx, nil, serverName) -} - -// GetPendingServerNames returns the server names that have PDUs -// waiting to be sent. -func (d *Database) GetPendingServerNames( - ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { - return d.selectQueueServerNames(ctx, nil) + queueJSON, err := NewSQLiteQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + FederationSenderJoinedHosts: joinedHosts, + FederationSenderQueuePDUs: queuePDUs, + FederationSenderQueueEDUs: queueEDUs, + FederationSenderQueueJSON: queueJSON, + FederationSenderRooms: rooms, + } + if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { + return nil, err + } + return &d, nil } diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go new file mode 100644 index 00000000..55d9119f --- /dev/null +++ b/federationsender/storage/tables/interface.go @@ -0,0 +1,62 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type FederationSenderQueuePDUs interface { + InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueuePDUNextTransactionID(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (gomatrixserverlib.TransactionID, error) + SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) + SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) + SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, limit int) ([]int64, error) + SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) +} + +type FederationSenderQueueEDUs interface { + InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) + SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) + SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) +} + +type FederationSenderQueueJSON interface { + InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} + +type FederationSenderJoinedHosts interface { + InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error + DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error + SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) + SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) + SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) +} + +type FederationSenderRooms interface { + InsertRoom(ctx context.Context, txn *sql.Tx, roomID string) error + SelectRoomForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (string, error) + UpdateRoom(ctx context.Context, txn *sql.Tx, roomID, lastEventID string) error +} diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index a25a4a5b..2ec6ce29 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -131,14 +131,17 @@ func NewTransactionWriter() *TransactionWriter { // transactionWriterTask represents a specific task. type transactionWriterTask struct { db *sql.DB + txn *sql.Tx f func(txn *sql.Tx) error wait chan error } // Do queues a task to be run by a TransactionWriter. The function // provided will be ran within a transaction as supplied by the -// database parameter. This will block until the task is finished. -func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error { +// txn parameter if one is supplied, and if not, will take out a +// new transaction from the database supplied in the database +// parameter. Either way, this will block until the task is done. +func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { if w.todo == nil { return errors.New("not initialised") } @@ -147,6 +150,7 @@ func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error { } task := transactionWriterTask{ db: db, + txn: txn, f: f, wait: make(chan error, 1), } @@ -164,9 +168,15 @@ func (w *TransactionWriter) run() { } defer w.running.Store(false) for task := range w.todo { - task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { - return task.f(txn) - }) + if task.txn != nil { + task.wait <- task.f(task.txn) + } else if task.db != nil { + task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { + return task.f(txn) + }) + } else { + panic("expected database or transaction but got neither") + } close(task.wait) } } diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go index 62cb7fcb..233e5d29 100644 --- a/keyserver/storage/storage_wasm.go +++ b/keyserver/storage/storage_wasm.go @@ -19,7 +19,7 @@ import ( "net/url" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" ) func NewDatabase( diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 38b503cd..32079291 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1114,7 +1114,7 @@ func (d *Database) StoreNewSendForDeviceMessage( } // Delegate the database write task to the SendToDeviceWriter. It'll guarantee // that we don't lock the table for writes in more than one place. - err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { return d.AddSendToDeviceEvent( ctx, txn, userID, deviceID, string(j), ) @@ -1179,7 +1179,7 @@ func (d *Database) CleanSendToDeviceUpdates( // If we need to write to the database then we'll ask the SendToDeviceWriter to // do that for us. It'll guarantee that we don't lock the table for writes in // more than one place. - err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { // Delete any send-to-device messages marked for deletion. if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)