Component-wide TransactionWriters (#1290)
* Offset updates take place using TransactionWriter * Refactor TransactionWriter in current state server * Refactor TransactionWriter in federation sender * Refactor TransactionWriter in key server * Refactor TransactionWriter in media API * Refactor TransactionWriter in server key API * Refactor TransactionWriter in sync API * Refactor TransactionWriter in user API * Fix deadlocking Sync API tests * Un-deadlock device database * Fix appservice API * Rename TransactionWriters to Writers * Move writers up a layer in sync API * Document sqlutil.Writer interface * Add note to Writer documentationmain
parent
5aaf32bbed
commit
9d53351dc2
|
@ -32,6 +32,7 @@ type Database struct {
|
||||||
events eventsStatements
|
events eventsStatements
|
||||||
txnID txnStatements
|
txnID txnStatements
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
|
@ -41,10 +42,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
if result.db, err = sqlutil.Open(dbProperties); err != nil {
|
if result.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
result.writer = sqlutil.NewDummyWriter()
|
||||||
if err = result.prepare(); err != nil {
|
if err = result.prepare(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil {
|
if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &result, nil
|
return &result, nil
|
||||||
|
|
|
@ -67,7 +67,7 @@ const (
|
||||||
|
|
||||||
type eventsStatements struct {
|
type eventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
selectEventsByApplicationServiceIDStmt *sql.Stmt
|
||||||
countEventsByApplicationServiceIDStmt *sql.Stmt
|
countEventsByApplicationServiceIDStmt *sql.Stmt
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
|
@ -75,9 +75,9 @@ type eventsStatements struct {
|
||||||
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
|
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(appserviceEventsSchema)
|
_, err = db.Exec(appserviceEventsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -32,6 +32,7 @@ type Database struct {
|
||||||
events eventsStatements
|
events eventsStatements
|
||||||
txnID txnStatements
|
txnID txnStatements
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
|
@ -41,21 +42,22 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
if result.db, err = sqlutil.Open(dbProperties); err != nil {
|
if result.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
result.writer = sqlutil.NewExclusiveWriter()
|
||||||
if err = result.prepare(); err != nil {
|
if err = result.prepare(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil {
|
if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) prepare() error {
|
func (d *Database) prepare() error {
|
||||||
if err := d.events.prepare(d.db); err != nil {
|
if err := d.events.prepare(d.db, d.writer); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.txnID.prepare(d.db)
|
return d.txnID.prepare(d.db, d.writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
|
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
|
||||||
|
|
|
@ -38,13 +38,13 @@ const selectTxnIDSQL = `
|
||||||
|
|
||||||
type txnStatements struct {
|
type txnStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
selectTxnIDStmt *sql.Stmt
|
selectTxnIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *txnStatements) prepare(db *sql.DB) (err error) {
|
func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(txnIDSchema)
|
_, err = db.Exec(txnIDSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
type Database struct {
|
type Database struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,7 +22,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil {
|
d.writer = sqlutil.NewDummyWriter()
|
||||||
|
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
currRoomState, err := NewPostgresCurrentRoomStateTable(d.db)
|
currRoomState, err := NewPostgresCurrentRoomStateTable(d.db)
|
||||||
|
@ -30,6 +32,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
Writer: d.writer,
|
||||||
CurrentRoomState: currRoomState,
|
CurrentRoomState: currRoomState,
|
||||||
}
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
|
|
|
@ -27,6 +27,7 @@ import (
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
|
Writer sqlutil.Writer
|
||||||
CurrentRoomState tables.CurrentRoomState
|
CurrentRoomState tables.CurrentRoomState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,7 +60,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
|
||||||
|
|
||||||
func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent,
|
func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent,
|
||||||
removeStateEventIDs []string) error {
|
removeStateEventIDs []string) error {
|
||||||
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
|
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
|
||||||
for _, eventID := range removeStateEventIDs {
|
for _, eventID := range removeStateEventIDs {
|
||||||
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
|
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
|
||||||
|
|
|
@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" +
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||||
|
@ -96,7 +96,7 @@ type currentRoomStateStatements struct {
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
|
||||||
s := ¤tRoomStateStatements{
|
s := ¤tRoomStateStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
writer: sqlutil.NewExclusiveWriter(),
|
||||||
}
|
}
|
||||||
_, err := db.Exec(currentRoomStateSchema)
|
_, err := db.Exec(currentRoomStateSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
type Database struct {
|
type Database struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,7 +23,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil {
|
d.writer = sqlutil.NewExclusiveWriter()
|
||||||
|
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
currRoomState, err := NewSqliteCurrentRoomStateTable(d.db)
|
currRoomState, err := NewSqliteCurrentRoomStateTable(d.db)
|
||||||
|
@ -31,6 +33,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
Writer: d.writer,
|
||||||
CurrentRoomState: currRoomState,
|
CurrentRoomState: currRoomState,
|
||||||
}
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
|
|
|
@ -28,6 +28,7 @@ type Database struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
|
@ -37,6 +38,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.writer = sqlutil.NewDummyWriter()
|
||||||
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
|
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -63,6 +65,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
Writer: d.writer,
|
||||||
FederationSenderJoinedHosts: joinedHosts,
|
FederationSenderJoinedHosts: joinedHosts,
|
||||||
FederationSenderQueuePDUs: queuePDUs,
|
FederationSenderQueuePDUs: queuePDUs,
|
||||||
FederationSenderQueueEDUs: queueEDUs,
|
FederationSenderQueueEDUs: queueEDUs,
|
||||||
|
@ -70,7 +73,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
FederationSenderRooms: rooms,
|
FederationSenderRooms: rooms,
|
||||||
FederationSenderBlacklist: blacklist,
|
FederationSenderBlacklist: blacklist,
|
||||||
}
|
}
|
||||||
if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil {
|
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
|
|
|
@ -28,6 +28,7 @@ import (
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
|
Writer sqlutil.Writer
|
||||||
FederationSenderQueuePDUs tables.FederationSenderQueuePDUs
|
FederationSenderQueuePDUs tables.FederationSenderQueuePDUs
|
||||||
FederationSenderQueueEDUs tables.FederationSenderQueueEDUs
|
FederationSenderQueueEDUs tables.FederationSenderQueueEDUs
|
||||||
FederationSenderQueueJSON tables.FederationSenderQueueJSON
|
FederationSenderQueueJSON tables.FederationSenderQueueJSON
|
||||||
|
@ -64,7 +65,7 @@ func (d *Database) UpdateRoom(
|
||||||
addHosts []types.JoinedHost,
|
addHosts []types.JoinedHost,
|
||||||
removeHosts []string,
|
removeHosts []string,
|
||||||
) (joinedHosts []types.JoinedHost, err error) {
|
) (joinedHosts []types.JoinedHost, err error) {
|
||||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID)
|
err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -133,7 +134,12 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string)
|
||||||
func (d *Database) StoreJSON(
|
func (d *Database) StoreJSON(
|
||||||
ctx context.Context, js string,
|
ctx context.Context, js string,
|
||||||
) (*Receipt, error) {
|
) (*Receipt, error) {
|
||||||
nid, err := d.FederationSenderQueueJSON.InsertQueueJSON(ctx, nil, js)
|
var nid int64
|
||||||
|
var err error
|
||||||
|
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
nid, err = d.FederationSenderQueueJSON.InsertQueueJSON(ctx, txn, js)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
|
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -143,11 +149,15 @@ func (d *Database) StoreJSON(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||||
return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), nil, serverName)
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
|
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||||
return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), nil, serverName)
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
|
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||||
|
|
|
@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" +
|
||||||
|
|
||||||
type blacklistStatements struct {
|
type blacklistStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertBlacklistStmt *sql.Stmt
|
insertBlacklistStmt *sql.Stmt
|
||||||
selectBlacklistStmt *sql.Stmt
|
selectBlacklistStmt *sql.Stmt
|
||||||
deleteBlacklistStmt *sql.Stmt
|
deleteBlacklistStmt *sql.Stmt
|
||||||
|
@ -51,7 +50,6 @@ type blacklistStatements struct {
|
||||||
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||||
s = &blacklistStatements{
|
s = &blacklistStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err = db.Exec(blacklistSchema)
|
_, err = db.Exec(blacklistSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -75,11 +73,9 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||||
func (s *blacklistStatements) InsertBlacklist(
|
func (s *blacklistStatements) InsertBlacklist(
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
||||||
_, err := stmt.ExecContext(ctx, serverName)
|
_, err := stmt.ExecContext(ctx, serverName)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
||||||
|
@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist(
|
||||||
func (s *blacklistStatements) DeleteBlacklist(
|
func (s *blacklistStatements) DeleteBlacklist(
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
||||||
_, err := stmt.ExecContext(ctx, serverName)
|
_, err := stmt.ExecContext(ctx, serverName)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,6 @@ const selectJoinedHostsForRoomsSQL = "" +
|
||||||
|
|
||||||
type joinedHostsStatements struct {
|
type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsStmt *sql.Stmt
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
|
@ -76,7 +75,6 @@ type joinedHostsStatements struct {
|
||||||
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||||
s = &joinedHostsStatements{
|
s = &joinedHostsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err = db.Exec(joinedHostsSchema)
|
_, err = db.Exec(joinedHostsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -103,17 +101,14 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
|
||||||
roomID, eventID string,
|
roomID, eventID string,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
||||||
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
for _, eventID := range eventIDs {
|
for _, eventID := range eventIDs {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
||||||
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
|
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
|
||||||
|
@ -121,7 +116,6 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
||||||
|
|
|
@ -64,7 +64,6 @@ const selectQueueServerNamesSQL = "" +
|
||||||
|
|
||||||
type queueEDUsStatements struct {
|
type queueEDUsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertQueueEDUStmt *sql.Stmt
|
insertQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUStmt *sql.Stmt
|
selectQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||||
|
@ -75,7 +74,6 @@ type queueEDUsStatements struct {
|
||||||
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
|
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
|
||||||
s = &queueEDUsStatements{
|
s = &queueEDUsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err = db.Exec(queueEDUsSchema)
|
_, err = db.Exec(queueEDUsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -106,7 +104,6 @@ func (s *queueEDUsStatements) InsertQueueEDU(
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
nid int64,
|
nid int64,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
|
||||||
_, err := stmt.ExecContext(
|
_, err := stmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -115,7 +112,6 @@ func (s *queueEDUsStatements) InsertQueueEDU(
|
||||||
nid, // JSON blob NID
|
nid, // JSON blob NID
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueEDUsStatements) DeleteQueueEDUs(
|
func (s *queueEDUsStatements) DeleteQueueEDUs(
|
||||||
|
@ -135,11 +131,9 @@ func (s *queueEDUsStatements) DeleteQueueEDUs(
|
||||||
params[k+1] = v
|
params[k+1] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||||
_, err := stmt.ExecContext(ctx, params...)
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUs(
|
func (s *queueEDUsStatements) SelectQueueEDUs(
|
||||||
|
|
|
@ -50,7 +50,6 @@ const selectJSONSQL = "" +
|
||||||
|
|
||||||
type queueJSONStatements struct {
|
type queueJSONStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertJSONStmt *sql.Stmt
|
insertJSONStmt *sql.Stmt
|
||||||
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
|
@ -59,7 +58,6 @@ type queueJSONStatements struct {
|
||||||
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
||||||
s = &queueJSONStatements{
|
s = &queueJSONStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err = db.Exec(queueJSONSchema)
|
_, err = db.Exec(queueJSONSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -74,18 +72,15 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
||||||
func (s *queueJSONStatements) InsertQueueJSON(
|
func (s *queueJSONStatements) InsertQueueJSON(
|
||||||
ctx context.Context, txn *sql.Tx, json string,
|
ctx context.Context, txn *sql.Tx, json string,
|
||||||
) (lastid int64, err error) {
|
) (lastid int64, err error) {
|
||||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
||||||
res, err := stmt.ExecContext(ctx, json)
|
res, err := stmt.ExecContext(ctx, json)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("stmt.QueryContext: %w", err)
|
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
|
||||||
}
|
}
|
||||||
lastid, err = res.LastInsertId()
|
lastid, err = res.LastInsertId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("res.LastInsertId: %w", err)
|
return 0, fmt.Errorf("res.LastInsertId: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,11 +98,9 @@ func (s *queueJSONStatements) DeleteQueueJSON(
|
||||||
iNIDs[k] = v
|
iNIDs[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||||
_, err = stmt.ExecContext(ctx, iNIDs...)
|
_, err = stmt.ExecContext(ctx, iNIDs...)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueJSONStatements) SelectQueueJSON(
|
func (s *queueJSONStatements) SelectQueueJSON(
|
||||||
|
|
|
@ -71,7 +71,6 @@ const selectQueuePDUsServerNamesSQL = "" +
|
||||||
|
|
||||||
type queuePDUsStatements struct {
|
type queuePDUsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertQueuePDUStmt *sql.Stmt
|
insertQueuePDUStmt *sql.Stmt
|
||||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||||
selectQueuePDUsByTransactionStmt *sql.Stmt
|
selectQueuePDUsByTransactionStmt *sql.Stmt
|
||||||
|
@ -84,7 +83,6 @@ type queuePDUsStatements struct {
|
||||||
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||||
s = &queuePDUsStatements{
|
s = &queuePDUsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err = db.Exec(queuePDUsSchema)
|
_, err = db.Exec(queuePDUsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -121,7 +119,6 @@ func (s *queuePDUsStatements) InsertQueuePDU(
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
nid int64,
|
nid int64,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
|
||||||
_, err := stmt.ExecContext(
|
_, err := stmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -130,7 +127,6 @@ func (s *queuePDUsStatements) InsertQueuePDU(
|
||||||
nid, // JSON blob NID
|
nid, // JSON blob NID
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) DeleteQueuePDUs(
|
func (s *queuePDUsStatements) DeleteQueuePDUs(
|
||||||
|
@ -150,11 +146,9 @@ func (s *queuePDUsStatements) DeleteQueuePDUs(
|
||||||
params[k+1] = v
|
params[k+1] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||||
_, err := stmt.ExecContext(ctx, params...)
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
|
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
|
||||||
|
|
|
@ -44,7 +44,6 @@ const updateRoomSQL = "" +
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertRoomStmt *sql.Stmt
|
insertRoomStmt *sql.Stmt
|
||||||
selectRoomForUpdateStmt *sql.Stmt
|
selectRoomForUpdateStmt *sql.Stmt
|
||||||
updateRoomStmt *sql.Stmt
|
updateRoomStmt *sql.Stmt
|
||||||
|
@ -53,7 +52,6 @@ type roomStatements struct {
|
||||||
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
|
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
|
||||||
s = &roomStatements{
|
s = &roomStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err = db.Exec(roomSchema)
|
_, err = db.Exec(roomSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -77,10 +75,8 @@ func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
|
||||||
func (s *roomStatements) InsertRoom(
|
func (s *roomStatements) InsertRoom(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
|
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
||||||
|
@ -103,9 +99,7 @@ func (s *roomStatements) SelectRoomForUpdate(
|
||||||
func (s *roomStatements) UpdateRoom(
|
func (s *roomStatements) UpdateRoom(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
|
||||||
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
|
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,7 @@ type Database struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
|
@ -39,6 +40,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.writer = sqlutil.NewExclusiveWriter()
|
||||||
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -65,6 +67,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
Writer: d.writer,
|
||||||
FederationSenderJoinedHosts: joinedHosts,
|
FederationSenderJoinedHosts: joinedHosts,
|
||||||
FederationSenderQueuePDUs: queuePDUs,
|
FederationSenderQueuePDUs: queuePDUs,
|
||||||
FederationSenderQueueEDUs: queueEDUs,
|
FederationSenderQueueEDUs: queueEDUs,
|
||||||
|
@ -72,7 +75,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
FederationSenderRooms: rooms,
|
FederationSenderRooms: rooms,
|
||||||
FederationSenderBlacklist: blacklist,
|
FederationSenderBlacklist: blacklist,
|
||||||
}
|
}
|
||||||
if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil {
|
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
|
|
|
@ -53,6 +53,8 @@ const upsertPartitionOffsetsSQL = "" +
|
||||||
|
|
||||||
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
|
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
|
||||||
type PartitionOffsetStatements struct {
|
type PartitionOffsetStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer Writer
|
||||||
selectPartitionOffsetsStmt *sql.Stmt
|
selectPartitionOffsetsStmt *sql.Stmt
|
||||||
upsertPartitionOffsetStmt *sql.Stmt
|
upsertPartitionOffsetStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
@ -60,7 +62,9 @@ type PartitionOffsetStatements struct {
|
||||||
// Prepare converts the raw SQL statements into prepared statements.
|
// Prepare converts the raw SQL statements into prepared statements.
|
||||||
// Takes a prefix to prepend to the table name used to store the partition offsets.
|
// Takes a prefix to prepend to the table name used to store the partition offsets.
|
||||||
// This allows multiple components to share the same database schema.
|
// This allows multiple components to share the same database schema.
|
||||||
func (s *PartitionOffsetStatements) Prepare(db *sql.DB, prefix string) (err error) {
|
func (s *PartitionOffsetStatements) Prepare(db *sql.DB, writer Writer, prefix string) (err error) {
|
||||||
|
s.db = db
|
||||||
|
s.writer = writer
|
||||||
_, err = db.Exec(strings.Replace(partitionOffsetsSchema, "${prefix}", prefix, -1))
|
_, err = db.Exec(strings.Replace(partitionOffsetsSchema, "${prefix}", prefix, -1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -121,6 +125,9 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
|
||||||
func (s *PartitionOffsetStatements) upsertPartitionOffset(
|
func (s *PartitionOffsetStatements) upsertPartitionOffset(
|
||||||
ctx context.Context, topic string, partition int32, offset int64,
|
ctx context.Context, topic string, partition int32, offset int64,
|
||||||
) error {
|
) error {
|
||||||
_, err := s.upsertPartitionOffsetStmt.ExecContext(ctx, topic, partition, offset)
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
|
stmt := TxStmt(txn, s.upsertPartitionOffsetStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, topic, partition, offset)
|
||||||
return err
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,7 +103,3 @@ func SQLiteDriverName() string {
|
||||||
}
|
}
|
||||||
return "sqlite3"
|
return "sqlite3"
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransactionWriter interface {
|
|
||||||
Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
package sqlutil
|
||||||
|
|
||||||
|
import "database/sql"
|
||||||
|
|
||||||
|
// The Writer interface is designed to solve the problem of how
|
||||||
|
// to handle database writes for database engines that don't allow
|
||||||
|
// concurrent writes, e.g. SQLite.
|
||||||
|
//
|
||||||
|
// The interface has a single Do function which takes an optional
|
||||||
|
// database parameter, an optional transaction parameter and a
|
||||||
|
// required function parameter. The Writer will call the function
|
||||||
|
// provided when it is safe to do so, optionally providing a
|
||||||
|
// transaction to use.
|
||||||
|
//
|
||||||
|
// Depending on the combination of parameters provided, the Writer
|
||||||
|
// will behave in one of three ways:
|
||||||
|
//
|
||||||
|
// 1. `db` provided, `txn` provided:
|
||||||
|
//
|
||||||
|
// The Writer will call f() when it is safe to do so. The supplied
|
||||||
|
// "txn" will ALWAYS be passed through to f(). Use this when you
|
||||||
|
// already have a transaction open.
|
||||||
|
//
|
||||||
|
// 2. `db` provided, `txn` not provided (nil):
|
||||||
|
//
|
||||||
|
// The Writer will open a new transaction on the provided database
|
||||||
|
// and then will call f() when it is safe to do so. The new
|
||||||
|
// transaction will ALWAYS be passed through to f(). Use this if
|
||||||
|
// you plan to perform more than one SQL query within f().
|
||||||
|
//
|
||||||
|
// 3. `db` not provided (nil), `txn` not provided (nil):
|
||||||
|
//
|
||||||
|
// The Writer will call f() when it is safe to do so, but will
|
||||||
|
// not make any attempt to open a new database transaction or to
|
||||||
|
// pass through an existing one. The "txn" parameter within f()
|
||||||
|
// will ALWAYS be nil in this mode. This is useful if you just
|
||||||
|
// want to perform a single query on an already-prepared statement
|
||||||
|
// without the overhead of opening a new transaction to do it in.
|
||||||
|
//
|
||||||
|
// You MUST take particular care not to call Do() from within f()
|
||||||
|
// on the same Writer, or it will likely result in a deadlock.
|
||||||
|
type Writer interface {
|
||||||
|
// Queue up one or more database write operations within the
|
||||||
|
// provided function to be executed when it is safe to do so.
|
||||||
|
Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
|
||||||
|
}
|
|
@ -4,15 +4,21 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DummyTransactionWriter struct {
|
// DummyWriter implements sqlutil.Writer.
|
||||||
|
// The DummyWriter is designed to allow reuse of the sqlutil.Writer
|
||||||
|
// interface but, unlike ExclusiveWriter, it will not guarantee
|
||||||
|
// writer exclusivity. This is fine in PostgreSQL where overlapping
|
||||||
|
// transactions and writes are acceptable.
|
||||||
|
type DummyWriter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDummyTransactionWriter() TransactionWriter {
|
// NewDummyWriter returns a new dummy writer.
|
||||||
return &DummyTransactionWriter{}
|
func NewDummyWriter() Writer {
|
||||||
|
return &DummyWriter{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
func (w *DummyWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||||
if txn == nil {
|
if db != nil && txn == nil {
|
||||||
return WithTransaction(db, func(txn *sql.Tx) error {
|
return WithTransaction(db, func(txn *sql.Tx) error {
|
||||||
return f(txn)
|
return f(txn)
|
||||||
})
|
})
|
||||||
|
|
|
@ -7,16 +7,17 @@ import (
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExclusiveTransactionWriter allows queuing database writes so that you don't
|
// ExclusiveWriter implements sqlutil.Writer.
|
||||||
|
// ExclusiveWriter allows queuing database writes so that you don't
|
||||||
// contend on database locks in, e.g. SQLite. Only one task will run
|
// contend on database locks in, e.g. SQLite. Only one task will run
|
||||||
// at a time on a given ExclusiveTransactionWriter.
|
// at a time on a given ExclusiveWriter.
|
||||||
type ExclusiveTransactionWriter struct {
|
type ExclusiveWriter struct {
|
||||||
running atomic.Bool
|
running atomic.Bool
|
||||||
todo chan transactionWriterTask
|
todo chan transactionWriterTask
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTransactionWriter() TransactionWriter {
|
func NewExclusiveWriter() Writer {
|
||||||
return &ExclusiveTransactionWriter{
|
return &ExclusiveWriter{
|
||||||
todo: make(chan transactionWriterTask),
|
todo: make(chan transactionWriterTask),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,7 +35,7 @@ type transactionWriterTask struct {
|
||||||
// txn parameter if one is supplied, and if not, will take out a
|
// txn parameter if one is supplied, and if not, will take out a
|
||||||
// new transaction from the database supplied in the database
|
// new transaction from the database supplied in the database
|
||||||
// parameter. Either way, this will block until the task is done.
|
// parameter. Either way, this will block until the task is done.
|
||||||
func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
func (w *ExclusiveWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||||
if w.todo == nil {
|
if w.todo == nil {
|
||||||
return errors.New("not initialised")
|
return errors.New("not initialised")
|
||||||
}
|
}
|
||||||
|
@ -55,20 +56,20 @@ func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql
|
||||||
// of these goroutines will run at a time. A transaction will be
|
// of these goroutines will run at a time. A transaction will be
|
||||||
// opened using the database object from the task and then this will
|
// opened using the database object from the task and then this will
|
||||||
// be passed as a parameter to the task function.
|
// be passed as a parameter to the task function.
|
||||||
func (w *ExclusiveTransactionWriter) run() {
|
func (w *ExclusiveWriter) run() {
|
||||||
if !w.running.CAS(false, true) {
|
if !w.running.CAS(false, true) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer w.running.Store(false)
|
defer w.running.Store(false)
|
||||||
for task := range w.todo {
|
for task := range w.todo {
|
||||||
if task.txn != nil {
|
if task.db != nil && task.txn != nil {
|
||||||
task.wait <- task.f(task.txn)
|
task.wait <- task.f(task.txn)
|
||||||
} else if task.db != nil {
|
} else if task.db != nil && task.txn == nil {
|
||||||
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
||||||
return task.f(txn)
|
return task.f(txn)
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
panic("expected database or transaction but got neither")
|
task.wait <- task.f(nil)
|
||||||
}
|
}
|
||||||
close(task.wait)
|
close(task.wait)
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" +
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
@ -71,10 +71,10 @@ type deviceKeysStatements struct {
|
||||||
deleteAllDeviceKeysStmt *sql.Stmt
|
deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewSqliteDeviceKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.DeviceKeys, error) {
|
||||||
s := &deviceKeysStatements{
|
s := &deviceKeysStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
writer: writer,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(deviceKeysSchema)
|
_, err := db.Exec(deviceKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -52,15 +52,15 @@ const selectKeyChangesSQL = "" +
|
||||||
|
|
||||||
type keyChangesStatements struct {
|
type keyChangesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
upsertKeyChangeStmt *sql.Stmt
|
upsertKeyChangeStmt *sql.Stmt
|
||||||
selectKeyChangesStmt *sql.Stmt
|
selectKeyChangesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
func NewSqliteKeyChangesTable(db *sql.DB, writer sqlutil.Writer) (tables.KeyChanges, error) {
|
||||||
s := &keyChangesStatements{
|
s := &keyChangesStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
writer: writer,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(keyChangesSchema)
|
_, err := db.Exec(keyChangesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
|
||||||
|
|
||||||
type oneTimeKeysStatements struct {
|
type oneTimeKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
upsertKeysStmt *sql.Stmt
|
upsertKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
selectKeysCountStmt *sql.Stmt
|
selectKeysCountStmt *sql.Stmt
|
||||||
|
@ -68,10 +68,10 @@ type oneTimeKeysStatements struct {
|
||||||
deleteOneTimeKeyStmt *sql.Stmt
|
deleteOneTimeKeyStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
func NewSqliteOneTimeKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.OneTimeKeys, error) {
|
||||||
s := &oneTimeKeysStatements{
|
s := &oneTimeKeysStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
writer: writer,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(oneTimeKeysSchema)
|
_, err := db.Exec(oneTimeKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -49,13 +50,18 @@ const selectStaleDeviceListsSQL = "" +
|
||||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||||
|
|
||||||
type staleDeviceListsStatements struct {
|
type staleDeviceListsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
upsertStaleDeviceListStmt *sql.Stmt
|
upsertStaleDeviceListStmt *sql.Stmt
|
||||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||||
selectStaleDeviceListsStmt *sql.Stmt
|
selectStaleDeviceListsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
func NewSqliteStaleDeviceListsTable(db *sql.DB, writer sqlutil.Writer) (tables.StaleDeviceLists, error) {
|
||||||
s := &staleDeviceListsStatements{}
|
s := &staleDeviceListsStatements{
|
||||||
|
db: db,
|
||||||
|
writer: writer,
|
||||||
|
}
|
||||||
_, err := db.Exec(staleDeviceListsSchema)
|
_, err := db.Exec(staleDeviceListsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -77,8 +83,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.upsertStaleDeviceListStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
||||||
return err
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||||
|
|
|
@ -25,19 +25,20 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
otk, err := NewSqliteOneTimeKeysTable(db)
|
writer := sqlutil.NewExclusiveWriter()
|
||||||
|
otk, err := NewSqliteOneTimeKeysTable(db, writer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk, err := NewSqliteDeviceKeysTable(db)
|
dk, err := NewSqliteDeviceKeysTable(db, writer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
kc, err := NewSqliteKeyChangesTable(db)
|
kc, err := NewSqliteKeyChangesTable(db, writer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sdl, err := NewSqliteStaleDeviceListsTable(db)
|
sdl, err := NewSqliteStaleDeviceListsTable(db, writer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,14 +62,14 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
|
||||||
|
|
||||||
type mediaStatements struct {
|
type mediaStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
insertMediaStmt *sql.Stmt
|
insertMediaStmt *sql.Stmt
|
||||||
selectMediaStmt *sql.Stmt
|
selectMediaStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
|
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
|
|
||||||
_, err = db.Exec(mediaSchema)
|
_, err = db.Exec(mediaSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -17,6 +17,8 @@ package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type statements struct {
|
type statements struct {
|
||||||
|
@ -24,11 +26,11 @@ type statements struct {
|
||||||
thumbnail thumbnailStatements
|
thumbnail thumbnailStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statements) prepare(db *sql.DB) (err error) {
|
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
if err = s.media.prepare(db); err != nil {
|
if err = s.media.prepare(db, writer); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = s.thumbnail.prepare(db); err != nil {
|
if err = s.thumbnail.prepare(db, writer); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,16 +31,19 @@ import (
|
||||||
type Database struct {
|
type Database struct {
|
||||||
statements statements
|
statements statements
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open opens a postgres database.
|
// Open opens a postgres database.
|
||||||
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
var d Database
|
d := Database{
|
||||||
|
writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
}
|
||||||
var err error
|
var err error
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.statements.prepare(d.db); err != nil {
|
if err = d.statements.prepare(d.db, d.writer); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -57,16 +58,20 @@ SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method
|
||||||
`
|
`
|
||||||
|
|
||||||
type thumbnailStatements struct {
|
type thumbnailStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
insertThumbnailStmt *sql.Stmt
|
insertThumbnailStmt *sql.Stmt
|
||||||
selectThumbnailStmt *sql.Stmt
|
selectThumbnailStmt *sql.Stmt
|
||||||
selectThumbnailsStmt *sql.Stmt
|
selectThumbnailsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
|
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
_, err = db.Exec(thumbnailSchema)
|
_, err = db.Exec(thumbnailSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
s.db = db
|
||||||
|
s.writer = writer
|
||||||
|
|
||||||
return statementList{
|
return statementList{
|
||||||
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
{&s.insertThumbnailStmt, insertThumbnailSQL},
|
||||||
|
@ -79,7 +84,9 @@ func (s *thumbnailStatements) insertThumbnail(
|
||||||
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
|
||||||
) error {
|
) error {
|
||||||
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
|
||||||
_, err := s.insertThumbnailStmt.ExecContext(
|
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
|
||||||
|
_, err := stmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
thumbnailMetadata.MediaMetadata.MediaID,
|
thumbnailMetadata.MediaMetadata.MediaID,
|
||||||
thumbnailMetadata.MediaMetadata.Origin,
|
thumbnailMetadata.MediaMetadata.Origin,
|
||||||
|
@ -91,6 +98,7 @@ func (s *thumbnailStatements) insertThumbnail(
|
||||||
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
thumbnailMetadata.ThumbnailSize.ResizeMethod,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *thumbnailStatements) selectThumbnail(
|
func (s *thumbnailStatements) selectThumbnail(
|
||||||
|
|
|
@ -98,7 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
Writer: sqlutil.NewDummyTransactionWriter(),
|
Writer: sqlutil.NewDummyWriter(),
|
||||||
EventTypesTable: eventTypes,
|
EventTypesTable: eventTypes,
|
||||||
EventStateKeysTable: eventStateKeys,
|
EventStateKeysTable: eventStateKeys,
|
||||||
EventJSONTable: eventJSON,
|
EventJSONTable: eventJSON,
|
||||||
|
|
|
@ -27,7 +27,7 @@ const redactionsArePermanent = false
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
Writer sqlutil.TransactionWriter
|
Writer sqlutil.Writer
|
||||||
EventsTable tables.Events
|
EventsTable tables.Events
|
||||||
EventJSONTable tables.EventJSON
|
EventJSONTable tables.EventJSON
|
||||||
EventTypesTable tables.EventTypes
|
EventTypesTable tables.EventTypes
|
||||||
|
|
|
@ -41,7 +41,7 @@ type Database struct {
|
||||||
invites tables.Invites
|
invites tables.Invites
|
||||||
membership tables.Membership
|
membership tables.Membership
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open a sqlite database.
|
// Open a sqlite database.
|
||||||
|
@ -52,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
d.writer = sqlutil.NewTransactionWriter()
|
d.writer = sqlutil.NewExclusiveWriter()
|
||||||
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
||||||
//d.db.Exec("PRAGMA read_uncommitted = true;")
|
//d.db.Exec("PRAGMA read_uncommitted = true;")
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
Writer: sqlutil.NewTransactionWriter(),
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
EventsTable: d.events,
|
EventsTable: d.events,
|
||||||
EventTypesTable: d.eventTypes,
|
EventTypesTable: d.eventTypes,
|
||||||
EventStateKeysTable: d.eventStateKeys,
|
EventStateKeysTable: d.eventStateKeys,
|
||||||
|
|
|
@ -30,6 +30,7 @@ import (
|
||||||
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
|
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
|
||||||
// the public keys for other matrix servers.
|
// the public keys for other matrix servers.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
|
writer sqlutil.Writer
|
||||||
statements serverKeyStatements
|
statements serverKeyStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,8 +48,10 @@ func NewDatabase(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
d := &Database{}
|
d := &Database{
|
||||||
err = d.statements.prepare(db)
|
writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
}
|
||||||
|
err = d.statements.prepare(db, d.writer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,14 +63,14 @@ const upsertServerKeysSQL = "" +
|
||||||
|
|
||||||
type serverKeyStatements struct {
|
type serverKeyStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
bulkSelectServerKeysStmt *sql.Stmt
|
bulkSelectServerKeysStmt *sql.Stmt
|
||||||
upsertServerKeysStmt *sql.Stmt
|
upsertServerKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
|
func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(serverKeysSchema)
|
_, err = db.Exec(serverKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -31,6 +31,7 @@ import (
|
||||||
type SyncServerDatasource struct {
|
type SyncServerDatasource struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +42,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
|
d.writer = sqlutil.NewDummyWriter()
|
||||||
|
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
accountData, err := NewPostgresAccountDataTable(d.db)
|
accountData, err := NewPostgresAccountDataTable(d.db)
|
||||||
|
@ -78,6 +80,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
Writer: sqlutil.NewDummyWriter(),
|
||||||
Invites: invites,
|
Invites: invites,
|
||||||
AccountData: accountData,
|
AccountData: accountData,
|
||||||
OutputEvents: events,
|
OutputEvents: events,
|
||||||
|
@ -86,7 +89,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
||||||
BackwardExtremities: backwardExtremities,
|
BackwardExtremities: backwardExtremities,
|
||||||
Filter: filter,
|
Filter: filter,
|
||||||
SendToDevice: sendToDevice,
|
SendToDevice: sendToDevice,
|
||||||
SendToDeviceWriter: sqlutil.NewTransactionWriter(),
|
|
||||||
EDUCache: cache.New(),
|
EDUCache: cache.New(),
|
||||||
}
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
|
|
|
@ -37,6 +37,7 @@ import (
|
||||||
// For now this contains the shared functions
|
// For now this contains the shared functions
|
||||||
type Database struct {
|
type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
|
Writer sqlutil.Writer
|
||||||
Invites tables.Invites
|
Invites tables.Invites
|
||||||
AccountData tables.AccountData
|
AccountData tables.AccountData
|
||||||
OutputEvents tables.Events
|
OutputEvents tables.Events
|
||||||
|
@ -45,7 +46,6 @@ type Database struct {
|
||||||
BackwardExtremities tables.BackwardsExtremities
|
BackwardExtremities tables.BackwardsExtremities
|
||||||
SendToDevice tables.SendToDevice
|
SendToDevice tables.SendToDevice
|
||||||
Filter tables.Filter
|
Filter tables.Filter
|
||||||
SendToDeviceWriter sqlutil.TransactionWriter
|
|
||||||
EDUCache *cache.EDUCache
|
EDUCache *cache.EDUCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,10 +129,7 @@ func (d *Database) GetStateEvent(
|
||||||
func (d *Database) GetStateEventsForRoom(
|
func (d *Database) GetStateEventsForRoom(
|
||||||
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
|
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
|
||||||
) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) {
|
) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) {
|
||||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter)
|
||||||
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,9 +168,9 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition
|
||||||
func (d *Database) AddInviteEvent(
|
func (d *Database) AddInviteEvent(
|
||||||
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
|
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
|
||||||
) (sp types.StreamPosition, err error) {
|
) (sp types.StreamPosition, err error) {
|
||||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
_ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||||
sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
|
sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent)
|
||||||
return err
|
return nil
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -182,8 +179,12 @@ func (d *Database) AddInviteEvent(
|
||||||
// Returns an error if there was a problem communicating with the database.
|
// Returns an error if there was a problem communicating with the database.
|
||||||
func (d *Database) RetireInviteEvent(
|
func (d *Database) RetireInviteEvent(
|
||||||
ctx context.Context, inviteEventID string,
|
ctx context.Context, inviteEventID string,
|
||||||
) (types.StreamPosition, error) {
|
) (sp types.StreamPosition, err error) {
|
||||||
return d.Invites.DeleteInviteEvent(ctx, inviteEventID)
|
_ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||||
|
sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountDataInRange returns all account data for a given user inserted or
|
// GetAccountDataInRange returns all account data for a given user inserted or
|
||||||
|
@ -207,7 +208,7 @@ func (d *Database) GetAccountDataInRange(
|
||||||
func (d *Database) UpsertAccountData(
|
func (d *Database) UpsertAccountData(
|
||||||
ctx context.Context, userID, roomID, dataType string,
|
ctx context.Context, userID, roomID, dataType string,
|
||||||
) (sp types.StreamPosition, err error) {
|
) (sp types.StreamPosition, err error) {
|
||||||
err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
|
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
@ -237,6 +238,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
|
||||||
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
|
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
|
||||||
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
|
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
|
||||||
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
|
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
|
||||||
|
// This function should always be called within a sqlutil.Writer for safety in SQLite.
|
||||||
func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
|
func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
|
||||||
if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
|
if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -275,7 +277,7 @@ func (d *Database) WriteEvent(
|
||||||
addStateEventIDs, removeStateEventIDs []string,
|
addStateEventIDs, removeStateEventIDs []string,
|
||||||
transactionID *api.TransactionID, excludeFromSync bool,
|
transactionID *api.TransactionID, excludeFromSync bool,
|
||||||
) (pduPosition types.StreamPosition, returnErr error) {
|
) (pduPosition types.StreamPosition, returnErr error) {
|
||||||
returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
pos, err := d.OutputEvents.InsertEvent(
|
pos, err := d.OutputEvents.InsertEvent(
|
||||||
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
|
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
|
||||||
|
@ -304,6 +306,7 @@ func (d *Database) WriteEvent(
|
||||||
return pduPosition, returnErr
|
return pduPosition, returnErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This function should always be called within a sqlutil.Writer for safety in SQLite.
|
||||||
func (d *Database) updateRoomState(
|
func (d *Database) updateRoomState(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
removedEventIDs []string,
|
removedEventIDs []string,
|
||||||
|
@ -1114,7 +1117,7 @@ func (d *Database) StoreNewSendForDeviceMessage(
|
||||||
}
|
}
|
||||||
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
|
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
|
||||||
// that we don't lock the table for writes in more than one place.
|
// that we don't lock the table for writes in more than one place.
|
||||||
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.AddSendToDeviceEvent(
|
return d.AddSendToDeviceEvent(
|
||||||
ctx, txn, userID, deviceID, string(j),
|
ctx, txn, userID, deviceID, string(j),
|
||||||
)
|
)
|
||||||
|
@ -1179,7 +1182,7 @@ func (d *Database) CleanSendToDeviceUpdates(
|
||||||
// If we need to write to the database then we'll ask the SendToDeviceWriter to
|
// If we need to write to the database then we'll ask the SendToDeviceWriter to
|
||||||
// do that for us. It'll guarantee that we don't lock the table for writes in
|
// do that for us. It'll guarantee that we don't lock the table for writes in
|
||||||
// more than one place.
|
// more than one place.
|
||||||
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
// Delete any send-to-device messages marked for deletion.
|
// Delete any send-to-device messages marked for deletion.
|
||||||
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
||||||
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -51,7 +50,6 @@ const selectMaxAccountDataIDSQL = "" +
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
selectMaxAccountDataIDStmt *sql.Stmt
|
selectMaxAccountDataIDStmt *sql.Stmt
|
||||||
|
@ -61,7 +59,6 @@ type accountDataStatements struct {
|
||||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
||||||
s := &accountDataStatements{
|
s := &accountDataStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(accountDataSchema)
|
_, err := db.Exec(accountDataSchema)
|
||||||
|
@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
userID, roomID, dataType string,
|
userID, roomID, dataType string,
|
||||||
) (pos types.StreamPosition, err error) {
|
) (pos types.StreamPosition, err error) {
|
||||||
return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
||||||
return err
|
return
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) SelectAccountDataInRange(
|
func (s *accountDataStatements) SelectAccountDataInRange(
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,7 +48,6 @@ const deleteBackwardExtremitySQL = "" +
|
||||||
|
|
||||||
type backwardExtremitiesStatements struct {
|
type backwardExtremitiesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertBackwardExtremityStmt *sql.Stmt
|
insertBackwardExtremityStmt *sql.Stmt
|
||||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||||
deleteBackwardExtremityStmt *sql.Stmt
|
deleteBackwardExtremityStmt *sql.Stmt
|
||||||
|
@ -58,7 +56,6 @@ type backwardExtremitiesStatements struct {
|
||||||
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||||
s := &backwardExtremitiesStatements{
|
s := &backwardExtremitiesStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(backwardExtremitiesSchema)
|
_, err := db.Exec(backwardExtremitiesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
||||||
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||||
_, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||||
|
@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||||
_, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" +
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
|
@ -98,7 +97,6 @@ type currentRoomStateStatements struct {
|
||||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
||||||
s := ¤tRoomStateStatements{
|
s := ¤tRoomStateStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(currentRoomStateSchema)
|
_, err := db.Exec(currentRoomStateSchema)
|
||||||
|
@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
||||||
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
|
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
|
||||||
ctx context.Context, txn *sql.Tx, eventID string,
|
ctx context.Context, txn *sql.Tx, eventID string,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
||||||
_, err := stmt.ExecContext(ctx, eventID)
|
_, err := stmt.ExecContext(ctx, eventID)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) UpsertRoomState(
|
func (s *currentRoomStateStatements) UpsertRoomState(
|
||||||
|
@ -225,9 +221,8 @@ func (s *currentRoomStateStatements) UpsertRoomState(
|
||||||
}
|
}
|
||||||
|
|
||||||
// upsert state event
|
// upsert state event
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
||||||
_, err := stmt.ExecContext(
|
_, err = stmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
event.RoomID(),
|
event.RoomID(),
|
||||||
event.EventID(),
|
event.EventID(),
|
||||||
|
@ -240,7 +235,6 @@ func (s *currentRoomStateStatements) UpsertRoomState(
|
||||||
addedAt,
|
addedAt,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func minOfInts(a, b int) int {
|
func minOfInts(a, b int) int {
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -52,7 +51,6 @@ const insertFilterSQL = "" +
|
||||||
|
|
||||||
type filterStatements struct {
|
type filterStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
selectFilterStmt *sql.Stmt
|
selectFilterStmt *sql.Stmt
|
||||||
selectFilterIDByContentStmt *sql.Stmt
|
selectFilterIDByContentStmt *sql.Stmt
|
||||||
insertFilterStmt *sql.Stmt
|
insertFilterStmt *sql.Stmt
|
||||||
|
@ -65,7 +63,6 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
|
||||||
}
|
}
|
||||||
s := &filterStatements{
|
s := &filterStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -114,7 +111,6 @@ func (s *filterStatements) InsertFilter(
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
|
||||||
// Check if filter already exists in the database using its localpart and content
|
// Check if filter already exists in the database using its localpart and content
|
||||||
//
|
//
|
||||||
// This can result in a race condition when two clients try to insert the
|
// This can result in a race condition when two clients try to insert the
|
||||||
|
@ -123,24 +119,22 @@ func (s *filterStatements) InsertFilter(
|
||||||
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||||
localpart, filterJSON).Scan(&existingFilterID)
|
localpart, filterJSON).Scan(&existingFilterID)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
// If it does, return the existing ID
|
// If it does, return the existing ID
|
||||||
if existingFilterID != "" {
|
if existingFilterID != "" {
|
||||||
return nil
|
return existingFilterID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise insert the filter and return the new ID
|
// Otherwise insert the filter and return the new ID
|
||||||
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
rowid, err := res.LastInsertId()
|
rowid, err := res.LastInsertId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
filterID = fmt.Sprintf("%d", rowid)
|
filterID = fmt.Sprintf("%d", rowid)
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" +
|
||||||
|
|
||||||
type inviteEventsStatements struct {
|
type inviteEventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertInviteEventStmt *sql.Stmt
|
insertInviteEventStmt *sql.Stmt
|
||||||
selectInviteEventsInRangeStmt *sql.Stmt
|
selectInviteEventsInRangeStmt *sql.Stmt
|
||||||
|
@ -70,7 +69,6 @@ type inviteEventsStatements struct {
|
||||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
||||||
s := &inviteEventsStatements{
|
s := &inviteEventsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(inviteEventsSchema)
|
_, err := db.Exec(inviteEventsSchema)
|
||||||
|
@ -95,20 +93,19 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
|
||||||
func (s *inviteEventsStatements) InsertInviteEvent(
|
func (s *inviteEventsStatements) InsertInviteEvent(
|
||||||
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
|
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
|
||||||
) (streamPos types.StreamPosition, err error) {
|
) (streamPos types.StreamPosition, err error) {
|
||||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var headeredJSON []byte
|
var headeredJSON []byte
|
||||||
headeredJSON, err = json.Marshal(inviteEvent)
|
headeredJSON, err = json.Marshal(inviteEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
|
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
|
||||||
|
_, err = stmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
streamPos,
|
streamPos,
|
||||||
inviteEvent.RoomID(),
|
inviteEvent.RoomID(),
|
||||||
|
@ -116,24 +113,17 @@ func (s *inviteEventsStatements) InsertInviteEvent(
|
||||||
*inviteEvent.StateKey(),
|
*inviteEvent.StateKey(),
|
||||||
headeredJSON,
|
headeredJSON,
|
||||||
)
|
)
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteEventsStatements) DeleteInviteEvent(
|
func (s *inviteEventsStatements) DeleteInviteEvent(
|
||||||
ctx context.Context, inviteEventID string,
|
ctx context.Context, inviteEventID string,
|
||||||
) (types.StreamPosition, error) {
|
) (types.StreamPosition, error) {
|
||||||
var streamPos types.StreamPosition
|
streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
|
||||||
err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return streamPos, err
|
||||||
}
|
}
|
||||||
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
|
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
|
||||||
return err
|
|
||||||
})
|
|
||||||
return streamPos, err
|
return streamPos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" +
|
||||||
|
|
||||||
type outputRoomEventsStatements struct {
|
type outputRoomEventsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
streamIDStatements *streamIDStatements
|
streamIDStatements *streamIDStatements
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventsStmt *sql.Stmt
|
selectEventsStmt *sql.Stmt
|
||||||
|
@ -120,7 +119,6 @@ type outputRoomEventsStatements struct {
|
||||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
||||||
s := &outputRoomEventsStatements{
|
s := &outputRoomEventsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
streamIDStatements: streamID,
|
streamIDStatements: streamID,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(outputRoomEventsSchema)
|
_, err := db.Exec(outputRoomEventsSchema)
|
||||||
|
@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
|
||||||
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
|
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
|
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
|
||||||
|
@ -304,15 +300,12 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var streamPos types.StreamPosition
|
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
|
||||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||||
_, ierr := insertStmt.ExecContext(
|
_, err = insertStmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
streamPos,
|
streamPos,
|
||||||
event.RoomID(),
|
event.RoomID(),
|
||||||
|
@ -328,8 +321,6 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
||||||
excludeFromSync,
|
excludeFromSync,
|
||||||
excludeFromSync,
|
excludeFromSync,
|
||||||
)
|
)
|
||||||
return ierr
|
|
||||||
})
|
|
||||||
return streamPos, err
|
return streamPos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" +
|
||||||
|
|
||||||
type outputRoomEventsTopologyStatements struct {
|
type outputRoomEventsTopologyStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertEventInTopologyStmt *sql.Stmt
|
insertEventInTopologyStmt *sql.Stmt
|
||||||
selectEventIDsInRangeASCStmt *sql.Stmt
|
selectEventIDsInRangeASCStmt *sql.Stmt
|
||||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||||
|
@ -78,7 +77,6 @@ type outputRoomEventsTopologyStatements struct {
|
||||||
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
||||||
s := &outputRoomEventsTopologyStatements{
|
s := &outputRoomEventsTopologyStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(outputRoomEventsTopologySchema)
|
_, err := db.Exec(outputRoomEventsTopologySchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
||||||
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
|
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
|
||||||
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
|
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
|
||||||
_, err := stmt.ExecContext(
|
_, err = stmt.ExecContext(
|
||||||
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
|
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
|
||||||
)
|
)
|
||||||
return err
|
return
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||||
|
|
|
@ -73,7 +73,6 @@ const deleteSendToDeviceMessagesSQL = `
|
||||||
|
|
||||||
type sendToDeviceStatements struct {
|
type sendToDeviceStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
insertSendToDeviceMessageStmt *sql.Stmt
|
insertSendToDeviceMessageStmt *sql.Stmt
|
||||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||||
countSendToDeviceMessagesStmt *sql.Stmt
|
countSendToDeviceMessagesStmt *sql.Stmt
|
||||||
|
@ -82,7 +81,6 @@ type sendToDeviceStatements struct {
|
||||||
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
s := &sendToDeviceStatements{
|
s := &sendToDeviceStatements{
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewTransactionWriter(),
|
|
||||||
}
|
}
|
||||||
_, err := db.Exec(sendToDeviceSchema)
|
_, err := db.Exec(sendToDeviceSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||||
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
return
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
||||||
|
@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
||||||
for k, v := range nids {
|
for k, v := range nids {
|
||||||
params[k+1] = v
|
params[k+1] = v
|
||||||
}
|
}
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
_, err = txn.ExecContext(ctx, query, params...)
|
||||||
_, err := txn.ExecContext(ctx, query, params...)
|
return
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||||
|
@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||||
for k, v := range nids {
|
for k, v := range nids {
|
||||||
params[k] = v
|
params[k] = v
|
||||||
}
|
}
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
_, err = txn.ExecContext(ctx, query, params...)
|
||||||
_, err := txn.ExecContext(ctx, query, params...)
|
return
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,14 +28,12 @@ const selectStreamIDStmt = "" +
|
||||||
|
|
||||||
type streamIDStatements struct {
|
type streamIDStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
|
||||||
increaseStreamIDStmt *sql.Stmt
|
increaseStreamIDStmt *sql.Stmt
|
||||||
selectStreamIDStmt *sql.Stmt
|
selectStreamIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
|
||||||
_, err = db.Exec(streamIDTableSchema)
|
_, err = db.Exec(streamIDTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||||
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
||||||
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
|
||||||
if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil {
|
return
|
||||||
return ierr
|
}
|
||||||
}
|
err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
||||||
if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
|
|
||||||
return serr
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,7 @@ import (
|
||||||
type SyncServerDatasource struct {
|
type SyncServerDatasource struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
streamID streamIDStatements
|
streamID streamIDStatements
|
||||||
}
|
}
|
||||||
|
@ -44,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
||||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.writer = sqlutil.NewExclusiveWriter()
|
||||||
if err = d.prepare(); err != nil {
|
if err = d.prepare(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -51,7 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatasource) prepare() (err error) {
|
func (d *SyncServerDatasource) prepare() (err error) {
|
||||||
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
|
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = d.streamID.prepare(d.db); err != nil {
|
if err = d.streamID.prepare(d.db); err != nil {
|
||||||
|
@ -91,6 +93,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
|
||||||
}
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
Invites: invites,
|
Invites: invites,
|
||||||
AccountData: accountData,
|
AccountData: accountData,
|
||||||
OutputEvents: events,
|
OutputEvents: events,
|
||||||
|
@ -99,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) {
|
||||||
Topology: topology,
|
Topology: topology,
|
||||||
Filter: filter,
|
Filter: filter,
|
||||||
SendToDevice: sendToDevice,
|
SendToDevice: sendToDevice,
|
||||||
SendToDeviceWriter: sqlutil.NewTransactionWriter(),
|
|
||||||
EDUCache: cache.New(),
|
EDUCache: cache.New(),
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -35,6 +35,7 @@ import (
|
||||||
// Database represents an account database
|
// Database represents an account database
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
|
@ -49,27 +50,27 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
partitions := sqlutil.PartitionOffsetStatements{}
|
d := &Database{
|
||||||
if err = partitions.Prepare(db, "account"); err != nil {
|
serverName: serverName,
|
||||||
|
db: db,
|
||||||
|
writer: sqlutil.NewDummyWriter(),
|
||||||
|
}
|
||||||
|
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
a := accountsStatements{}
|
if err = d.accounts.prepare(db, serverName); err != nil {
|
||||||
if err = a.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
p := profilesStatements{}
|
if err = d.profiles.prepare(db); err != nil {
|
||||||
if err = p.prepare(db); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ac := accountDataStatements{}
|
if err = d.accountDatas.prepare(db); err != nil {
|
||||||
if err = ac.prepare(db); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t := threepidStatements{}
|
if err = d.threepids.prepare(db); err != nil {
|
||||||
if err = t.prepare(db); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Database{db, partitions, a, p, ac, t, serverName}, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
|
|
@ -51,15 +51,15 @@ const selectAccountDataByTypeSQL = "" +
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
selectAccountDataStmt *sql.Stmt
|
selectAccountDataStmt *sql.Stmt
|
||||||
selectAccountDataByTypeStmt *sql.Stmt
|
selectAccountDataByTypeStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(accountDataSchema)
|
_, err = db.Exec(accountDataSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" +
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
selectAccountByLocalpartStmt *sql.Stmt
|
selectAccountByLocalpartStmt *sql.Stmt
|
||||||
selectPasswordHashStmt *sql.Stmt
|
selectPasswordHashStmt *sql.Stmt
|
||||||
|
@ -67,9 +67,9 @@ type accountsStatements struct {
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(accountsSchema)
|
_, err = db.Exec(accountsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" +
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
insertProfileStmt *sql.Stmt
|
insertProfileStmt *sql.Stmt
|
||||||
selectProfileByLocalpartStmt *sql.Stmt
|
selectProfileByLocalpartStmt *sql.Stmt
|
||||||
setAvatarURLStmt *sql.Stmt
|
setAvatarURLStmt *sql.Stmt
|
||||||
|
@ -61,9 +61,9 @@ type profilesStatements struct {
|
||||||
selectProfilesBySearchStmt *sql.Stmt
|
selectProfilesBySearchStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(profilesSchema)
|
_, err = db.Exec(profilesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -34,6 +34,8 @@ import (
|
||||||
// Database represents an account database
|
// Database represents an account database
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
|
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
|
@ -53,35 +55,28 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
partitions := sqlutil.PartitionOffsetStatements{}
|
d := &Database{
|
||||||
if err = partitions.Prepare(db, "account"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
a := accountsStatements{}
|
|
||||||
if err = a.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
p := profilesStatements{}
|
|
||||||
if err = p.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ac := accountDataStatements{}
|
|
||||||
if err = ac.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t := threepidStatements{}
|
|
||||||
if err = t.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Database{
|
|
||||||
db: db,
|
|
||||||
PartitionOffsetStatements: partitions,
|
|
||||||
accounts: a,
|
|
||||||
profiles: p,
|
|
||||||
accountDatas: ac,
|
|
||||||
threepids: t,
|
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
}, nil
|
db: db,
|
||||||
|
writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
}
|
||||||
|
partitions := sqlutil.PartitionOffsetStatements{}
|
||||||
|
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = d.accounts.prepare(db, d.writer, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = d.profiles.prepare(db, d.writer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = d.accountDatas.prepare(db, d.writer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = d.threepids.prepare(db, d.writer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
|
|
@ -54,16 +54,16 @@ const deleteThreePIDSQL = "" +
|
||||||
|
|
||||||
type threepidStatements struct {
|
type threepidStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
selectLocalpartForThreePIDStmt *sql.Stmt
|
selectLocalpartForThreePIDStmt *sql.Stmt
|
||||||
selectThreePIDsForLocalpartStmt *sql.Stmt
|
selectThreePIDsForLocalpartStmt *sql.Stmt
|
||||||
insertThreePIDStmt *sql.Stmt
|
insertThreePIDStmt *sql.Stmt
|
||||||
deleteThreePIDStmt *sql.Stmt
|
deleteThreePIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(threepidSchema)
|
_, err = db.Exec(threepidSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" +
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.TransactionWriter
|
writer sqlutil.Writer
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
selectDevicesCountStmt *sql.Stmt
|
selectDevicesCountStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
|
@ -91,9 +91,9 @@ type devicesStatements struct {
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = sqlutil.NewTransactionWriter()
|
s.writer = writer
|
||||||
_, err = db.Exec(devicesSchema)
|
_, err = db.Exec(devicesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -138,19 +138,13 @@ func (s *devicesStatements) insertDevice(
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
var sessionID int64
|
var sessionID int64
|
||||||
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
|
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||||
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
sessionID++
|
sessionID++
|
||||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
|
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &api.Device{
|
return &api.Device{
|
||||||
|
@ -164,11 +158,9 @@ func (s *devicesStatements) insertDevice(
|
||||||
func (s *devicesStatements) deleteDevice(
|
func (s *devicesStatements) deleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevices(
|
func (s *devicesStatements) deleteDevices(
|
||||||
|
@ -179,7 +171,6 @@ func (s *devicesStatements) deleteDevices(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, prep)
|
stmt := sqlutil.TxStmt(txn, prep)
|
||||||
params := make([]interface{}, len(devices)+1)
|
params := make([]interface{}, len(devices)+1)
|
||||||
params[0] = localpart
|
params[0] = localpart
|
||||||
|
@ -188,27 +179,22 @@ func (s *devicesStatements) deleteDevices(
|
||||||
}
|
}
|
||||||
_, err = stmt.ExecContext(ctx, params...)
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart)
|
_, err := stmt.ExecContext(ctx, localpart)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceName(
|
func (s *devicesStatements) updateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||||
return err
|
return err
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDeviceByToken(
|
func (s *devicesStatements) selectDeviceByToken(
|
||||||
|
|
|
@ -34,6 +34,7 @@ var deviceIDByteLength = 6
|
||||||
// Database represents a device database.
|
// Database represents a device database.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer sqlutil.Writer
|
||||||
devices devicesStatements
|
devices devicesStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
writer := sqlutil.NewExclusiveWriter()
|
||||||
d := devicesStatements{}
|
d := devicesStatements{}
|
||||||
if err = d.prepare(db, serverName); err != nil {
|
if err = d.prepare(db, writer, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Database{db, d}, nil
|
return &Database{db, writer, d}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
@ -88,7 +90,7 @@ func (d *Database) CreateDevice(
|
||||||
displayName *string,
|
displayName *string,
|
||||||
) (dev *api.Device, returnErr error) {
|
) (dev *api.Device, returnErr error) {
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
// Revoke existing tokens for this device
|
// Revoke existing tokens for this device
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||||
|
@ -108,7 +110,7 @@ func (d *Database) CreateDevice(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
||||||
return err
|
return err
|
||||||
|
@ -138,7 +140,7 @@ func generateDeviceID() (string, error) {
|
||||||
func (d *Database) UpdateDevice(
|
func (d *Database) UpdateDevice(
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -150,7 +152,7 @@ func (d *Database) UpdateDevice(
|
||||||
func (d *Database) RemoveDevice(
|
func (d *Database) RemoveDevice(
|
||||||
ctx context.Context, deviceID, localpart string,
|
ctx context.Context, deviceID, localpart string,
|
||||||
) error {
|
) error {
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -165,7 +167,7 @@ func (d *Database) RemoveDevice(
|
||||||
func (d *Database) RemoveDevices(
|
func (d *Database) RemoveDevices(
|
||||||
ctx context.Context, localpart string, devices []string,
|
ctx context.Context, localpart string, devices []string,
|
||||||
) error {
|
) error {
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -179,7 +181,7 @@ func (d *Database) RemoveDevices(
|
||||||
func (d *Database) RemoveAllDevices(
|
func (d *Database) RemoveAllDevices(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) error {
|
) error {
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
|
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue