diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 661c4472..a3262926 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -58,32 +59,28 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventJSONSchema) +func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.prepare(db) } -func (s *eventJSONStatements) insertEventJSON( - ctx context.Context, eventNID types.EventNID, eventJSON []byte, +func (s *eventJSONStatements) InsertEventJSON( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( +func (s *eventJSONStatements) BulkSelectEventJSON( ctx context.Context, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +) ([]tables.EventJSONPair, error) { rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err @@ -94,7 +91,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON( // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c28fa8e6..9c464946 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -136,13 +137,14 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func (s *eventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventsSchema) +func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} + _, err := db.Exec(eventsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, @@ -160,8 +162,9 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventStatements) insertEvent( +func (s *eventStatements) InsertEvent( ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, @@ -179,8 +182,8 @@ func (s *eventStatements) insertEvent( return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } -func (s *eventStatements) selectEvent( - ctx context.Context, eventID string, +func (s *eventStatements) SelectEvent( + ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 @@ -190,7 +193,7 @@ func (s *eventStatements) selectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError -func (s *eventStatements) bulkSelectStateEventByID( +func (s *eventStatements) BulkSelectStateEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) @@ -233,7 +236,7 @@ func (s *eventStatements) bulkSelectStateEventByID( // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. -func (s *eventStatements) bulkSelectStateAtEventByID( +func (s *eventStatements) BulkSelectStateAtEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) @@ -270,14 +273,14 @@ func (s *eventStatements) bulkSelectStateAtEventByID( return results, nil } -func (s *eventStatements) updateEventState( +func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) return err } -func (s *eventStatements) selectEventSentToOutput( +func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { stmt := internal.TxStmt(txn, s.selectEventSentToOutputStmt) @@ -285,13 +288,13 @@ func (s *eventStatements) selectEventSentToOutput( return } -func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { +func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { stmt := internal.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := stmt.ExecContext(ctx, int64(eventNID)) return err } -func (s *eventStatements) selectEventID( +func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { stmt := internal.TxStmt(txn, s.selectEventIDStmt) @@ -299,7 +302,7 @@ func (s *eventStatements) selectEventID( return } -func (s *eventStatements) bulkSelectStateAtEventAndReference( +func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { stmt := internal.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) @@ -341,8 +344,8 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( return results, nil } -func (s *eventStatements) bulkSelectEventReference( - ctx context.Context, eventNIDs []types.EventNID, +func (s *eventStatements) BulkSelectEventReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { @@ -367,7 +370,7 @@ func (s *eventStatements) bulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err @@ -394,7 +397,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err @@ -412,7 +415,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str return results, rows.Err() } -func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 stmt := s.selectMaxEventDepthStmt err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) @@ -422,7 +425,7 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []t return result, nil } -func (s *eventStatements) selectRoomNIDForEventNID( +func (s *eventStatements) SelectRoomNIDForEventNID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go index e41c5a39..964dabbb 100644 --- a/roomserver/storage/postgres/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -39,8 +39,6 @@ func (s *statements) prepare(db *sql.DB) error { for _, prepare := range []func(db *sql.DB) error{ s.roomStatements.prepare, - s.eventStatements.prepare, - s.eventJSONStatements.prepare, s.stateSnapshotStatements.prepare, s.stateBlockStatements.prepare, s.previousEventStatements.prepare, diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 6fcceced..0022c617 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -36,8 +36,10 @@ import ( type Database struct { shared.Database statements statements + events tables.Events eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys + eventJSON tables.EventJSON db *sql.DB } @@ -59,9 +61,19 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.eventJSON, err = NewPostgresEventJSONTable(d.db) + if err != nil { + return nil, err + } + d.events, err = NewPostgresEventsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, + EventJSONTable: d.eventJSON, + EventsTable: d.events, } return &d, nil } @@ -120,8 +132,9 @@ func (d *Database) StoreEvent( } } - if eventNID, stateNID, err = d.statements.insertEvent( + if eventNID, stateNID, err = d.events.InsertEvent( ctx, + nil, roomNID, eventTypeNID, eventStateKeyNID, @@ -132,14 +145,14 @@ func (d *Database) StoreEvent( ); err != nil { if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) + eventNID, stateNID, err = d.events.SelectEvent(ctx, nil, event.EventID()) } if err != nil { return 0, types.StateAtEvent{}, err } } - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + if err = d.eventJSON.InsertEventJSON(ctx, nil, eventNID, event.JSON()); err != nil { return 0, types.StateAtEvent{}, err } @@ -230,25 +243,11 @@ func (d *Database) assignStateKeyNID( return eventStateKeyNID, err } -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(ctx, eventIDs) -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(ctx, eventIDs) -} - // Events implements input.EventDatabase func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) + eventJSONs, err := d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs) if err != nil { return nil, err } @@ -258,7 +257,7 @@ func (d *Database) Events( var roomVersion gomatrixserverlib.RoomVersion result := &results[i] result.EventNID = eventJSON.EventNID - roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) + roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) if err != nil { return nil, err } @@ -297,20 +296,6 @@ func (d *Database) AddState( return d.statements.insertState(ctx, roomNID, stateBlockNIDs) } -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - return d.statements.updateEventState(ctx, eventNID, stateNID) -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) -} - // StateBlockNIDs implements state.RoomStateDatabase func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, @@ -325,21 +310,6 @@ func (d *Database) StateEntries( return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) } -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(ctx, eventID) - return stateNID, err -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(ctx, eventNIDs) -} - // GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, @@ -354,14 +324,14 @@ func (d *Database) GetLatestEventsForUpdate( txn.Rollback() // nolint: errcheck return nil, err } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + stateAndRefs, err := d.events.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) if err != nil { txn.Rollback() // nolint: errcheck return nil, err } var lastEventIDSent string if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + lastEventIDSent, err = d.events.SelectEventID(ctx, txn, lastEventNIDSent) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -450,12 +420,12 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( // HasEventBeenSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.events.SelectEventSentToOutput(u.ctx, u.txn, eventNID) } // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.events.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) } func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { @@ -491,20 +461,24 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro // LatestEventIDs implements query.RoomserverQueryAPIDatabase func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, -) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { - eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return nil, 0, 0, err - } - references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - return references, currentStateSnapshotNID, depth, nil +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, roomNID) + if err != nil { + return err + } + references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return } // GetInvitesForUser implements query.RoomserverQueryAPIDatabase diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 7a8da865..a3b2c2e2 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -8,6 +8,8 @@ import ( ) type Database struct { + EventsTable tables.Events + EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes EventStateKeysTable tables.EventStateKeys } @@ -32,3 +34,46 @@ func (d *Database) EventStateKeyNIDs( ) (map[string]types.EventStateKeyNID, error) { return d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) } + +// StateEntriesForEventIDs implements input.EventDatabase +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) +} + +// EventNIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) +} + +// SetState implements input.EventDatabase +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) +} + +// StateAtEventIDs implements input.EventDatabase +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) +} + +// SnapshotNIDFromEventID implements state.RoomStateDatabase +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID) + return stateNID, err +} + +// EventIDs implements input.RoomEventDatabase +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) +} diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index fbf35e71..34b067cb 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -51,40 +52,36 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} s.db = db - _, err = db.Exec(eventJSONSchema) + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.prepare(db) } -func (s *eventJSONStatements) insertEventJSON( +func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { _, err := internal.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +func (s *eventJSONStatements) BulkSelectEventJSON( + ctx context.Context, eventNIDs []types.EventNID, +) ([]tables.EventJSONPair, error) { iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) - rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) + rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { return nil, err } @@ -94,7 +91,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON( // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 55113495..a41a8737 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -111,14 +112,15 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func (s *eventStatements) prepare(db *sql.DB) (err error) { +func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} s.db = db - _, err = db.Exec(eventsSchema) + _, err := db.Exec(eventsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, @@ -135,7 +137,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventStatements) insertEvent( +func (s *eventStatements) InsertEvent( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, @@ -145,7 +147,7 @@ func (s *eventStatements) insertEvent( referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, -) (types.EventNID, error) { +) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID insertStmt := internal.TxStmt(txn, s.insertEventStmt) result, err := insertStmt.ExecContext( @@ -153,17 +155,17 @@ func (s *eventStatements) insertEvent( eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, ) if err != nil { - return 0, err + return 0, 0, err } modified, err := result.RowsAffected() if modified == 0 && err == nil { - return 0, sql.ErrNoRows + return 0, 0, sql.ErrNoRows } eventNID, err := result.LastInsertId() - return types.EventNID(eventNID), err + return types.EventNID(eventNID), 0, err } -func (s *eventStatements) selectEvent( +func (s *eventStatements) SelectEvent( ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 @@ -175,8 +177,8 @@ func (s *eventStatements) selectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError -func (s *eventStatements) bulkSelectStateEventByID( - ctx context.Context, txn *sql.Tx, eventIDs []string, +func (s *eventStatements) BulkSelectStateEventByID( + ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -184,13 +186,12 @@ func (s *eventStatements) bulkSelectStateEventByID( iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err @@ -228,8 +229,8 @@ func (s *eventStatements) bulkSelectStateEventByID( // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. -func (s *eventStatements) bulkSelectStateAtEventByID( - ctx context.Context, txn *sql.Tx, eventIDs []string, +func (s *eventStatements) BulkSelectStateAtEventByID( + ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -237,13 +238,11 @@ func (s *eventStatements) bulkSelectStateAtEventByID( iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err @@ -275,15 +274,14 @@ func (s *eventStatements) bulkSelectStateAtEventByID( return results, err } -func (s *eventStatements) updateEventState( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, +func (s *eventStatements) UpdateEventState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - updateStmt := internal.TxStmt(txn, s.updateEventStateStmt) - _, err := updateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) return err } -func (s *eventStatements) selectEventSentToOutput( +func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { selectStmt := internal.TxStmt(txn, s.selectEventSentToOutputStmt) @@ -294,14 +292,14 @@ func (s *eventStatements) selectEventSentToOutput( return } -func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { +func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { updateStmt := internal.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := updateStmt.ExecContext(ctx, int64(eventNID)) //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID)) return err } -func (s *eventStatements) selectEventID( +func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { selectStmt := internal.TxStmt(txn, s.selectEventIDStmt) @@ -309,7 +307,7 @@ func (s *eventStatements) selectEventID( return } -func (s *eventStatements) bulkSelectStateAtEventAndReference( +func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { /////////////// @@ -355,7 +353,7 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( return results, nil } -func (s *eventStatements) bulkSelectEventReference( +func (s *eventStatements) BulkSelectEventReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { /////////////// @@ -391,20 +389,19 @@ func (s *eventStatements) bulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { /////////////// iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", internal.QueryVariadic(len(iEventNIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) if err != nil { return nil, err @@ -428,20 +425,18 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", internal.QueryVariadic(len(iEventIDs)), 1) - selectPrep, err := txn.Prepare(selectOrig) + selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } /////////////// - - selectStmt := internal.TxStmt(txn, selectPrep) rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { return nil, err @@ -459,7 +454,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e return results, nil } -func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 iEventIDs := make([]interface{}, len(eventNIDs)) for i, v := range eventNIDs { @@ -473,7 +468,7 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, return result, nil } -func (s *eventStatements) selectRoomNIDForEventNID( +func (s *eventStatements) SelectRoomNIDForEventNID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index bb3318b2..ac44b398 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -39,8 +39,6 @@ func (s *statements) prepare(db *sql.DB) error { for _, prepare := range []func(db *sql.DB) error{ s.roomStatements.prepare, - s.eventStatements.prepare, - s.eventJSONStatements.prepare, s.stateSnapshotStatements.prepare, s.stateBlockStatements.prepare, s.previousEventStatements.prepare, diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index b9157e3a..05bad7fb 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -37,6 +37,8 @@ import ( type Database struct { shared.Database statements statements + events tables.Events + eventJSON tables.EventJSON eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys db *sql.DB @@ -79,9 +81,19 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + d.eventJSON, err = NewSqliteEventJSONTable(d.db) + if err != nil { + return nil, err + } + d.events, err = NewSqliteEventsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ + EventsTable: d.events, EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, + EventJSONTable: d.eventJSON, } return &d, nil } @@ -141,7 +153,7 @@ func (d *Database) StoreEvent( } } - if eventNID, err = d.statements.insertEvent( + if eventNID, stateNID, err = d.events.InsertEvent( ctx, txn, roomNID, @@ -154,14 +166,14 @@ func (d *Database) StoreEvent( ); err != nil { if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID()) + eventNID, stateNID, err = d.events.SelectEvent(ctx, txn, event.EventID()) } if err != nil { return err } } - if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { + if err = d.eventJSON.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { return err } @@ -255,47 +267,25 @@ func (d *Database) assignStateKeyNID( return } -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) (se []types.StateEntry, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs) - return err - }) - return -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (out map[string]types.EventNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs) - return err - }) - return -} - // Events implements input.EventDatabase func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - var eventJSONs []eventJSONPair + var eventJSONs []tables.EventJSONPair var err error var results []types.Event + eventJSONs, err = d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs) + if err != nil || len(eventJSONs) == 0 { + return nil, nil + } err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) - if err != nil || len(eventJSONs) == 0 { - return nil - } results = make([]types.Event, len(eventJSONs)) for i, eventJSON := range eventJSONs { var roomNID types.RoomNID var roomVersion gomatrixserverlib.RoomVersion result := &results[i] result.EventNID = eventJSON.EventNID - roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, txn, eventJSON.EventNID) + roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, txn, eventJSON.EventNID) if err != nil { return err } @@ -343,27 +333,6 @@ func (d *Database) AddState( return } -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - e := internal.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.statements.updateEventState(ctx, txn, eventNID, stateNID) - }) - return e -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) (se []types.StateAtEvent, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs) - return err - }) - return -} - // StateBlockNIDs implements state.RoomStateDatabase func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, @@ -386,28 +355,6 @@ func (d *Database) StateEntries( return } -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (stateNID types.StateSnapshotNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - _, stateNID, err = d.statements.selectEvent(ctx, txn, eventID) - return err - }) - return -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (out map[types.EventNID]string, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs) - return err - }) - return -} - // GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, @@ -422,14 +369,14 @@ func (d *Database) GetLatestEventsForUpdate( txn.Rollback() // nolint: errcheck return nil, err } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + stateAndRefs, err := d.events.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) if err != nil { txn.Rollback() // nolint: errcheck return nil, err } var lastEventIDSent string if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + lastEventIDSent, err = d.events.SelectEventID(ctx, txn, lastEventNIDSent) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -539,7 +486,7 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( // HasEventBeenSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) { err = internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID) + res, err = u.d.events.SelectEventSentToOutput(u.ctx, txn, eventNID) return err }) return @@ -548,7 +495,7 @@ func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { err := internal.WithTransaction(u.d.db, func(txn *sql.Tx) error { - return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID) + return u.d.events.UpdateEventSentToOutput(u.ctx, txn, eventNID) }) return err } @@ -601,11 +548,11 @@ func (d *Database) LatestEventIDs( if err != nil { return err } - references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs) + references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs) if err != nil { return err } - depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs) + depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs) if err != nil { return err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index d607865d..78ddc5fe 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -5,8 +5,19 @@ import ( "database/sql" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" ) +type EventJSONPair struct { + EventNID types.EventNID + EventJSON []byte +} + +type EventJSON interface { + InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error + BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) +} + type EventTypes interface { InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) @@ -19,3 +30,28 @@ type EventStateKeys interface { BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) } + +type Events interface { + InsertEvent(c context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64) (types.EventNID, types.StateSnapshotNID, error) + SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) + // bulkSelectStateEventByID lookups a list of state events by event ID. + // If any of the requested events are missing from the database it returns a types.MissingEventError + BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. + // If any of the requested events are missing from the database it returns a types.MissingEventError. + // If we do not have the state for any of the requested events it returns a types.MissingEventError. + BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + UpdateEventState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) + UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error + SelectEventID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) + BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) + BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) + // BulkSelectEventID returns a map from numeric event ID to string event ID. + BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. + // If an event ID is not in the database then it is omitted from the map. + BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) + SelectRoomNIDForEventNID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (roomNID types.RoomNID, err error) +}