Add context to the federationsender database (#231)

main
Mark Haines 2017-09-18 14:15:17 +01:00 committed by GitHub
parent dc5dd4c5d2
commit 5ada8872bb
4 changed files with 48 additions and 21 deletions

View File

@ -123,8 +123,12 @@ func (s *OutputRoomEvent) processMessage(ore api.OutputNewRoomEvent) error {
// TODO: handle EventIDMismatchError and recover the current state by talking
// to the roomserver
oldJoinedHosts, err := s.db.UpdateRoom(
ore.Event.RoomID(), ore.LastSentEventID, ore.Event.EventID(),
addsJoinedHosts, ore.RemovesStateEventIDs,
context.TODO(),
ore.Event.RoomID(),
ore.LastSentEventID,
ore.Event.EventID(),
addsJoinedHosts,
ore.RemovesStateEventIDs,
)
if err != nil {
return err

View File

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
@ -78,20 +79,29 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
}
func (s *joinedHostsStatements) insertJoinedHosts(
txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName,
ctx context.Context,
txn *sql.Tx,
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error {
_, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
stmt := common.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err
}
func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error {
_, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
func (s *joinedHostsStatements) deleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
return err
}
func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string,
func (s *joinedHostsStatements) selectJoinedHosts(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID)
stmt := common.TxStmt(txn, s.selectJoinedHostsStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}

View File

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
@ -66,17 +67,22 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
// insertRoom inserts the room if it didn't already exist.
// If the room didn't exist then last_event_id is set to the empty string.
func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
_, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID)
func (s *roomStatements) insertRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
// The row must already exist in the table. Callers can ensure that the row
// exists by calling insertRoom first.
func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) {
func (s *roomStatements) selectRoomForUpdate(
ctx context.Context, txn *sql.Tx, roomID string,
) (string, error) {
var lastEventID string
err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
if err != nil {
return "", err
}
@ -85,7 +91,10 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
// have already been called earlier within the transaction.
func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error {
_, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID)
func (s *roomStatements) updateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error {
stmt := common.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err
}

View File

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
@ -73,35 +74,38 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
// UpdateRoom updates the joined hosts for a room and returns what the joined
// hosts were before the update.
func (d *Database) UpdateRoom(
ctx context.Context,
roomID, oldEventID, newEventID string,
addHosts []types.JoinedHost,
removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err = d.insertRoom(txn, roomID); err != nil {
if err = d.insertRoom(ctx, txn, roomID); err != nil {
return err
}
lastSentEventID, err := d.selectRoomForUpdate(txn, roomID)
lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID)
if err != nil {
return err
}
if lastSentEventID != oldEventID {
return types.EventIDMismatchError{lastSentEventID, oldEventID}
return types.EventIDMismatchError{
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
}
}
joinedHosts, err = d.selectJoinedHosts(txn, roomID)
joinedHosts, err = d.selectJoinedHosts(ctx, txn, roomID)
if err != nil {
return err
}
for _, add := range addHosts {
err = d.insertJoinedHosts(txn, roomID, add.MemberEventID, add.ServerName)
err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
if err != nil {
return err
}
}
if err = d.deleteJoinedHosts(txn, removeHosts); err != nil {
if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil {
return err
}
return d.updateRoom(txn, roomID, newEventID)
return d.updateRoom(ctx, txn, roomID, newEventID)
})
return
}