diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 1e0232d2..52e6a96b 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -63,29 +63,80 @@ type Database interface { SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) // Look up a room version from the room NID. GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) + // Stores a matrix room event in the database StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) + // Look up the state entries for a list of string event IDs + // Returns an error if the there is an error talking to the database + // Returns a types.MissingEventError if the event IDs aren't in the database. StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + // Look up the string event state keys for a list of numeric event state keys + // Returns an error if there was a problem talking to the database. EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + // Look up the numeric IDs for a list of events. + // Returns an error if there was a problem talking to the database. EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + // Set the state at an event. FIXME TODO: "at" SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + // Lookup the event IDs for a batch of event numeric IDs. + // Returns an error if the retrieval went wrong. EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // Look up the latest events in a room in preparation for an update. + // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. + // Returns the latest events in the room and the last eventID sent to the log along with an updater. + // If this returns an error then no further action is required. GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) + // Look up event ID by transaction's info. + // This is used to determine if the room event is processed/processing already. + // Returns an empty string if no such event exists. GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error) + // Look up the numeric ID for the room. + // Returns 0 if the room doesn't exists. + // Returns an error if there was a problem talking to the database. RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) // RoomNIDExcludingStubs is a special variation of RoomNID that will return 0 as if the room // does not exist if the room has no latest events. This can happen when we've received an // invite over federation for a room that we don't know anything else about yet. RoomNIDExcludingStubs(ctx context.Context, roomID string) (types.RoomNID, error) + // Look up event references for the latest events in the room and the current state snapshot. + // Returns the latest events, the current state and the maximum depth of the latest events plus 1. + // Returns an error if there was a problem talking to the database. LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + // Look up the active invites targeting a user in a room and return the + // numeric state key IDs for the user IDs who sent them. + // Returns an error if there was a problem talking to the database. GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) + // Save a given room alias with the room ID it refers to. + // Returns an error if there was a problem talking to the database. SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error + // Look up the room ID a given alias refers to. + // Returns an error if there was a problem talking to the database. GetRoomIDForAlias(ctx context.Context, alias string) (string, error) + // Look up all aliases referring to a given room ID. + // Returns an error if there was a problem talking to the database. GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) + // Get the user ID of the creator of an alias. + // Returns an error if there was a problem talking to the database. GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) + // Remove a given room alias. + // Returns an error if there was a problem talking to the database. RemoveRoomAlias(ctx context.Context, alias string) error + // Build a membership updater for the target user in a room. MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) + // Lookup the membership of a given user in a given room. + // Returns the numeric ID of the latest membership event sent from this user + // in this room, along a boolean set to true if the user is still in this room, + // false if not. + // Returns an error if there was a problem talking to the database. GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) + // Lookup the membership event numeric IDs for all user that are or have + // been members of a given room. Only lookup events of "join" membership if + // joinOnly is set to true. + // Returns an error if there was a problem talking to the database. GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) + // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was + // not found. + // Returns an error if the retrieval went wrong. EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) + // Look up the room version for a given room. GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) } diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index a3262926..3bce9f08 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -65,10 +66,10 @@ func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) { if err != nil { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, - }.prepare(db) + }.Prepare(db) } func (s *eventJSONStatements) InsertEventJSON( diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 81b9b06e..b7e2cb6b 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -81,12 +82,12 @@ func NewPostgresEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { if err != nil { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, - }.prepare(db) + }.Prepare(db) } func (s *eventStateKeyStatements) InsertEventStateKeyNID( diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index aaba614a..f8a7bff1 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -106,11 +107,11 @@ func NewPostgresEventTypesTable(db *sql.DB) (tables.EventTypes, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, - }.prepare(db) + }.Prepare(db) } func (s *eventTypeStatements) InsertEventTypeNID( diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 5a567bf2..d39e9511 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -144,7 +145,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, @@ -159,7 +160,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, - }.prepare(db) + }.Prepare(db) } func (s *eventStatements) InsertEvent( diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index f0fb919e..55a5bc7d 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -79,20 +81,21 @@ type inviteStatements struct { updateInviteRetiredStmt *sql.Stmt } -func (s *inviteStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteSchema) +func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} + _, err := db.Exec(inviteSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *inviteStatements) insertInviteEvent( +func (s *inviteStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, @@ -111,7 +114,7 @@ func (s *inviteStatements) insertInviteEvent( return count != 0, nil } -func (s *inviteStatements) updateInviteRetired( +func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { @@ -133,8 +136,8 @@ func (s *inviteStatements) updateInviteRetired( return eventIDs, rows.Err() } -// selectInviteActiveForUserInRoom returns a list of sender state key NIDs -func (s *inviteStatements) selectInviteActiveForUserInRoom( +// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs +func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index f290a05f..c635ce28 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -20,17 +20,11 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -type membershipState int64 - -const ( - membershipStateLeaveOrBan membershipState = 1 - membershipStateInvite membershipState = 2 - membershipStateJoin membershipState = 3 -) - const membershipSchema = ` -- The membership table is used to coordinate updates between the invite table -- and the room state tables. @@ -115,13 +109,14 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func (s *membershipStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(membershipSchema) +func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} + _, err := db.Exec(membershipSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, @@ -130,10 +125,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *membershipStatements) insertMembership( +func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, @@ -143,27 +138,27 @@ func (s *membershipStatements) insertMembership( return err } -func (s *membershipStatements) selectMembershipForUpdate( +func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (membership membershipState, err error) { +) (membership tables.MembershipState, err error) { err = internal.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership) return } -func (s *membershipStatements) selectMembershipFromRoomAndTarget( +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership membershipState, err error) { +) (eventNID types.EventNID, membership tables.MembershipState, err error) { err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } -func (s *membershipStatements) selectMembershipsFromRoom( +func (s *membershipStatements) SelectMembershipsFromRoom( ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt @@ -188,9 +183,9 @@ func (s *membershipStatements) selectMembershipsFromRoom( return eventNIDs, rows.Err() } -func (s *membershipStatements) selectMembershipsFromRoomAndMembership( +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( ctx context.Context, - roomNID types.RoomNID, membership membershipState, localOnly bool, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var rows *sql.Rows var stmt *sql.Stmt @@ -215,10 +210,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( return eventNIDs, rows.Err() } -func (s *membershipStatements) updateMembership( +func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership membershipState, + senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { _, err := internal.TxStmt(txn, s.updateMembershipStmt).ExecContext( diff --git a/roomserver/storage/postgres/prepare.go b/roomserver/storage/postgres/prepare.go deleted file mode 100644 index 70b6e516..00000000 --- a/roomserver/storage/postgres/prepare.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string -} - -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { - for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { - return - } - } - return -} diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index b3d32c95..ff030696 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -71,10 +72,10 @@ func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, - }.prepare(db) + }.Prepare(db) } func (s *previousEventStatements) InsertPreviousEvent( diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index f869cf4f..85042c54 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -66,13 +67,13 @@ func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { if err != nil { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, - }.prepare(db) + }.Prepare(db) } func (s *roomAliasesStatements) InsertRoomAlias( diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 98881390..5a353ed1 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -89,7 +90,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { if err != nil { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, @@ -97,7 +98,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, - }.prepare(db) + }.Prepare(db) } func (s *roomStatements) InsertRoomNID( diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go deleted file mode 100644 index eb626dd8..00000000 --- a/roomserver/storage/postgres/sql.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -type statements struct { - eventTypeStatements - eventStateKeyStatements - roomStatements - eventStatements - eventJSONStatements - stateSnapshotStatements - stateBlockStatements - previousEventStatements - roomAliasesStatements - inviteStatements - membershipStatements - transactionStatements -} - -func (s *statements) prepare(db *sql.DB) error { - var err error - - for _, prepare := range []func(db *sql.DB) error{ - s.inviteStatements.prepare, - s.membershipStatements.prepare, - } { - if err = prepare(db); err != nil { - return err - } - } - - return nil -} diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index d1aaaa00..dbe21d98 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -95,12 +96,12 @@ func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.prepare(db) + }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 8971292f..0f8f1c51 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -72,10 +73,10 @@ func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.prepare(db) + }.Prepare(db) } func (s *stateSnapshotStatements) InsertState( diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 03cfb7f0..53a58076 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -33,7 +33,6 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { shared.Database - statements statements events tables.Events eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys @@ -41,6 +40,8 @@ type Database struct { rooms tables.Rooms transactions tables.Transactions prevEvents tables.PreviousEvents + invites tables.Invites + membership tables.Membership db *sql.DB } @@ -52,9 +53,6 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - if err = d.statements.prepare(d.db); err != nil { - return nil, err - } d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db) if err != nil { return nil, err @@ -95,6 +93,14 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.invites, err = NewPostgresInvitesTable(d.db) + if err != nil { + return nil, err + } + d.membership, err = NewPostgresMembershipTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventTypesTable: d.eventTypes, @@ -107,6 +113,8 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, StateSnapshotTable: stateSnapshot, PrevEventsTable: d.prevEvents, RoomAliasesTable: roomAliases, + InvitesTable: d.invites, + MembershipTable: d.membership, } return &d, nil } @@ -254,15 +262,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) } -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, @@ -303,7 +302,7 @@ type membershipUpdater struct { d *Database roomNID types.RoomNID targetUserNID types.EventStateKeyNID - membership membershipState + membership tables.MembershipState } func (d *Database) membershipUpdaterTxn( @@ -314,11 +313,11 @@ func (d *Database) membershipUpdaterTxn( targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) if err != nil { return nil, err } @@ -330,17 +329,17 @@ func (d *Database) membershipUpdaterTxn( // IsInvite implements types.MembershipUpdater func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite + return u.membership == tables.MembershipStateInvite } // IsJoin implements types.MembershipUpdater func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin + return u.membership == tables.MembershipStateJoin } // IsLeave implements types.MembershipUpdater func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan + return u.membership == tables.MembershipStateLeaveOrBan } // SetToInvite implements types.MembershipUpdater @@ -349,15 +348,15 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er if err != nil { return false, err } - inserted, err := u.d.statements.insertInviteEvent( + inserted, err := u.d.invites.InsertInviteEvent( u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { return false, err } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + if u.membership != tables.MembershipStateInvite { + if err = u.d.membership.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, ); err != nil { return false, err } @@ -376,7 +375,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd // If this is a join event update, there is no invite to update if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( + inviteEventIDs, err = u.d.invites.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { @@ -390,10 +389,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd return nil, err } - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.membership.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], + tables.MembershipStateJoin, nIDs[eventID], ); err != nil { return nil, err } @@ -408,7 +407,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s if err != nil { return nil, err } - inviteEventIDs, err := u.d.statements.updateInviteRetired( + inviteEventIDs, err := u.d.invites.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { @@ -421,10 +420,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s return nil, err } - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.membership.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], + tables.MembershipStateLeaveOrBan, nIDs[eventID], ); err != nil { return nil, err } @@ -432,42 +431,6 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s return inviteEventIDs, nil } -// GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) - if err != nil { - return - } - - senderMembershipEventNID, senderMembership, err := - d.statements.selectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return 0, false, nil - } else if err != nil { - return - } - - return senderMembershipEventNID, senderMembership == membershipStateJoin, nil -} - -// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, -) ([]types.EventNID, error) { - if joinOnly { - return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, localOnly, - ) - } - - return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) -} - type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index 7f7ef76a..5e59ae16 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -60,10 +61,10 @@ func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, - }.prepare(db) + }.Prepare(db) } func (s *transactionStatements) InsertTransaction( diff --git a/roomserver/storage/sqlite3/prepare.go b/roomserver/storage/shared/prepare.go similarity index 66% rename from roomserver/storage/sqlite3/prepare.go rename to roomserver/storage/shared/prepare.go index 482dfa2b..1b3497fd 100644 --- a/roomserver/storage/sqlite3/prepare.go +++ b/roomserver/storage/shared/prepare.go @@ -13,22 +13,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sqlite3 +package shared import ( "database/sql" ) -// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. -type statementList []struct { - statement **sql.Stmt - sql string +// StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type StatementList []struct { + Statement **sql.Stmt + SQL string } -// prepare the SQL for each statement in the list and assign the result to the prepared statement. -func (s statementList) prepare(db *sql.DB) (err error) { +// Prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s StatementList) Prepare(db *sql.DB) (err error) { for _, statement := range s { - if *statement.statement, err = db.Prepare(statement.sql); err != nil { + if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { return } } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 29c6f73e..17bcd696 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -24,37 +24,34 @@ type Database struct { StateBlockTable tables.StateBlock RoomAliasesTable tables.RoomAliases PrevEventsTable tables.PreviousEvents + InvitesTable tables.Invites + MembershipTable tables.Membership } -// EventTypeNIDs implements state.RoomStateDatabase func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { return d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) } -// EventStateKeys implements query.RoomserverQueryAPIDatabase func (d *Database) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) } -// EventStateKeyNIDs implements state.RoomStateDatabase func (d *Database) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { return d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) } -// StateEntriesForEventIDs implements input.EventDatabase func (d *Database) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) } -// StateEntriesForTuples implements state.RoomStateDatabase func (d *Database) StateEntriesForTuples( ctx context.Context, stateBlockNIDs []types.StateBlockNID, @@ -65,7 +62,6 @@ func (d *Database) StateEntriesForTuples( ) } -// AddState implements input.EventDatabase func (d *Database) AddState( ctx context.Context, roomNID types.RoomNID, @@ -90,28 +86,24 @@ func (d *Database) AddState( return } -// EventNIDs implements query.RoomserverQueryAPIDatabase func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) } -// SetState implements input.EventDatabase func (d *Database) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) } -// StateAtEventIDs implements input.EventDatabase func (d *Database) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) } -// SnapshotNIDFromEventID implements state.RoomStateDatabase func (d *Database) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { @@ -119,14 +111,12 @@ func (d *Database) SnapshotNIDFromEventID( return stateNID, err } -// EventIDs implements input.RoomEventDatabase func (d *Database) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) } -// EventsFromIDs implements query.RoomserverQueryAPIEventDB func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { nidMap, err := d.EventNIDs(ctx, eventIDs) if err != nil { @@ -141,7 +131,6 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type return d.Events(ctx, nids) } -// RoomNID implements query.RoomserverQueryAPIDB func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { roomNID, err := d.RoomsTable.SelectRoomNID(ctx, nil, roomID) if err == sql.ErrNoRows { @@ -150,7 +139,6 @@ func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, e return roomNID, err } -// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { roomNID, err = d.RoomNID(ctx, roomID) if err != nil { @@ -167,7 +155,6 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro return } -// LatestEventIDs implements query.RoomserverQueryAPIDatabase func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, ) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { @@ -190,14 +177,12 @@ func (d *Database) LatestEventIDs( return } -// StateBlockNIDs implements state.RoomStateDatabase func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) } -// StateEntries implements state.RoomStateDatabase func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { @@ -220,34 +205,70 @@ func (d *Database) GetRoomVersionForRoomNID( ) } -// SetRoomAlias implements alias.RoomserverAliasAPIDB func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) } -// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) } -// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) } -// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB func (d *Database) GetCreatorIDForAlias( ctx context.Context, alias string, ) (string, error) { return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) } -// RemoveRoomAlias implements alias.RoomserverAliasAPIDB func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) } -// Events implements input.EventDatabase +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + if err != nil { + return + } + + senderMembershipEventNID, senderMembership, err := + d.MembershipTable.SelectMembershipFromRoomAndTarget( + ctx, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return 0, false, nil + } else if err != nil { + return + } + + return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil +} + +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + if joinOnly { + return d.MembershipTable.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateJoin, localOnly, + ) + } + + return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) +} + +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { @@ -279,7 +300,6 @@ func (d *Database) Events( return results, nil } -// GetTransactionEventID implements input.EventDatabase func (d *Database) GetTransactionEventID( ctx context.Context, transactionID string, sessionID int64, userID string, @@ -291,7 +311,6 @@ func (d *Database) GetTransactionEventID( return eventID, err } -// StoreEvent implements input.EventDatabase func (d *Database) StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 34b067cb..7ff4c1d4 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -59,10 +60,10 @@ func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { if err != nil { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, - }.prepare(db) + }.Prepare(db) } func (s *eventJSONStatements) InsertEventJSON( diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 0d3d323f..2d2c04d2 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -75,12 +76,12 @@ func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { if err != nil { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, - }.prepare(db) + }.Prepare(db) } func (s *eventStateKeyStatements) InsertEventStateKeyNID( @@ -89,7 +90,7 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID( var eventStateKeyNID int64 var err error var res sql.Result - insertStmt := txn.Stmt(s.insertEventStateKeyNIDStmt) + insertStmt := internal.TxStmt(txn, s.insertEventStateKeyNIDStmt) if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { eventStateKeyNID, err = res.LastInsertId() } @@ -100,7 +101,7 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := txn.Stmt(s.selectEventStateKeyNIDStmt) + stmt := internal.TxStmt(txn, s.selectEventStateKeyNIDStmt) err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index d47be545..6d229bbd 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -90,12 +91,12 @@ func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, - }.prepare(db) + }.Prepare(db) } func (s *eventTypeStatements) InsertEventTypeNID( diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 247faa68..491399ce 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -120,7 +121,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, @@ -134,7 +135,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, - }.prepare(db) + }.Prepare(db) } func (s *eventStatements) InsertEvent( diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index a42d18a7..fbe2b072 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -66,21 +68,22 @@ type inviteStatements struct { selectInvitesAboutToRetireStmt *sql.Stmt } -func (s *inviteStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteSchema) +func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} + _, err := db.Exec(inviteSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, {&s.selectInvitesAboutToRetireStmt, selectInvitesAboutToRetireSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *inviteStatements) insertInviteEvent( +func (s *inviteStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, @@ -101,7 +104,7 @@ func (s *inviteStatements) insertInviteEvent( return count != 0, nil } -func (s *inviteStatements) updateInviteRetired( +func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { @@ -127,7 +130,7 @@ func (s *inviteStatements) updateInviteRetired( } // selectInviteActiveForUserInRoom returns a list of sender state key NIDs -func (s *inviteStatements) selectInviteActiveForUserInRoom( +func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 34108af4..5f3076c5 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -20,17 +20,11 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -type membershipState int64 - -const ( - membershipStateLeaveOrBan membershipState = 1 - membershipStateInvite membershipState = 2 - membershipStateJoin membershipState = 3 -) - const membershipSchema = ` CREATE TABLE IF NOT EXISTS roomserver_membership ( room_nid INTEGER NOT NULL, @@ -91,13 +85,14 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func (s *membershipStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(membershipSchema) +func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} + _, err := db.Exec(membershipSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, @@ -106,10 +101,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *membershipStatements) insertMembership( +func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, @@ -119,10 +114,10 @@ func (s *membershipStatements) insertMembership( return err } -func (s *membershipStatements) selectMembershipForUpdate( +func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (membership membershipState, err error) { +) (membership tables.MembershipState, err error) { stmt := internal.TxStmt(txn, s.selectMembershipForUpdateStmt) err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, @@ -130,26 +125,25 @@ func (s *membershipStatements) selectMembershipForUpdate( return } -func (s *membershipStatements) selectMembershipFromRoomAndTarget( - ctx context.Context, txn *sql.Tx, +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( + ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership membershipState, err error) { - selectStmt := internal.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) - err = selectStmt.QueryRowContext( +) (eventNID types.EventNID, membership tables.MembershipState, err error) { + err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } -func (s *membershipStatements) selectMembershipsFromRoom( - ctx context.Context, txn *sql.Tx, +func (s *membershipStatements) SelectMembershipsFromRoom( + ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var selectStmt *sql.Stmt if localOnly { - selectStmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt) + selectStmt = s.selectLocalMembershipsFromRoomStmt } else { - selectStmt = internal.TxStmt(txn, s.selectMembershipsFromRoomStmt) + selectStmt = s.selectMembershipsFromRoomStmt } rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { @@ -167,15 +161,15 @@ func (s *membershipStatements) selectMembershipsFromRoom( return } -func (s *membershipStatements) selectMembershipsFromRoomAndMembership( - ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, membership membershipState, localOnly bool, +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( + ctx context.Context, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt if localOnly { - stmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt) + stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt } else { - stmt = internal.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + stmt = s.selectMembershipsFromRoomAndMembershipStmt } rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { @@ -193,10 +187,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( return } -func (s *membershipStatements) updateMembership( +func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership membershipState, + senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { stmt := internal.TxStmt(txn, s.updateMembershipStmt) diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 6b758ccc..2b187bba 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -63,10 +64,10 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, - }.prepare(db) + }.Prepare(db) } func (s *previousEventStatements) InsertPreviousEvent( diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 686391f6..da5f9161 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -67,13 +68,13 @@ func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { if err != nil { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, - }.prepare(db) + }.Prepare(db) } func (s *roomAliasesStatements) InsertRoomAlias( diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 75b8fec9..d22fceb3 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -22,6 +22,7 @@ import ( "errors" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -78,7 +79,7 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { if err != nil { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, @@ -86,7 +87,7 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, - }.prepare(db) + }.Prepare(db) } func (s *roomStatements) InsertRoomNID( diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go deleted file mode 100644 index df994d50..00000000 --- a/roomserver/storage/sqlite3/sql.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "database/sql" -) - -type statements struct { - eventTypeStatements - eventStateKeyStatements - roomStatements - eventStatements - eventJSONStatements - stateSnapshotStatements - stateBlockStatements - previousEventStatements - roomAliasesStatements - inviteStatements - membershipStatements - transactionStatements -} - -func (s *statements) prepare(db *sql.DB) error { - var err error - - for _, prepare := range []func(db *sql.DB) error{ - s.inviteStatements.prepare, - s.membershipStatements.prepare, - } { - if err = prepare(db); err != nil { - return err - } - } - - return nil -} diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index e3ca0b3c..399fdbf6 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -86,12 +87,12 @@ func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.prepare(db) + }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 42eae1a7..5de3f639 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -60,10 +61,10 @@ func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.prepare(db) + }.Prepare(db) } func (s *stateSnapshotStatements) InsertState( diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 16697f1b..16d89304 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -34,7 +34,6 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { shared.Database - statements statements events tables.Events eventJSON tables.EventJSON eventTypes tables.EventTypes @@ -42,6 +41,8 @@ type Database struct { rooms tables.Rooms transactions tables.Transactions prevEvents tables.PreviousEvents + invites tables.Invites + membership tables.Membership db *sql.DB } @@ -72,9 +73,7 @@ func Open(dataSourceName string) (*Database, error) { // acquire the global mutex and never unlock it because it is waiting for a connection // which it will never obtain. d.db.SetMaxOpenConns(20) - if err = d.statements.prepare(d.db); err != nil { - return nil, err - } + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) if err != nil { return nil, err @@ -115,6 +114,14 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + d.invites, err = NewSqliteInvitesTable(d.db) + if err != nil { + return nil, err + } + d.membership, err = NewSqliteMembershipTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventsTable: d.events, @@ -127,6 +134,8 @@ func Open(dataSourceName string) (*Database, error) { StateSnapshotTable: stateSnapshot, PrevEventsTable: d.prevEvents, RoomAliasesTable: roomAliases, + InvitesTable: d.invites, + MembershipTable: d.membership, } return &d, nil } @@ -305,15 +314,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta return } -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, @@ -367,7 +367,7 @@ type membershipUpdater struct { d *Database roomNID types.RoomNID targetUserNID types.EventStateKeyNID - membership membershipState + membership tables.MembershipState } func (d *Database) membershipUpdaterTxn( @@ -378,11 +378,11 @@ func (d *Database) membershipUpdaterTxn( targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) if err != nil { return nil, err } @@ -395,17 +395,17 @@ func (d *Database) membershipUpdaterTxn( // IsInvite implements types.MembershipUpdater func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite + return u.membership == tables.MembershipStateInvite } // IsJoin implements types.MembershipUpdater func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin + return u.membership == tables.MembershipStateJoin } // IsLeave implements types.MembershipUpdater func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan + return u.membership == tables.MembershipStateLeaveOrBan } // SetToInvite implements types.MembershipUpdater @@ -415,15 +415,15 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted if err != nil { return err } - inserted, err = u.d.statements.insertInviteEvent( + inserted, err = u.d.invites.InsertInviteEvent( u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { return err } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + if u.membership != tables.MembershipStateInvite { + if err = u.d.membership.UpdateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, ); err != nil { return err } @@ -443,7 +443,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd // If this is a join event update, there is no invite to update if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( + inviteEventIDs, err = u.d.invites.UpdateInviteRetired( u.ctx, txn, u.roomNID, u.targetUserNID, ) if err != nil { @@ -457,10 +457,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd return err } - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.membership.UpdateMembership( u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], + tables.MembershipStateJoin, nIDs[eventID], ); err != nil { return err } @@ -478,7 +478,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inv if err != nil { return err } - inviteEventIDs, err = u.d.statements.updateInviteRetired( + inviteEventIDs, err = u.d.invites.UpdateInviteRetired( u.ctx, txn, u.roomNID, u.targetUserNID, ) if err != nil { @@ -491,10 +491,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inv return err } - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.membership.UpdateMembership( u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], + tables.MembershipStateLeaveOrBan, nIDs[eventID], ); err != nil { return err } @@ -504,52 +504,6 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inv return } -// GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID) - if err != nil { - return err - } - - membershipEventNID, _, err = - d.statements.selectMembershipFromRoomAndTarget( - ctx, txn, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return nil - } - if err != nil { - return err - } - stillInRoom = true - return nil - }) - - return -} - -// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, -) (eventNIDs []types.EventNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - if joinOnly { - eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( - ctx, txn, roomNID, membershipStateJoin, localOnly, - ) - return nil - } - - eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly) - return nil - }) - return -} - type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 37ea15c0..7b489ac6 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -54,10 +55,10 @@ func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { return nil, err } - return s, statementList{ + return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, - }.prepare(db) + }.Prepare(db) } func (s *transactionStatements) InsertTransaction( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 56ef5dd3..11cff8a8 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -96,3 +96,27 @@ type PreviousEvents interface { // Returns sql.ErrNoRows if the event reference doesn't exist. SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error } + +type Invites interface { + InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) + UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) + // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs + SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error) +} + +type MembershipState int64 + +const ( + MembershipStateLeaveOrBan MembershipState = 1 + MembershipStateInvite MembershipState = 2 + MembershipStateJoin MembershipState = 3 +) + +type Membership interface { + InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error + SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) + SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error) + SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error +}