From 19aa44ecaef70a2be7294e8ad738467da41d1f2e Mon Sep 17 00:00:00 2001 From: Kegsay Date: Tue, 26 May 2020 18:23:39 +0100 Subject: [PATCH] Convert transactions/rooms table to share more code (#1063) * Convert rooms table * Convert transactions table * Convert rooms table and factor out lots of functions * I think you'll be needing this.. --- roomserver/storage/postgres/events_table.go | 5 +- roomserver/storage/postgres/rooms_table.go | 31 +- roomserver/storage/postgres/sql.go | 2 - roomserver/storage/postgres/storage.go | 281 ++-------------- .../storage/postgres/transactions_table.go | 17 +- roomserver/storage/shared/storage.go | 301 ++++++++++++++++++ roomserver/storage/sqlite3/events_table.go | 5 +- roomserver/storage/sqlite3/rooms_table.go | 31 +- roomserver/storage/sqlite3/sql.go | 2 - roomserver/storage/sqlite3/storage.go | 298 ++--------------- .../storage/sqlite3/transactions_table.go | 19 +- roomserver/storage/tables/interface.go | 17 +- 12 files changed, 409 insertions(+), 600 deletions(-) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 9c464946..5a567bf2 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -426,10 +426,9 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDForEventNID( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, + ctx context.Context, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { - selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) - err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) + err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) return } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index fc64489d..98881390 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -82,12 +83,13 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomsSchema) +func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{} + _, err := db.Exec(roomsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, @@ -98,7 +100,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *roomStatements) insertRoomNID( +func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { @@ -108,7 +110,7 @@ func (s *roomStatements) insertRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) selectRoomNID( +func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { var roomNID int64 @@ -117,8 +119,8 @@ func (s *roomStatements) selectRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) selectLatestEventNIDs( - ctx context.Context, roomNID types.RoomNID, +func (s *roomStatements) SelectLatestEventNIDs( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 @@ -134,7 +136,7 @@ func (s *roomStatements) selectLatestEventNIDs( return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) selectLatestEventsNIDsForUpdate( +func (s *roomStatements) SelectLatestEventsNIDsForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array @@ -152,7 +154,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate( return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) updateLatestEventNIDs( +func (s *roomStatements) UpdateLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -171,7 +173,7 @@ func (s *roomStatements) updateLatestEventNIDs( return err } -func (s *roomStatements) selectRoomVersionForRoomID( +func (s *roomStatements) SelectRoomVersionForRoomID( ctx context.Context, txn *sql.Tx, roomID string, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion @@ -183,12 +185,11 @@ func (s *roomStatements) selectRoomVersionForRoomID( return roomVersion, err } -func (s *roomStatements) selectRoomVersionForRoomNID( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +func (s *roomStatements) SelectRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := internal.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) - err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") } diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go index 964dabbb..914f269c 100644 --- a/roomserver/storage/postgres/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -38,14 +38,12 @@ func (s *statements) prepare(db *sql.DB) error { var err error for _, prepare := range []func(db *sql.DB) error{ - s.roomStatements.prepare, s.stateSnapshotStatements.prepare, s.stateBlockStatements.prepare, s.previousEventStatements.prepare, s.roomAliasesStatements.prepare, s.inviteStatements.prepare, s.membershipStatements.prepare, - s.transactionStatements.prepare, } { if err = prepare(db); err != nil { return err diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 0022c617..d44da858 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -18,14 +18,12 @@ package postgres import ( "context" "database/sql" - "encoding/json" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" // Import the postgres database driver. _ "github.com/lib/pq" - "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -40,6 +38,8 @@ type Database struct { eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys eventJSON tables.EventJSON + rooms tables.Rooms + transactions tables.Transactions db *sql.DB } @@ -69,164 +69,43 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.rooms, err = NewPostgresRoomsTable(d.db) + if err != nil { + return nil, err + } + d.transactions, err = NewPostgresTransactionsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ + DB: d.db, EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, EventJSONTable: d.eventJSON, EventsTable: d.events, + RoomsTable: d.rooms, + TransactionsTable: d.transactions, } return &d, nil } -// StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, -) (types.RoomNID, types.StateAtEvent, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - err error - ) - - if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txnAndSessionID.TransactionID, - txnAndSessionID.SessionID, event.Sender(), event.EventID(), - ); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - // TODO: Here we should aim to have two different code paths for new rooms - // vs existing ones. - - // Get the default room version. If the client doesn't supply a room_version - // then we will use our configured default to create the room. - // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom - // Note that the below logic depends on the m.room.create event being the - // first event that is persisted to the database when creating or joining a - // room. - var roomVersion gomatrixserverlib.RoomVersion - if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return 0, types.StateAtEvent{}, err - } - - if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID(), roomVersion); err != nil { - return 0, types.StateAtEvent{}, err - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { - return 0, types.StateAtEvent{}, err - } - - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if eventNID, stateNID, err = d.events.InsertEvent( - ctx, - nil, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.events.SelectEvent(ctx, nil, event.EventID()) - } - if err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if err = d.eventJSON.InsertEventJSON(ctx, nil, eventNID, event.JSON()); err != nil { - return 0, types.StateAtEvent{}, err - } - - return roomNID, types.StateAtEvent{ - BeforeStateSnapshotNID: stateNID, - StateEntry: types.StateEntry{ - StateKeyTuple: types.StateKeyTuple{ - EventTypeNID: eventTypeNID, - EventStateKeyNID: eventStateKeyNID, - }, - EventNID: eventNID, - }, - }, nil -} - -func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( - gomatrixserverlib.RoomVersion, error, -) { - var err error - var roomVersion gomatrixserverlib.RoomVersion - // Look for m.room.create events. - if event.Type() != gomatrixserverlib.MRoomCreate { - return gomatrixserverlib.RoomVersion(""), nil - } - roomVersion = gomatrixserverlib.RoomVersionV1 - var createContent gomatrixserverlib.CreateContent - // The m.room.create event contains an optional "room_version" key in - // the event content, so we need to unmarshal that first. - if err = json.Unmarshal(event.Content(), &createContent); err != nil { - return gomatrixserverlib.RoomVersion(""), err - } - // A room version was specified in the event content? - if createContent.RoomVersion != nil { - roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) - } - return roomVersion, err -} - func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { // Check if we already have a numeric ID in the database. - roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) + roomNID, err := d.rooms.SelectRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) + roomNID, err = d.rooms.InsertRoomNID(ctx, txn, roomID, roomVersion) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + roomNID, err = d.rooms.SelectRoomNID(ctx, txn, roomID) } } return roomNID, err } -func (d *Database) assignEventTypeNID( - ctx context.Context, eventType string, -) (eventTypeNID types.EventTypeNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - // Check if we already have a numeric ID in the database. - eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.eventTypes.InsertEventTypeNID(ctx, txn, eventType) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType) - } - } - return err - }) - return eventTypeNID, err -} - func (d *Database) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { @@ -243,38 +122,6 @@ func (d *Database) assignStateKeyNID( return eventStateKeyNID, err } -// Events implements input.EventDatabase -func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - eventJSONs, err := d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs) - if err != nil { - return nil, err - } - results := make([]types.Event, len(eventJSONs)) - for i, eventJSON := range eventJSONs { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion - result := &results[i] - result.EventNID = eventJSON.EventNID - roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) - if err != nil { - return nil, err - } - roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, nil, roomNID) - if err != nil { - return nil, err - } - result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( - eventJSON.EventJSON, false, roomVersion, - ) - if err != nil { - return nil, err - } - } - return results, nil -} - // AddState implements input.EventDatabase func (d *Database) AddState( ctx context.Context, @@ -319,7 +166,7 @@ func (d *Database) GetLatestEventsForUpdate( return nil, err } eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + d.rooms.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -342,18 +189,6 @@ func (d *Database) GetLatestEventsForUpdate( }, nil } -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - type roomRecentEventsUpdater struct { transaction d *Database @@ -415,7 +250,7 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( for i := range latest { eventNIDs[i] = latest[i].EventNID } - return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + return u.d.rooms.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) } // HasEventBeenSent implements types.RoomRecentEventsUpdater @@ -432,55 +267,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) } -// RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { - roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) - if err == sql.ErrNoRows { - return 0, nil - } - return roomNID, err -} - -// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB -func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { - roomNID, err = d.RoomNID(ctx, roomID) - if err != nil { - return - } - latestEvents, _, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return - } - if len(latestEvents) == 0 { - roomNID = 0 - return - } - return -} - -// LatestEventIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - var eventNIDs []types.EventNID - eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return err - } - references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs) - if err != nil { - return err - } - depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs) - if err != nil { - return err - } - return nil - }) - return -} - // GetInvitesForUser implements query.RoomserverQueryAPIDatabase func (d *Database) GetInvitesForUser( ctx context.Context, @@ -733,37 +519,6 @@ func (d *Database) GetMembershipEventNIDsForRoom( return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) } -// EventsFromIDs implements query.RoomserverQueryAPIEventDB -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - var nids []types.EventNID - for _, nid := range nidMap { - nids = append(nids, nid) - } - - return d.Events(ctx, nids) -} - -func (d *Database) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomID( - ctx, nil, roomID, - ) -} - -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomNID( - ctx, nil, roomNID, - ) -} - type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index 87c1caca..7f7ef76a 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -18,6 +18,8 @@ package postgres import ( "context" "database/sql" + + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const transactionsSchema = ` @@ -51,20 +53,21 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func (s *transactionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(transactionsSchema) +func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} + _, err := db.Exec(transactionsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, }.prepare(db) } -func (s *transactionStatements) insertTransaction( - ctx context.Context, +func (s *transactionStatements) InsertTransaction( + ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, @@ -76,7 +79,7 @@ func (s *transactionStatements) insertTransaction( return } -func (s *transactionStatements) selectTransactionEventID( +func (s *transactionStatements) SelectTransactionEventID( ctx context.Context, transactionID string, sessionID int64, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index a3b2c2e2..814b6e81 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -2,16 +2,24 @@ package shared import ( "context" + "database/sql" + "encoding/json" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" ) type Database struct { + DB *sql.DB EventsTable tables.Events EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes EventStateKeysTable tables.EventStateKeys + RoomsTable tables.Rooms + TransactionsTable tables.Transactions } // EventTypeNIDs implements state.RoomStateDatabase @@ -77,3 +85,296 @@ func (d *Database) EventIDs( ) (map[types.EventNID]string, error) { return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) } + +// EventsFromIDs implements query.RoomserverQueryAPIEventDB +func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.EventNIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + var nids []types.EventNID + for _, nid := range nidMap { + nids = append(nids, nid) + } + + return d.Events(ctx, nids) +} + +// RoomNID implements query.RoomserverQueryAPIDB +func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { + roomNID, err := d.RoomsTable.SelectRoomNID(ctx, nil, roomID) + if err == sql.ErrNoRows { + return 0, nil + } + return roomNID, err +} + +// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB +func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { + roomNID, err = d.RoomNID(ctx, roomID) + if err != nil { + return + } + latestEvents, _, err := d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID) + if err != nil { + return + } + if len(latestEvents) == 0 { + roomNID = 0 + return + } + return +} + +// LatestEventIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID) + if err != nil { + return err + } + references, err = d.EventsTable.BulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.EventsTable.SelectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return +} + +func (d *Database) GetRoomVersionForRoom( + ctx context.Context, roomID string, +) (gomatrixserverlib.RoomVersion, error) { + return d.RoomsTable.SelectRoomVersionForRoomID( + ctx, nil, roomID, + ) +} + +func (d *Database) GetRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + return d.RoomsTable.SelectRoomVersionForRoomNID( + ctx, roomNID, + ) +} + +// Events implements input.EventDatabase +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + if err != nil { + return nil, err + } + results := make([]types.Event, len(eventJSONs)) + for i, eventJSON := range eventJSONs { + var roomNID types.RoomNID + var roomVersion gomatrixserverlib.RoomVersion + result := &results[i] + result.EventNID = eventJSON.EventNID + roomNID, err = d.EventsTable.SelectRoomNIDForEventNID(ctx, eventJSON.EventNID) + if err != nil { + return nil, err + } + roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) + if err != nil { + return nil, err + } + result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( + eventJSON.EventJSON, false, roomVersion, + ) + if err != nil { + return nil, err + } + } + return results, nil +} + +// GetTransactionEventID implements input.EventDatabase +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + sessionID int64, userID string, +) (string, error) { + eventID, err := d.TransactionsTable.SelectTransactionEventID(ctx, transactionID, sessionID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} + +// StoreEvent implements input.EventDatabase +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { + var ( + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error { + if txnAndSessionID != nil { + if err = d.TransactionsTable.InsertTransaction( + ctx, txn, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), + ); err != nil { + return err + } + } + + // TODO: Here we should aim to have two different code paths for new rooms + // vs existing ones. + + // Get the default room version. If the client doesn't supply a room_version + // then we will use our configured default to create the room. + // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom + // Note that the below logic depends on the m.room.create event being the + // first event that is persisted to the database when creating or joining a + // room. + var roomVersion gomatrixserverlib.RoomVersion + if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { + return err + } + + if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { + return err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { + return err + } + + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { + return err + } + } + + if eventNID, stateNID, err = d.EventsTable.InsertEvent( + ctx, + txn, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) + } + if err != nil { + return err + } + } + + if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { + return err + } + + return nil + }) + if err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil +} + +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, + roomID string, roomVersion gomatrixserverlib.RoomVersion, +) (types.RoomNID, error) { + // Check if we already have a numeric ID in the database. + roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + roomNID, err = d.RoomsTable.InsertRoomNID(ctx, txn, roomID, roomVersion) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID) + } + } + return roomNID, err +} + +func (d *Database) assignEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, +) (eventTypeNID types.EventTypeNID, err error) { + // Check if we already have a numeric ID in the database. + eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.EventTypesTable.InsertEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) + } + } + return +} + +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + // Check if we already have a numeric ID in the database. + eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventStateKeyNID, err = d.EventStateKeysTable.InsertEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) + } + } + return eventStateKeyNID, err +} + +func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( + gomatrixserverlib.RoomVersion, error, +) { + var err error + var roomVersion gomatrixserverlib.RoomVersion + // Look for m.room.create events. + if event.Type() != gomatrixserverlib.MRoomCreate { + return gomatrixserverlib.RoomVersion(""), nil + } + roomVersion = gomatrixserverlib.RoomVersionV1 + var createContent gomatrixserverlib.CreateContent + // The m.room.create event contains an optional "room_version" key in + // the event content, so we need to unmarshal that first. + if err = json.Unmarshal(event.Content(), &createContent); err != nil { + return gomatrixserverlib.RoomVersion(""), err + } + // A room version was specified in the event content? + if createContent.RoomVersion != nil { + roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) + } + return roomVersion, err +} diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index a41a8737..247faa68 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -469,10 +469,9 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDForEventNID( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, + ctx context.Context, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { - selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) - err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) + err = s.selectRoomNIDForEventNIDStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&roomNID) return } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index ea949d1e..75b8fec9 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -22,6 +22,7 @@ import ( "errors" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -71,12 +72,13 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomsSchema) +func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{} + _, err := db.Exec(roomsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, @@ -87,20 +89,20 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *roomStatements) insertRoomNID( +func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (types.RoomNID, error) { var err error insertStmt := internal.TxStmt(txn, s.insertRoomNIDStmt) if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil { - return s.selectRoomNID(ctx, txn, roomID) + return s.SelectRoomNID(ctx, txn, roomID) } else { return types.RoomNID(0), err } } -func (s *roomStatements) selectRoomNID( +func (s *roomStatements) SelectRoomNID( ctx context.Context, txn *sql.Tx, roomID string, ) (types.RoomNID, error) { var roomNID int64 @@ -109,7 +111,7 @@ func (s *roomStatements) selectRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) selectLatestEventNIDs( +func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { var eventNIDs []types.EventNID @@ -126,7 +128,7 @@ func (s *roomStatements) selectLatestEventNIDs( return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) selectLatestEventsNIDsForUpdate( +func (s *roomStatements) SelectLatestEventsNIDsForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { var eventNIDs []types.EventNID @@ -144,7 +146,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate( return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) updateLatestEventNIDs( +func (s *roomStatements) UpdateLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -163,7 +165,7 @@ func (s *roomStatements) updateLatestEventNIDs( return err } -func (s *roomStatements) selectRoomVersionForRoomID( +func (s *roomStatements) SelectRoomVersionForRoomID( ctx context.Context, txn *sql.Tx, roomID string, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion @@ -175,12 +177,11 @@ func (s *roomStatements) selectRoomVersionForRoomID( return roomVersion, err } -func (s *roomStatements) selectRoomVersionForRoomNID( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +func (s *roomStatements) SelectRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { var roomVersion gomatrixserverlib.RoomVersion - stmt := internal.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) - err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + err := s.selectRoomVersionForRoomNIDStmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) if err == sql.ErrNoRows { return roomVersion, errors.New("room not found") } diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index ac44b398..fe899174 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -38,14 +38,12 @@ func (s *statements) prepare(db *sql.DB) error { var err error for _, prepare := range []func(db *sql.DB) error{ - s.roomStatements.prepare, s.stateSnapshotStatements.prepare, s.stateBlockStatements.prepare, s.previousEventStatements.prepare, s.roomAliasesStatements.prepare, s.inviteStatements.prepare, s.membershipStatements.prepare, - s.transactionStatements.prepare, } { if err = prepare(db); err != nil { return err diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 05bad7fb..a0a1b568 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -18,14 +18,12 @@ package sqlite3 import ( "context" "database/sql" - "encoding/json" "errors" "net/url" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -41,6 +39,8 @@ type Database struct { eventJSON tables.EventJSON eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys + rooms tables.Rooms + transactions tables.Transactions db *sql.DB } @@ -89,163 +89,38 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + d.rooms, err = NewSqliteRoomsTable(d.db) + if err != nil { + return nil, err + } + d.transactions, err = NewSqliteTransactionsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ + DB: d.db, EventsTable: d.events, EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, EventJSONTable: d.eventJSON, + RoomsTable: d.rooms, + TransactionsTable: d.transactions, } return &d, nil } -// StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, -) (types.RoomNID, types.StateAtEvent, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - err error - ) - - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txn, txnAndSessionID.TransactionID, - txnAndSessionID.SessionID, event.Sender(), event.EventID(), - ); err != nil { - return err - } - } - - // TODO: Here we should aim to have two different code paths for new rooms - // vs existing ones. - - // Get the default room version. If the client doesn't supply a room_version - // then we will use our configured default to create the room. - // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom - // Note that the below logic depends on the m.room.create event being the - // first event that is persisted to the database when creating or joining a - // room. - var roomVersion gomatrixserverlib.RoomVersion - if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return err - } - - if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { - return err - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { - return err - } - - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { - return err - } - } - - if eventNID, stateNID, err = d.events.InsertEvent( - ctx, - txn, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.events.SelectEvent(ctx, txn, event.EventID()) - } - if err != nil { - return err - } - } - - if err = d.eventJSON.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { - return err - } - - return nil - }) - if err != nil { - return 0, types.StateAtEvent{}, err - } - - return roomNID, types.StateAtEvent{ - BeforeStateSnapshotNID: stateNID, - StateEntry: types.StateEntry{ - StateKeyTuple: types.StateKeyTuple{ - EventTypeNID: eventTypeNID, - EventStateKeyNID: eventStateKeyNID, - }, - EventNID: eventNID, - }, - }, nil -} - -func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( - gomatrixserverlib.RoomVersion, error, -) { - var err error - var roomVersion gomatrixserverlib.RoomVersion - // Look for m.room.create events. - if event.Type() != gomatrixserverlib.MRoomCreate { - return gomatrixserverlib.RoomVersion(""), nil - } - roomVersion = gomatrixserverlib.RoomVersionV1 - var createContent gomatrixserverlib.CreateContent - // The m.room.create event contains an optional "room_version" key in - // the event content, so we need to unmarshal that first. - if err = json.Unmarshal(event.Content(), &createContent); err != nil { - return gomatrixserverlib.RoomVersion(""), err - } - // A room version was specified in the event content? - if createContent.RoomVersion != nil { - roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion) - } - return roomVersion, err -} - func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (roomNID types.RoomNID, err error) { // Check if we already have a numeric ID in the database. - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + roomNID, err = d.rooms.SelectRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID, roomVersion) + roomNID, err = d.rooms.InsertRoomNID(ctx, txn, roomID, roomVersion) if err == nil { // Now get the numeric ID back out of the database - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - } - } - return -} - -func (d *Database) assignEventTypeNID( - ctx context.Context, txn *sql.Tx, eventType string, -) (eventTypeNID types.EventTypeNID, err error) { - // Check if we already have a numeric ID in the database. - eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.eventTypes.InsertEventTypeNID(ctx, txn, eventType) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventTypeNID, err = d.eventTypes.SelectEventTypeNID(ctx, txn, eventType) + roomNID, err = d.rooms.SelectRoomNID(ctx, txn, roomID) } } return @@ -267,47 +142,6 @@ func (d *Database) assignStateKeyNID( return } -// Events implements input.EventDatabase -func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - var eventJSONs []tables.EventJSONPair - var err error - var results []types.Event - eventJSONs, err = d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs) - if err != nil || len(eventJSONs) == 0 { - return nil, nil - } - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - results = make([]types.Event, len(eventJSONs)) - for i, eventJSON := range eventJSONs { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion - result := &results[i] - result.EventNID = eventJSON.EventNID - roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, txn, eventJSON.EventNID) - if err != nil { - return err - } - roomVersion, err = d.statements.selectRoomVersionForRoomNID(ctx, txn, roomNID) - if err != nil { - return err - } - result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( - eventJSON.EventJSON, false, roomVersion, - ) - if err != nil { - return nil - } - } - return nil - }) - if err != nil { - return []types.Event{}, err - } - return results, nil -} - // AddState implements input.EventDatabase func (d *Database) AddState( ctx context.Context, @@ -364,7 +198,7 @@ func (d *Database) GetLatestEventsForUpdate( return nil, err } eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + d.rooms.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -396,18 +230,6 @@ func (d *Database) GetLatestEventsForUpdate( }, nil } -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - type roomRecentEventsUpdater struct { transaction d *Database @@ -478,7 +300,7 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( for i := range latest { eventNIDs[i] = latest[i].EventNID } - return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + return u.d.rooms.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) }) return err } @@ -508,59 +330,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta return } -// RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - roomNID = 0 - err = nil - } - return err - }) - return -} - -// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB -func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { - roomNID, err = d.RoomNID(ctx, roomID) - if err != nil { - return - } - latestEvents, _, err := d.statements.selectLatestEventNIDs(ctx, nil, roomNID) - if err != nil { - return - } - if len(latestEvents) == 0 { - roomNID = 0 - return - } - return -} - -// LatestEventIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - var eventNIDs []types.EventNID - eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID) - if err != nil { - return err - } - references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs) - if err != nil { - return err - } - depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs) - if err != nil { - return err - } - return nil - }) - return -} - // GetInvitesForUser implements query.RoomserverQueryAPIDatabase func (d *Database) GetInvitesForUser( ctx context.Context, @@ -844,37 +613,6 @@ func (d *Database) GetMembershipEventNIDsForRoom( return } -// EventsFromIDs implements query.RoomserverQueryAPIEventDB -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - var nids []types.EventNID - for _, nid := range nidMap { - nids = append(nids, nid) - } - - return d.Events(ctx, nids) -} - -func (d *Database) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomID( - ctx, nil, roomID, - ) -} - -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - return d.statements.selectRoomVersionForRoomNID( - ctx, nil, roomNID, - ) -} - type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index d22c7384..37ea15c0 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" ) const transactionsSchema = ` @@ -46,19 +47,20 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func (s *transactionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(transactionsSchema) +func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} + _, err := db.Exec(transactionsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, }.prepare(db) } -func (s *transactionStatements) insertTransaction( +func (s *transactionStatements) InsertTransaction( ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, @@ -72,14 +74,13 @@ func (s *transactionStatements) insertTransaction( return } -func (s *transactionStatements) selectTransactionEventID( - ctx context.Context, txn *sql.Tx, +func (s *transactionStatements) SelectTransactionEventID( + ctx context.Context, transactionID string, sessionID int64, userID string, ) (eventID string, err error) { - stmt := internal.TxStmt(txn, s.selectTransactionEventIDStmt) - err = stmt.QueryRowContext( + err = s.selectTransactionEventIDStmt.QueryRowContext( ctx, transactionID, sessionID, userID, ).Scan(&eventID) return diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 78ddc5fe..e913de6b 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -53,5 +53,20 @@ type Events interface { // If an event ID is not in the database then it is omitted from the map. BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) - SelectRoomNIDForEventNID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (roomNID types.RoomNID, err error) + SelectRoomNIDForEventNID(ctx context.Context, eventNID types.EventNID) (roomNID types.RoomNID, err error) +} + +type Rooms interface { + InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) + SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) + SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) + UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error + SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error) + SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) +} + +type Transactions interface { + InsertTransaction(ctx context.Context, txn *sql.Tx, transactionID string, sessionID int64, userID string, eventID string) error + SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error) }