diff --git a/clientapi/auth/storage/accounts/sqlite3/filter_table.go b/clientapi/auth/storage/accounts/sqlite3/filter_table.go index 691ead77..7f1a0c24 100644 --- a/clientapi/auth/storage/accounts/sqlite3/filter_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/filter_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "github.com/matrix-org/gomatrixserverlib" ) @@ -47,14 +48,10 @@ const selectFilterIDByContentSQL = "" + const insertFilterSQL = "" + "INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)" -const selectLastInsertedFilterIDSQL = "" + - "SELECT id FROM account_filter WHERE rowid = last_insert_rowid()" - type filterStatements struct { - selectFilterStmt *sql.Stmt - selectLastInsertedFilterIDStmt *sql.Stmt - selectFilterIDByContentStmt *sql.Stmt - insertFilterStmt *sql.Stmt + selectFilterStmt *sql.Stmt + selectFilterIDByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt } func (s *filterStatements) prepare(db *sql.DB) (err error) { @@ -65,9 +62,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return } - if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil { - return - } if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { return } @@ -128,12 +122,14 @@ func (s *filterStatements) insertFilter( } // Otherwise insert the filter and return the new ID - if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil { + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { return "", err } - row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx) - if err := row.Scan(&filterID); err != nil { + rowid, err := res.LastInsertId() + if err != nil { return "", err } + filterID = fmt.Sprintf("%d", rowid) return } diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 5a0f0bf7..641f8015 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -52,16 +52,18 @@ const selectInviteActiveForUserInRoomSQL = "" + // However the matrix protocol doesn't give us a way to reliably identify the // invites that were retired, so we are forced to retire all of them. const updateInviteRetiredSQL = ` - UPDATE roomserver_invites SET retired = TRUE - WHERE room_nid = $1 AND target_nid = $2 AND NOT retired; - SELECT invite_event_id FROM roomserver_invites - WHERE rowid = last_insert_rowid(); + UPDATE roomserver_invites SET retired = TRUE WHERE room_nid = $1 AND target_nid = $2 AND NOT retired +` + +const selectInvitesAboutToRetireSQL = ` +SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_nid = $2 AND NOT retired ` type inviteStatements struct { insertInviteEventStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt + selectInvitesAboutToRetireStmt *sql.Stmt } func (s *inviteStatements) prepare(db *sql.DB) (err error) { @@ -74,6 +76,7 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) { {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, + {&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL}, }.prepare(db) } @@ -102,7 +105,8 @@ func (s *inviteStatements) updateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { - stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) + // gather all the event IDs we will retire + stmt := txn.Stmt(s.selectInvitesAboutToRetireStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err @@ -110,11 +114,15 @@ func (s *inviteStatements) updateInviteRetired( defer (func() { err = rows.Close() })() for rows.Next() { var inviteEventID string - if err := rows.Scan(&inviteEventID); err != nil { + if err = rows.Scan(&inviteEventID); err != nil { return nil, err } eventIDs = append(eventIDs, inviteEventID) } + + // now retire the invites + stmt = txn.Stmt(s.updateInviteRetiredStmt) + _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) return } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index a5fd5449..71238b0e 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -103,6 +103,8 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return } + defer rows.Close() // nolint: errcheck + for rows.Next() { var alias string if err = rows.Scan(&alias); err != nil { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index ac593546..d75abcee 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -30,7 +30,7 @@ import ( const stateDataSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_block ( - state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + state_block_nid INTEGER NOT NULL, event_type_nid INTEGER NOT NULL, event_state_key_nid INTEGER NOT NULL, event_nid INTEGER NOT NULL, @@ -43,10 +43,7 @@ const insertStateDataSQL = "" + " VALUES ($1, $2, $3, $4)" const selectNextStateBlockNIDSQL = ` - SELECT COALESCE(( - SELECT seq+1 AS state_block_nid FROM sqlite_sequence - WHERE name = 'roomserver_state_block'), 1 - ) AS state_block_nid +SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block ` // Bulk state lookup by numeric state block ID. @@ -98,11 +95,19 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { func (s *stateBlockStatements) bulkInsertStateData( ctx context.Context, txn *sql.Tx, - stateBlockNID types.StateBlockNID, entries []types.StateEntry, -) error { +) (types.StateBlockNID, error) { + if len(entries) == 0 { + return 0, nil + } + var stateBlockNID types.StateBlockNID + err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + if err != nil { + return 0, err + } + for _, entry := range entries { - _, err := common.TxStmt(txn, s.insertStateDataStmt).ExecContext( + _, err := txn.Stmt(s.insertStateDataStmt).ExecContext( ctx, int64(stateBlockNID), int64(entry.EventTypeNID), @@ -110,20 +115,10 @@ func (s *stateBlockStatements) bulkInsertStateData( int64(entry.EventNID), ) if err != nil { - return err + return 0, err } } - return nil -} - -func (s *stateBlockStatements) selectNextStateBlockNID( - ctx context.Context, - txn *sql.Tx, -) (types.StateBlockNID, error) { - var stateBlockNID int64 - selectStmt := common.TxStmt(txn, s.selectNextStateBlockNIDStmt) - err := selectStmt.QueryRowContext(ctx).Scan(&stateBlockNID) - return types.StateBlockNID(stateBlockNID), err + return stateBlockNID, nil } func (s *stateBlockStatements) bulkSelectStateBlockEntries( diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index e20e8aed..aebb308c 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -54,7 +54,12 @@ func Open(dataSourceName string) (*Database, error) { } //d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA read_uncommitted = true;") - d.db.SetMaxOpenConns(2) + + // FIXME: We are leaking connections somewhere. Setting this to 2 will eventually + // cause the roomserver to be unresponsive to new events because something will + // acquire the global mutex and never unlock it because it is waiting for a connection + // which it will never obtain. + d.db.SetMaxOpenConns(20) if err = d.statements.prepare(d.db); err != nil { return nil, err } @@ -253,12 +258,13 @@ func (d *Database) Events( ) ([]types.Event, error) { var eventJSONs []eventJSONPair var err error - results := make([]types.Event, len(eventNIDs)) + var results []types.Event err = common.WithTransaction(d.db, func(txn *sql.Tx) error { eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil || len(eventJSONs) == 0 { return nil } + results = make([]types.Event, len(eventJSONs)) for i, eventJSON := range eventJSONs { result := &results[i] result.EventNID = eventJSON.EventNID @@ -286,13 +292,10 @@ func (d *Database) AddState( err = common.WithTransaction(d.db, func(txn *sql.Tx) error { if len(state) > 0 { var stateBlockNID types.StateBlockNID - stateBlockNID, err = d.statements.selectNextStateBlockNID(ctx, txn) + stateBlockNID, err = d.statements.bulkInsertStateData(ctx, txn, state) if err != nil { return err } - if err = d.statements.bulkInsertStateData(ctx, txn, stateBlockNID, state); err != nil { - return err - } stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) } stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) @@ -602,8 +605,9 @@ func (d *Database) StateEntriesForTuples( // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, -) (types.MembershipUpdater, error) { - txn, err := d.db.Begin() +) (updater types.MembershipUpdater, err error) { + var txn *sql.Tx + txn, err = d.db.Begin() if err != nil { return nil, err } @@ -611,6 +615,18 @@ func (d *Database) MembershipUpdater( defer func() { if !succeeded { txn.Rollback() // nolint: errcheck + } else { + // TODO: We should be holding open this transaction but we cannot have + // multiple write transactions on sqlite. The code will perform additional + // write transactions independent of this one which will consistently cause + // 'database is locked' errors. For now, we'll break up the transaction and + // hope we don't race too catastrophically. Long term, we should be able to + // thread in txn objects where appropriate (either at the interface level or + // bring matrix business logic into the storage layer). + txerr := txn.Commit() + if err == nil && txerr != nil { + err = txerr + } } }() @@ -624,7 +640,7 @@ func (d *Database) MembershipUpdater( return nil, err } - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) if err != nil { return nil, err } @@ -658,7 +674,8 @@ func (d *Database) membershipUpdaterTxn( } return &membershipUpdater{ - transaction{ctx, txn}, d, roomNID, targetUserNID, membership, + // purposefully set the txn to nil so if we try to use it we panic and fail fast + transaction{ctx, nil}, d, roomNID, targetUserNID, membership, }, nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index ba1b7dc5..5dbef4b7 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -158,6 +158,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ "event": string(msg.Event.JSON()), + "pdupos": pduPos, log.ErrorKey: err, }).Panicf("roomserver output log: write invite failure") return nil diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 3274e66e..71105d0c 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -19,15 +19,13 @@ import ( "context" "database/sql" - "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` CREATE TABLE IF NOT EXISTS syncapi_account_data_type ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id INTEGER PRIMARY KEY, user_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, @@ -43,9 +41,7 @@ const insertAccountDataSQL = "" + const selectAccountDataInRangeSQL = "" + "SELECT room_id, type FROM syncapi_account_data_type" + " WHERE user_id = $1 AND id > $2 AND id <= $3" + - " AND ( $4 IS NULL OR type IN ($4) )" + - " AND ( $5 IS NULL OR NOT(type IN ($5)) )" + - " ORDER BY id ASC LIMIT $6" + " ORDER BY id ASC" const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" @@ -53,8 +49,8 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt - selectAccountDataInRangeStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt + selectAccountDataInRangeStmt *sql.Stmt } func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { @@ -66,10 +62,10 @@ func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { return } - if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { + if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { return } - if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { + if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { return } return @@ -83,8 +79,7 @@ func (s *accountDataStatements) insertAccountData( if err != nil { return } - insertStmt := common.TxStmt(txn, s.insertAccountDataStmt) - _, err = insertStmt.ExecContext(ctx, pos, userID, roomID, dataType) + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) return } @@ -103,14 +98,13 @@ func (s *accountDataStatements) selectAccountDataInRange( oldPos-- } - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos, - pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.NotTypes)), - accountDataFilterPart.Limit, - ) + rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos) if err != nil { return } + defer rows.Close() // nolint: errcheck + + var entries int for rows.Next() { var dataType string @@ -120,22 +114,41 @@ func (s *accountDataStatements) selectAccountDataInRange( return } + // check if we should add this by looking at the filter. + // It would be nice if we could do this in SQL-land, but the mix of variadic + // and positional parameters makes the query annoyingly hard to do, it's easier + // and clearer to do it in Go-land. If there are no filters for [not]types then + // this gets skipped. + for _, includeType := range accountDataFilterPart.Types { + if includeType != dataType { // TODO: wildcard support + continue + } + } + for _, excludeType := range accountDataFilterPart.NotTypes { + if excludeType == dataType { // TODO: wildcard support + continue + } + } + if len(data[roomID]) > 0 { data[roomID] = append(data[roomID], dataType) } else { data[roomID] = []string{dataType} } + entries++ + if entries >= accountDataFilterPart.Limit { + break + } } - return + return data, nil } func (s *accountDataStatements) selectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := common.TxStmt(txn, s.selectMaxAccountDataIDStmt) - err = stmt.QueryRowContext(ctx).Scan(&nullableID) + err = txn.Stmt(s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 4ce94666..eb969c95 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "strings" "github.com/lib/pq" "github.com/matrix-org/dendrite/common" @@ -88,7 +89,6 @@ type currentRoomStateStatements struct { selectRoomIDsWithMembershipStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt - selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt } @@ -113,9 +113,6 @@ func (s *currentRoomStateStatements) prepare(db *sql.DB, streamID *streamIDState if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return } - if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { - return - } if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { return } @@ -233,8 +230,12 @@ func (s *currentRoomStateStatements) upsertRoomState( func (s *currentRoomStateStatements) selectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { - stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt) - rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) + rows, err := txn.QueryContext(ctx, query, iEventIDs...) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 74dba245..baf8871b 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -26,7 +26,7 @@ import ( const inviteEventsSchema = ` CREATE TABLE IF NOT EXISTS syncapi_invite_events ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id INTEGER PRIMARY KEY, event_id TEXT NOT NULL, room_id TEXT NOT NULL, target_user_id TEXT NOT NULL, @@ -39,11 +39,8 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events const insertInviteEventSQL = "" + "INSERT INTO syncapi_invite_events" + - " (room_id, event_id, target_user_id, event_json)" + - " VALUES ($1, $2, $3, $4)" - -const selectLastInsertedInviteEventSQL = "" + - "SELECT id FROM syncapi_invite_events WHERE rowid = last_insert_rowid()" + " (id, room_id, event_id, target_user_id, event_json)" + + " VALUES ($1, $2, $3, $4, $5)" const deleteInviteEventSQL = "" + "DELETE FROM syncapi_invite_events WHERE event_id = $1" @@ -57,12 +54,11 @@ const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" type inviteEventsStatements struct { - streamIDStatements *streamIDStatements - insertInviteEventStmt *sql.Stmt - selectLastInsertedInviteEventStmt *sql.Stmt - selectInviteEventsInRangeStmt *sql.Stmt - deleteInviteEventStmt *sql.Stmt - selectMaxInviteIDStmt *sql.Stmt + streamIDStatements *streamIDStatements + insertInviteEventStmt *sql.Stmt + selectInviteEventsInRangeStmt *sql.Stmt + deleteInviteEventStmt *sql.Stmt + selectMaxInviteIDStmt *sql.Stmt } func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { @@ -74,9 +70,6 @@ func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatement if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { return } - if s.selectLastInsertedInviteEventStmt, err = db.Prepare(selectLastInsertedInviteEventSQL); err != nil { - return - } if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { return } @@ -90,19 +83,16 @@ func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatement } func (s *inviteEventsStatements) insertInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (streamPos types.StreamPosition, err error) { - _, err = s.insertInviteEventStmt.ExecContext( + ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.Event, streamPos types.StreamPosition, +) (err error) { + _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( ctx, + streamPos, inviteEvent.RoomID(), inviteEvent.EventID(), *inviteEvent.StateKey(), inviteEvent.JSON(), ) - if err != nil { - return - } - err = s.selectLastInsertedInviteEventStmt.QueryRowContext(ctx).Scan(&streamPos) return } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 8c01f2ce..4535688d 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -54,9 +54,6 @@ const insertEventSQL = "" + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $11" -const selectLastInsertedEventSQL = "" + - "SELECT id FROM syncapi_output_room_events WHERE rowid = last_insert_rowid()" - const selectEventsSQL = "" + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" @@ -105,7 +102,6 @@ const selectStateInRangeSQL = "" + type outputRoomEventsStatements struct { streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt - selectLastInsertedEventStmt *sql.Stmt selectEventsStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt @@ -123,9 +119,6 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDState if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { return } - if s.selectLastInsertedEventStmt, err = db.Prepare(selectLastInsertedEventSQL); err != nil { - return - } if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil { return } @@ -270,7 +263,6 @@ func (s *outputRoomEventsStatements) insertEvent( } insertStmt := common.TxStmt(txn, s.insertEventStmt) - selectStmt := common.TxStmt(txn, s.selectLastInsertedEventStmt) _, err = insertStmt.ExecContext( ctx, streamPos, @@ -286,10 +278,6 @@ func (s *outputRoomEventsStatements) insertEvent( txnID, excludeFromSync, ) - if err != nil { - return - } - err = selectStmt.QueryRowContext(ctx).Scan(&streamPos) return } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 8cfc1884..6ad3419c 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -193,24 +193,20 @@ func (d *SyncServerDatasource) WriteEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ) if err != nil { - fmt.Println("d.events.insertEvent:", err) return err } pduPosition = pos if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil { - fmt.Println("d.topology.insertEventInTopology:", err) return err } if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - fmt.Println("d.handleBackwardExtremities:", err) return err } if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. - fmt.Println("nothing to do") return nil } @@ -340,8 +336,12 @@ func (d *SyncServerDatasource) GetEventsInRange( } // SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { - return d.syncPositionTx(ctx, nil) +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.PaginationToken, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + tok, err = d.syncPositionTx(ctx, txn) + return err + }) + return } // BackwardExtremitiesForRoom returns the event IDs of all of the backward @@ -380,8 +380,12 @@ func (d *SyncServerDatasource) EventPositionInTopology( } // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. -func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { - return d.syncStreamPositionTx(ctx, nil) +func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (pos types.StreamPosition, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + pos, err = d.syncStreamPositionTx(ctx, txn) + return err + }) + return } func (d *SyncServerDatasource) syncStreamPositionTx( @@ -625,18 +629,15 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( if err != nil { return } - fmt.Println("Joined rooms:", joinedRoomIDs) stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request // Build up a /sync response. Add joined rooms. for _, roomID := range joinedRoomIDs { - fmt.Println("WE'RE ON", roomID) var stateEvents []gomatrixserverlib.Event stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart) if err != nil { - fmt.Println("d.roomstate.selectCurrentState:", err) return } //fmt.Println("State events:", stateEvents) @@ -648,7 +649,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( numRecentEventsPerRoom, true, true, ) if err != nil { - fmt.Println("d.events.selectRecentEvents:", err) return } //fmt.Println("Recent stream events:", recentStreamEvents) @@ -658,10 +658,9 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( var backwardTopologyPos types.StreamPosition backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) if err != nil { - fmt.Println("d.topology.selectPositionInTopology:", err) return nil, types.PaginationToken{}, []string{}, err } - fmt.Println("Backward topology position:", backwardTopologyPos) + if backwardTopologyPos-1 <= 0 { backwardTopologyPos = types.StreamPosition(1) } else { @@ -683,7 +682,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( } if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { - fmt.Println("d.addInvitesToResponse:", err) return } @@ -744,18 +742,10 @@ func (d *SyncServerDatasource) GetAccountDataInRange( func (d *SyncServerDatasource) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, ) (sp types.StreamPosition, err error) { - txn, err := d.db.BeginTx(ctx, nil) - if err != nil { - return types.StreamPosition(0), err - } - var succeeded bool - defer func() { - txerr := common.EndTransaction(txn, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() - sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType) + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType) + return err + }) return } @@ -764,8 +754,15 @@ func (d *SyncServerDatasource) UpsertAccountData( // Returns an error if there was a problem communicating with the database. func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (types.StreamPosition, error) { - return d.invites.insertInviteEvent(ctx, inviteEvent) +) (streamPos types.StreamPosition, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + streamPos, err = d.streamID.nextStreamID(ctx, txn) + if err != nil { + return err + } + return d.invites.insertInviteEvent(ctx, txn, inviteEvent, streamPos) + }) + return } // RetireInviteEvent removes an old invite event from the database.