diff --git a/src/github.com/matrix-org/dendrite/roomserver/alias/alias.go b/src/github.com/matrix-org/dendrite/roomserver/alias/alias.go index 7cfe6083..0ea3f238 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/alias/alias.go +++ b/src/github.com/matrix-org/dendrite/roomserver/alias/alias.go @@ -32,16 +32,16 @@ import ( type RoomserverAliasAPIDatabase interface { // Save a given room alias with the room ID it refers to. // Returns an error if there was a problem talking to the database. - SetRoomAlias(alias string, roomID string) error + SetRoomAlias(ctx context.Context, alias string, roomID string) error // Look up the room ID a given alias refers to. // Returns an error if there was a problem talking to the database. - GetRoomIDFromAlias(alias string) (string, error) + GetRoomIDFromAlias(ctx context.Context, alias string) (string, error) // Look up all aliases referring to a given room ID. // Returns an error if there was a problem talking to the database. - GetAliasesFromRoomID(roomID string) ([]string, error) + GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) // Remove a given room alias. // Returns an error if there was a problem talking to the database. - RemoveRoomAlias(alias string) error + RemoveRoomAlias(ctx context.Context, alias string) error } // RoomserverAliasAPI is an implementation of api.RoomserverAliasAPI @@ -59,7 +59,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias( response *api.SetRoomAliasResponse, ) error { // Check if the alias isn't already referring to a room - roomID, err := r.DB.GetRoomIDFromAlias(request.Alias) + roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias) if err != nil { return err } @@ -71,7 +71,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias( response.AliasExists = false // Save the new alias - if err := r.DB.SetRoomAlias(request.Alias, request.RoomID); err != nil { + if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil { return err } @@ -93,7 +93,7 @@ func (r *RoomserverAliasAPI) GetAliasRoomID( response *api.GetAliasRoomIDResponse, ) error { // Look up the room ID in the database - roomID, err := r.DB.GetRoomIDFromAlias(request.Alias) + roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias) if err != nil { return err } @@ -109,18 +109,21 @@ func (r *RoomserverAliasAPI) RemoveRoomAlias( response *api.RemoveRoomAliasResponse, ) error { // Look up the room ID in the database - roomID, err := r.DB.GetRoomIDFromAlias(request.Alias) + roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias) if err != nil { return err } // Remove the dalias from the database - if err := r.DB.RemoveRoomAlias(request.Alias); err != nil { + if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil { return err } // Send an updated m.room.aliases event - if err := r.sendUpdatedAliasesEvent(ctx, request.UserID, roomID); err != nil { + // At this point we've already committed the alias to the database so we + // shouldn't cancel this request. + // TODO: Ensure that we send unsent events when if server restarts. + if err := r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, roomID); err != nil { return err } @@ -147,7 +150,7 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent( // Retrieve the updated list of aliases, marhal it and set it as the // event's content - aliases, err := r.DB.GetAliasesFromRoomID(roomID) + aliases, err := r.DB.GetAliasesFromRoomID(ctx, roomID) if err != nil { return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go b/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go index fbb7d7c0..74be2ed3 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go @@ -15,16 +15,23 @@ package input import ( + "context" + "sort" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" - "sort" ) // checkAuthEvents checks that the event passes authentication checks // Returns the numeric IDs for the auth events. -func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) { +func checkAuthEvents( + ctx context.Context, + db RoomEventDatabase, + event gomatrixserverlib.Event, + authEventIDs []string, +) ([]types.EventNID, error) { // Grab the numeric IDs for the supplied auth state events from the database. - authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs) + authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs) if err != nil { return nil, err } @@ -34,7 +41,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) if err != nil { return nil, err } @@ -84,7 +91,10 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even } func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event { - eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID}) + eventNID, ok := ae.state.lookup(types.StateKeyTuple{ + EventTypeNID: typeNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) if !ok { return nil } @@ -100,7 +110,10 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * if !ok { return nil } - eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID}) + eventNID, ok := ae.state.lookup(types.StateKeyTuple{ + EventTypeNID: typeNID, + EventStateKeyNID: stateKeyNID, + }) if !ok { return nil } @@ -113,6 +126,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * // loadAuthEvents loads the events needed for authentication from the supplied room state. func loadAuthEvents( + ctx context.Context, db RoomEventDatabase, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, @@ -121,7 +135,7 @@ func loadAuthEvents( var neededStateKeys []string neededStateKeys = append(neededStateKeys, needed.Member...) neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) - if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil { + if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil { return } @@ -135,34 +149,52 @@ func loadAuthEvents( eventNIDs = append(eventNIDs, eventNID) } } - if result.events, err = db.Events(eventNIDs); err != nil { + if result.events, err = db.Events(ctx, eventNIDs); err != nil { return } return } // stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. -func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { +func stateKeyTuplesNeeded( + stateKeyNIDMap map[string]types.EventStateKeyNID, + stateNeeded gomatrixserverlib.StateNeeded, +) []types.StateKeyTuple { var keyTuples []types.StateKeyTuple if stateNeeded.Create { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomCreateNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) } if stateNeeded.PowerLevels { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomPowerLevelsNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) } if stateNeeded.JoinRules { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomJoinRulesNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) } for _, member := range stateNeeded.Member { stateKeyNID, ok := stateKeyNIDMap[member] if ok { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }) } } for _, token := range stateNeeded.ThirdPartyInvite { stateKeyNID, ok := stateKeyNIDMap[token] if ok { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomThirdPartyInviteNID, + EventStateKeyNID: stateKeyNID, + }) } } return keyTuples diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index 88c60447..c7a33f52 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -15,6 +15,7 @@ package input import ( + "context" "fmt" "github.com/matrix-org/dendrite/common" @@ -28,22 +29,38 @@ import ( type RoomEventDatabase interface { state.RoomStateDatabase // Stores a matrix room event in the database - StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) + StoreEvent( + ctx context.Context, + event gomatrixserverlib.Event, + authEventNIDs []types.EventNID, + ) (types.RoomNID, types.StateAtEvent, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. - StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) + StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, + ) ([]types.StateEntry, error) // Set the state at an event. - SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error + SetState( + ctx context.Context, + eventNID types.EventNID, + stateNID types.StateSnapshotNID, + ) error // Look up the latest events in a room in preparation for an update. // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // Returns the latest events in the room and the last eventID sent to the log along with an updater. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(roomNID types.RoomNID) (updater types.RoomRecentEventsUpdater, err error) + GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, + ) (updater types.RoomRecentEventsUpdater, err error) // Look up the string event IDs for a list of numeric event IDs - EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) + EventIDs( + ctx context.Context, eventNIDs []types.EventNID, + ) (map[types.EventNID]string, error) // Build a membership updater for the target user in a room. - MembershipUpdater(roomID, targerUserID string) (types.MembershipUpdater, error) + MembershipUpdater( + ctx context.Context, roomID, targerUserID string, + ) (types.MembershipUpdater, error) } // OutputRoomEventWriter has the APIs needed to write an event to the output logs. @@ -52,18 +69,23 @@ type OutputRoomEventWriter interface { WriteOutputEvents(roomID string, updates []api.OutputEvent) error } -func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputRoomEvent) error { +func processRoomEvent( + ctx context.Context, + db RoomEventDatabase, + ow OutputRoomEventWriter, + input api.InputRoomEvent, +) error { // Parse and validate the event JSON event := input.Event // Check that the event passes authentication checks and work out the numeric IDs for the auth events. - authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs) + authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs) if err != nil { return err } // Store the event - roomNID, stateAtEvent, err := db.StoreEvent(event, authEventNIDs) + roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, authEventNIDs) if err != nil { return err } @@ -82,20 +104,20 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api. // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = db.StateEntriesForEventIDs(input.StateEventIDs); err != nil { + if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { return err } - if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(roomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { return nil } } else { // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(db, event, roomNID); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil { return err } } - db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) } if input.Kind == api.KindBackfill { @@ -104,14 +126,19 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api. } // Update the extremities of the event graph for the room - if err := updateLatestEvents(db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil { + if err := updateLatestEvents(ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil { return err } return nil } -func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputInviteEvent) (err error) { +func processInviteEvent( + ctx context.Context, + db RoomEventDatabase, + ow OutputRoomEventWriter, + input api.InputInviteEvent, +) (err error) { if input.Event.StateKey() == nil { return fmt.Errorf("invite must be a state event") } @@ -119,7 +146,7 @@ func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input ap roomID := input.Event.RoomID() targetUserID := *input.Event.StateKey() - updater, err := db.MembershipUpdater(roomID, targetUserID) + updater, err := db.MembershipUpdater(ctx, roomID, targetUserID) if err != nil { return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/input.go b/src/github.com/matrix-org/dendrite/roomserver/input/input.go index e3918a3c..27797096 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/input.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/input.go @@ -59,12 +59,12 @@ func (r *RoomserverInputAPI) InputRoomEvents( response *api.InputRoomEventsResponse, ) error { for i := range request.InputRoomEvents { - if err := processRoomEvent(r.DB, r, request.InputRoomEvents[i]); err != nil { + if err := processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { return err } } for i := range request.InputInviteEvents { - if err := processInviteEvent(r.DB, r, request.InputInviteEvents[i]); err != nil { + if err := processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { return err } } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go b/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go index fe485607..c20613db 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go @@ -16,6 +16,7 @@ package input import ( "bytes" + "context" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" @@ -42,6 +43,7 @@ import ( // 7 <----- latest // func updateLatestEvents( + ctx context.Context, db RoomEventDatabase, ow OutputRoomEventWriter, roomNID types.RoomNID, @@ -49,7 +51,7 @@ func updateLatestEvents( event gomatrixserverlib.Event, sendAsServer string, ) (err error) { - updater, err := db.GetLatestEventsForUpdate(roomNID) + updater, err := db.GetLatestEventsForUpdate(ctx, roomNID) if err != nil { return } @@ -57,7 +59,7 @@ func updateLatestEvents( defer common.EndTransaction(updater, &succeeded) u := latestEventsUpdater{ - db: db, updater: updater, ow: ow, roomNID: roomNID, + ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID, stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer, } if err = u.doUpdateLatestEvents(); err != nil { @@ -73,6 +75,7 @@ func updateLatestEvents( // The state could be passed using function arguments, but it becomes impractical // when there are so many variables to pass around. type latestEventsUpdater struct { + ctx context.Context db RoomEventDatabase updater types.RoomRecentEventsUpdater ow OutputRoomEventWriter @@ -133,7 +136,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { return err } - updates, err := updateMemberships(u.db, u.updater, u.removed, u.added) + updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added) if err != nil { return err } @@ -174,18 +177,22 @@ func (u *latestEventsUpdater) latestState() error { for i := range u.latest { latestStateAtEvents[i] = u.latest[i].StateAtEvent } - u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(u.db, u.roomNID, latestStateAtEvents) + u.newStateNID, err = state.CalculateAndStoreStateAfterEvents( + u.ctx, u.db, u.roomNID, latestStateAtEvents, + ) if err != nil { return err } - u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(u.db, u.oldStateNID, u.newStateNID) + u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots( + u.ctx, u.db, u.oldStateNID, u.newStateNID, + ) if err != nil { return err } u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots( - u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, + u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ) if err != nil { return err @@ -193,7 +200,12 @@ func (u *latestEventsUpdater) latestState() error { return nil } -func calculateLatest(oldLatest []types.StateAtEventAndReference, alreadyReferenced bool, prevEvents []gomatrixserverlib.EventReference, newEvent types.StateAtEventAndReference) []types.StateAtEventAndReference { +func calculateLatest( + oldLatest []types.StateAtEventAndReference, + alreadyReferenced bool, + prevEvents []gomatrixserverlib.EventReference, + newEvent types.StateAtEventAndReference, +) []types.StateAtEventAndReference { var alreadyInLatest bool var newLatest []types.StateAtEventAndReference for _, l := range oldLatest { @@ -253,7 +265,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) stateEventNIDs = append(stateEventNIDs, entry.EventNID) } stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - eventIDMap, err := u.db.EventIDs(stateEventNIDs) + eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/membership.go b/src/github.com/matrix-org/dendrite/roomserver/input/membership.go index 6eeb0914..f4d8e02c 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/membership.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/membership.go @@ -15,6 +15,7 @@ package input import ( + "context" "fmt" "github.com/matrix-org/dendrite/roomserver/api" @@ -27,7 +28,10 @@ import ( // Returns a list of output events to write to the kafka log to inform the // consumers about the invites added or retired by the change in current state. func updateMemberships( - db RoomEventDatabase, updater types.RoomRecentEventsUpdater, removed, added []types.StateEntry, + ctx context.Context, + db RoomEventDatabase, + updater types.RoomRecentEventsUpdater, + removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { changes := membershipChanges(removed, added) var eventNIDs []types.EventNID @@ -43,7 +47,7 @@ func updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := db.Events(eventNIDs) + events, err := db.Events(ctx, eventNIDs) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/query/query.go b/src/github.com/matrix-org/dendrite/roomserver/query/query.go index a741c28a..902bf56a 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/query/query.go +++ b/src/github.com/matrix-org/dendrite/roomserver/query/query.go @@ -33,35 +33,47 @@ type RoomserverQueryAPIDatabase interface { // Look up the numeric ID for the room. // Returns 0 if the room doesn't exists. // Returns an error if there was a problem talking to the database. - RoomNID(roomID string) (types.RoomNID, error) + RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) // Look up event references for the latest events in the room and the current state snapshot. // Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns an error if there was a problem talking to the database. - LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, + ) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) // Look up the numeric IDs for a list of events. // Returns an error if there was a problem talking to the database. - EventNIDs(eventIDs []string) (map[string]types.EventNID, error) + EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) // Lookup the event IDs for a batch of event numeric IDs. // Returns an error if the retrieval went wrong. - EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) + EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // Lookup the membership of a given user in a given room. // Returns the numeric ID of the latest membership event sent from this user // in this room, along a boolean set to true if the user is still in this room, // false if not. // Returns an error if there was a problem talking to the database. - GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) + GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, + ) (membershipEventNID types.EventNID, stillInRoom bool, err error) // Lookup the membership event numeric IDs for all user that are or have // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. // Returns an error if there was a problem talking to the database. - GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) + GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, + ) ([]types.EventNID, error) // Look up the active invites targeting a user in a room and return the // numeric state key IDs for the user IDs who sent them. // Returns an error if there was a problem talking to the database. - GetInvitesForUser(roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserNIDs []types.EventStateKeyNID, err error) + GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, + ) (senderUserNIDs []types.EventStateKeyNID, err error) // Look up the string event state keys for a list of numeric event state keys // Returns an error if there was a problem talking to the database. - EventStateKeys([]types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + EventStateKeys( + context.Context, []types.EventStateKeyNID, + ) (map[types.EventStateKeyNID]string, error) } // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI @@ -76,7 +88,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( response *api.QueryLatestEventsAndStateResponse, ) error { response.QueryLatestEventsAndStateRequest = *request - roomNID, err := r.DB.RoomNID(request.RoomID) + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { return err } @@ -85,18 +97,21 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( } response.RoomExists = true var currentStateSnapshotNID types.StateSnapshotNID - response.LatestEvents, currentStateSnapshotNID, response.Depth, err = r.DB.LatestEventIDs(roomNID) + response.LatestEvents, currentStateSnapshotNID, response.Depth, err = + r.DB.LatestEventIDs(ctx, roomNID) if err != nil { return err } // Look up the currrent state for the requested tuples. - stateEntries, err := state.LoadStateAtSnapshotForStringTuples(r.DB, currentStateSnapshotNID, request.StateToFetch) + stateEntries, err := state.LoadStateAtSnapshotForStringTuples( + ctx, r.DB, currentStateSnapshotNID, request.StateToFetch, + ) if err != nil { return err } - stateEvents, err := r.loadStateEvents(stateEntries) + stateEvents, err := r.loadStateEvents(ctx, stateEntries) if err != nil { return err } @@ -112,7 +127,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( response *api.QueryStateAfterEventsResponse, ) error { response.QueryStateAfterEventsRequest = *request - roomNID, err := r.DB.RoomNID(request.RoomID) + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { return err } @@ -121,7 +136,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( } response.RoomExists = true - prevStates, err := r.DB.StateAtEventIDs(request.PrevEventIDs) + prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) if err != nil { switch err.(type) { case types.MissingEventError: @@ -133,12 +148,14 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( response.PrevEventsExist = true // Look up the currrent state for the requested tuples. - stateEntries, err := state.LoadStateAfterEventsForStringTuples(r.DB, prevStates, request.StateToFetch) + stateEntries, err := state.LoadStateAfterEventsForStringTuples( + ctx, r.DB, prevStates, request.StateToFetch, + ) if err != nil { return err } - stateEvents, err := r.loadStateEvents(stateEntries) + stateEvents, err := r.loadStateEvents(ctx, stateEntries) if err != nil { return err } @@ -155,7 +172,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID( ) error { response.QueryEventsByIDRequest = *request - eventNIDMap, err := r.DB.EventNIDs(request.EventIDs) + eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) if err != nil { return err } @@ -165,7 +182,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID( eventNIDs = append(eventNIDs, nid) } - events, err := r.loadEvents(eventNIDs) + events, err := r.loadEvents(ctx, eventNIDs) if err != nil { return err } @@ -174,16 +191,20 @@ func (r *RoomserverQueryAPI) QueryEventsByID( return nil } -func (r *RoomserverQueryAPI) loadStateEvents(stateEntries []types.StateEntry) ([]gomatrixserverlib.Event, error) { +func (r *RoomserverQueryAPI) loadStateEvents( + ctx context.Context, stateEntries []types.StateEntry, +) ([]gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID } - return r.loadEvents(eventNIDs) + return r.loadEvents(ctx, eventNIDs) } -func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixserverlib.Event, error) { - stateEvents, err := r.DB.Events(eventNIDs) +func (r *RoomserverQueryAPI) loadEvents( + ctx context.Context, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.Event, error) { + stateEvents, err := r.DB.Events(ctx, eventNIDs) if err != nil { return nil, err } @@ -201,12 +222,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( request *api.QueryMembershipsForRoomRequest, response *api.QueryMembershipsForRoomResponse, ) error { - roomNID, err := r.DB.RoomNID(request.RoomID) + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { return err } - membershipEventNID, stillInRoom, err := r.DB.GetMembership(roomNID, request.Sender) + membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.Sender) if err != nil { return nil } @@ -223,14 +244,14 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( var events []types.Event if stillInRoom { var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(roomNID, request.JoinedOnly) + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly) if err != nil { return err } - events, err = r.DB.Events(eventNIDs) + events, err = r.DB.Events(ctx, eventNIDs) } else { - events, err = r.getMembershipsBeforeEventNID(membershipEventNID, request.JoinedOnly) + events, err = r.getMembershipsBeforeEventNID(ctx, membershipEventNID, request.JoinedOnly) } if err != nil { @@ -249,22 +270,24 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( // of the event's room as it was when this event was fired, then filters the state events to // only keep the "m.room.member" events with a "join" membership. These events are returned. // Returns an error if there was an issue fetching the events. -func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNID, joinedOnly bool) ([]types.Event, error) { +func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( + ctx context.Context, eventNID types.EventNID, joinedOnly bool, +) ([]types.Event, error) { events := []types.Event{} // Lookup the event NID - eIDs, err := r.DB.EventIDs([]types.EventNID{eventNID}) + eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { return nil, err } eventIDs := []string{eIDs[eventNID]} - prevState, err := r.DB.StateAtEventIDs(eventIDs) + prevState, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { return nil, err } // Fetch the state as it was when this event was fired - stateEntries, err := state.LoadCombinedStateAfterEvents(r.DB, prevState) + stateEntries, err := state.LoadCombinedStateAfterEvents(ctx, r.DB, prevState) if err != nil { return nil, err } @@ -278,7 +301,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI } // Get all of the events in this state - stateEvents, err := r.DB.Events(eventNIDs) + stateEvents, err := r.DB.Events(ctx, eventNIDs) if err != nil { return nil, err } @@ -304,27 +327,27 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI // QueryInvitesForUser implements api.RoomserverQueryAPI func (r *RoomserverQueryAPI) QueryInvitesForUser( - _ context.Context, + ctx context.Context, request *api.QueryInvitesForUserRequest, response *api.QueryInvitesForUserResponse, ) error { - roomNID, err := r.DB.RoomNID(request.RoomID) + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { return err } - targetUserNIDs, err := r.DB.EventStateKeyNIDs([]string{request.TargetUserID}) + targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID}) if err != nil { return err } targetUserNID := targetUserNIDs[request.TargetUserID] - senderUserNIDs, err := r.DB.GetInvitesForUser(roomNID, targetUserNID) + senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID) if err != nil { return err } - senderUserIDs, err := r.DB.EventStateKeys(senderUserNIDs) + senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs) if err != nil { return err } @@ -342,14 +365,14 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent( request *api.QueryServerAllowedToSeeEventRequest, response *api.QueryServerAllowedToSeeEventResponse, ) error { - stateEntries, err := state.LoadStateAtEvent(r.DB, request.EventID) + stateEntries, err := state.LoadStateAtEvent(ctx, r.DB, request.EventID) if err != nil { return err } // TODO: We probably want to make it so that we don't have to pull // out all the state if possible. - stateAtEvent, err := r.loadStateEvents(stateEntries) + stateAtEvent, err := r.loadStateEvents(ctx, stateEntries) if err != nil { return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/state/state.go b/src/github.com/matrix-org/dendrite/roomserver/state/state.go index 0323be37..2a0b7f57 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/state/state.go +++ b/src/github.com/matrix-org/dendrite/roomserver/state/state.go @@ -17,6 +17,7 @@ package state import ( + "context" "fmt" "sort" "time" @@ -30,49 +31,58 @@ import ( // A RoomStateDatabase has the storage APIs needed to load state from the database type RoomStateDatabase interface { // Store the room state at an event in the database - AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, + ) (types.StateSnapshotNID, error) // Look up the state of a room at each event for a list of string event IDs. // Returns an error if there is an error talking to the database // Returns a types.MissingEventError if the room state for the event IDs aren't in the database - StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) // Look up the numeric IDs for a list of string event types. // Returns a map from string event type to numeric ID for the event type. - EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) // Look up the numeric IDs for a list of string event state keys. // Returns a map from string state key to numeric ID for the state key. - EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) // Look up the numeric state data IDs for each numeric state snapshot ID // The returned slice is sorted by numeric state snapshot ID. - StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) // Look up the state data for each numeric state data ID // The returned slice is sorted by numeric state data ID. - StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) // Look up the state data for the state key tuples for each numeric state block ID // This is used to fetch a subset of the room state at a snapshot. // If a block doesn't contain any of the requested tuples then it can be discarded from the result. // The returned slice is sorted by numeric state block ID. - StateEntriesForTuples(stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ( - []types.StateEntryList, error, - ) + StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, + ) ([]types.StateEntryList, error) // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string - SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) } // LoadStateAtSnapshot loads the full state of a room at a particular snapshot. // This is typically the state before an event or the current state of a room. // Returns a sorted list of state entries or an error if there was a problem talking to the database. -func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) { - stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID}) +func LoadStateAtSnapshot( + ctx context.Context, db RoomStateDatabase, stateNID types.StateSnapshotNID, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { return nil, err } // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. stateBlockNIDList := stateBlockNIDLists[0] - stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs) + stateEntryLists, err := db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) if err != nil { return nil, err } @@ -100,13 +110,15 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) } // LoadStateAtEvent loads the full state of a room at a particular event. -func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry, error) { - snapshotNID, err := db.SnapshotNIDFromEventID(eventID) +func LoadStateAtEvent( + ctx context.Context, db RoomStateDatabase, eventID string, +) ([]types.StateEntry, error) { + snapshotNID, err := db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { return nil, err } - stateEntries, err := LoadStateAtSnapshot(db, snapshotNID) + stateEntries, err := LoadStateAtSnapshot(ctx, db, snapshotNID) if err != nil { return nil, err } @@ -116,7 +128,9 @@ func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry, // LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events // and combines those snapshots together into a single list. -func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) { +func LoadCombinedStateAfterEvents( + ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent, +) ([]types.StateEntry, error) { stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) for i, state := range prevStates { stateNIDs[i] = state.BeforeStateSnapshotNID @@ -125,7 +139,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State // Deduplicate the IDs before passing them to the database. // There could be duplicates because the events could be state events where // the snapshot of the room state before them was the same. - stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs)) + stateBlockNIDLists, err := db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs)) if err != nil { return nil, err } @@ -138,7 +152,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State // Deduplicate the IDs before passing them to the database. // There could be duplicates because a block of state entries could be reused by // multiple snapshots. - stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs)) + stateEntryLists, err := db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) if err != nil { return nil, err } @@ -186,9 +200,9 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State } // DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. -func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID) ( - removed, added []types.StateEntry, err error, -) { +func DifferenceBetweeenStateSnapshots( + ctx context.Context, db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID, +) (removed, added []types.StateEntry, err error) { if oldStateNID == newStateNID { // If the snapshot NIDs are the same then nothing has changed return nil, nil, nil @@ -197,13 +211,13 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat var oldEntries []types.StateEntry var newEntries []types.StateEntry if oldStateNID != 0 { - oldEntries, err = LoadStateAtSnapshot(db, oldStateNID) + oldEntries, err = LoadStateAtSnapshot(ctx, db, oldStateNID) if err != nil { return nil, nil, err } } if newStateNID != 0 { - newEntries, err = LoadStateAtSnapshot(db, newStateNID) + newEntries, err = LoadStateAtSnapshot(ctx, db, newStateNID) if err != nil { return nil, nil, err } @@ -246,19 +260,26 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat // This is typically the state before an event or the current state of a room. // Returns a sorted list of state entries or an error if there was a problem talking to the database. func LoadStateAtSnapshotForStringTuples( - db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple, + ctx context.Context, + db RoomStateDatabase, + stateNID types.StateSnapshotNID, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { - numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples) + numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples) if err != nil { return nil, err } - return loadStateAtSnapshotForNumericTuples(db, stateNID, numericTuples) + return loadStateAtSnapshotForNumericTuples(ctx, db, stateNID, numericTuples) } // stringTuplesToNumericTuples converts the string state key tuples into numeric IDs // If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. // Returns an error if there was a problem talking to the database. -func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateKeyTuple, error) { +func stringTuplesToNumericTuples( + ctx context.Context, + db RoomStateDatabase, + stringTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateKeyTuple, error) { eventTypes := make([]string, len(stringTuples)) stateKeys := make([]string, len(stringTuples)) for i := range stringTuples { @@ -266,12 +287,12 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse stateKeys[i] = stringTuples[i].StateKey } eventTypes = util.UniqueStrings(eventTypes) - eventTypeMap, err := db.EventTypeNIDs(eventTypes) + eventTypeMap, err := db.EventTypeNIDs(ctx, eventTypes) if err != nil { return nil, err } stateKeys = util.UniqueStrings(stateKeys) - stateKeyMap, err := db.EventStateKeyNIDs(stateKeys) + stateKeyMap, err := db.EventStateKeyNIDs(ctx, stateKeys) if err != nil { return nil, err } @@ -297,16 +318,21 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse // This is typically the state before an event or the current state of a room. // Returns a sorted list of state entries or an error if there was a problem talking to the database. func loadStateAtSnapshotForNumericTuples( - db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple, + ctx context.Context, + db RoomStateDatabase, + stateNID types.StateSnapshotNID, + stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID}) + stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { return nil, err } // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. stateBlockNIDList := stateBlockNIDLists[0] - stateEntryLists, err := db.StateEntriesForTuples(stateBlockNIDList.StateBlockNIDs, stateKeyTuples) + stateEntryLists, err := db.StateEntriesForTuples( + ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples, + ) if err != nil { return nil, err } @@ -341,23 +367,29 @@ func loadStateAtSnapshotForNumericTuples( // This is typically the state before an event. // Returns a sorted list of state entries or an error if there was a problem talking to the database. func LoadStateAfterEventsForStringTuples( - db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, + ctx context.Context, + db RoomStateDatabase, + prevStates []types.StateAtEvent, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { - numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples) + numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples) if err != nil { return nil, err } - return loadStateAfterEventsForNumericTuples(db, prevStates, numericTuples) + return loadStateAfterEventsForNumericTuples(ctx, db, prevStates, numericTuples) } func loadStateAfterEventsForNumericTuples( - db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, + ctx context.Context, + db RoomStateDatabase, + prevStates []types.StateAtEvent, + stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { if len(prevStates) == 1 { // Fast path for a single event. prevState := prevStates[0] result, err := loadStateAtSnapshotForNumericTuples( - db, prevState.BeforeStateSnapshotNID, stateKeyTuples, + ctx, db, prevState.BeforeStateSnapshotNID, stateKeyTuples, ) if err != nil { return nil, err @@ -390,7 +422,7 @@ func loadStateAfterEventsForNumericTuples( // TODO: Add metrics for this as it could take a long time for big rooms // with large conflicts. - fullState, _, _, err := calculateStateAfterManyEvents(db, prevStates) + fullState, _, _, err := calculateStateAfterManyEvents(ctx, db, prevStates) if err != nil { return nil, err } @@ -403,7 +435,10 @@ func loadStateAfterEventsForNumericTuples( for _, tuple := range stateKeyTuples { eventNID, ok := stateEntryMap(fullState).lookup(tuple) if ok { - result = append(result, types.StateEntry{tuple, eventNID}) + result = append(result, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) } } sort.Sort(stateEntrySorter(result)) @@ -509,7 +544,10 @@ func init() { // Stores the snapshot of the state in the database. // Returns a numeric ID for the snapshot of the state before the event. func CalculateAndStoreStateBeforeEvent( - db RoomStateDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID, + ctx context.Context, + db RoomStateDatabase, + event gomatrixserverlib.Event, + roomNID types.RoomNID, ) (types.StateSnapshotNID, error) { // Load the state at the prev events. prevEventRefs := event.PrevEvents() @@ -518,25 +556,30 @@ func CalculateAndStoreStateBeforeEvent( prevEventIDs[i] = prevEventRefs[i].EventID } - prevStates, err := db.StateAtEventIDs(prevEventIDs) + prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs) if err != nil { return 0, err } // The state before this event will be the state after the events that came before it. - return CalculateAndStoreStateAfterEvents(db, roomNID, prevStates) + return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates) } // CalculateAndStoreStateAfterEvents finds the room state after the given events. // Stores the resulting state in the database and returns a numeric ID for that snapshot. -func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) { +func CalculateAndStoreStateAfterEvents( + ctx context.Context, + db RoomStateDatabase, + roomNID types.RoomNID, + prevStates []types.StateAtEvent, +) (types.StateSnapshotNID, error) { metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} if len(prevStates) == 0 { // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - return metrics.stop(db.AddState(roomNID, nil, nil)) + return metrics.stop(db.AddState(ctx, roomNID, nil, nil)) } if len(prevStates) == 1 { @@ -551,7 +594,9 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN } // The previous event was a state event so we need to store a copy // of the previous state updated with that event. - stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}) + stateBlockNIDLists, err := db.StateBlockNIDs( + ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, + ) if err != nil { metrics.algorithm = "_load_state_blocks" return metrics.stop(0, err) @@ -562,14 +607,14 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN // add the state event as a block of size one to the end of the blocks. metrics.algorithm = "single_delta" return metrics.stop(db.AddState( - roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, )) } // If there are too many deltas then we need to calculate the full state // So fall through to calculateAndStoreStateAfterManyEvents } - return calculateAndStoreStateAfterManyEvents(db, roomNID, prevStates, metrics) + return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics) } // maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. @@ -583,10 +628,15 @@ const maxStateBlockNIDs = 64 // This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. // Stores the resulting state and returns a numeric ID for the snapshot. func calculateAndStoreStateAfterManyEvents( - db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent, metrics calculateStateMetrics, + ctx context.Context, + db RoomStateDatabase, + roomNID types.RoomNID, + prevStates []types.StateAtEvent, + metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { - state, algorithm, conflictLength, err := calculateStateAfterManyEvents(db, prevStates) + state, algorithm, conflictLength, err := + calculateStateAfterManyEvents(ctx, db, prevStates) metrics.algorithm = algorithm if err != nil { return metrics.stop(0, err) @@ -596,16 +646,16 @@ func calculateAndStoreStateAfterManyEvents( // previous state. metrics.conflictLength = conflictLength metrics.fullStateLength = len(state) - return metrics.stop(db.AddState(roomNID, nil, state)) + return metrics.stop(db.AddState(ctx, roomNID, nil, state)) } func calculateStateAfterManyEvents( - db RoomStateDatabase, prevStates []types.StateAtEvent, + ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent, ) (state []types.StateEntry, algorithm string, conflictLength int, err error) { var combined []types.StateEntry // Conflict resolution. // First stage: load the state after each of the prev events. - combined, err = LoadCombinedStateAfterEvents(db, prevStates) + combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates) if err != nil { algorithm = "_load_combined_state" return @@ -635,7 +685,7 @@ func calculateStateAfterManyEvents( } var resolved []types.StateEntry - resolved, err = resolveConflicts(db, notConflicted, conflicts) + resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts) if err != nil { algorithm = "_resolve_conflicts" return @@ -657,10 +707,14 @@ func calculateStateAfterManyEvents( // Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. // The returned list is sorted by state key tuple. // Returns an error if there was a problem talking to the database. -func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.StateEntry) ([]types.StateEntry, error) { +func resolveConflicts( + ctx context.Context, + db RoomStateDatabase, + notConflicted, conflicted []types.StateEntry, +) ([]types.StateEntry, error) { // Load the conflicted events - conflictedEvents, eventIDMap, err := loadStateEvents(db, conflicted) + conflictedEvents, eventIDMap, err := loadStateEvents(ctx, db, conflicted) if err != nil { return nil, err } @@ -672,7 +726,7 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St var neededStateKeys []string neededStateKeys = append(neededStateKeys, needed.Member...) neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) - stateKeyNIDMap, err := db.EventStateKeyNIDs(neededStateKeys) + stateKeyNIDMap, err := db.EventStateKeyNIDs(ctx, neededStateKeys) if err != nil { return nil, err } @@ -682,10 +736,13 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St var authEntries []types.StateEntry for _, tuple := range tuplesNeeded { if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { - authEntries = append(authEntries, types.StateEntry{tuple, eventNID}) + authEntries = append(authEntries, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) } } - authEvents, _, err := loadStateEvents(db, authEntries) + authEvents, _, err := loadStateEvents(ctx, db, authEntries) if err != nil { return nil, err } @@ -711,24 +768,39 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { var keyTuples []types.StateKeyTuple if stateNeeded.Create { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomCreateNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) } if stateNeeded.PowerLevels { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomPowerLevelsNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) } if stateNeeded.JoinRules { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomJoinRulesNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) } for _, member := range stateNeeded.Member { stateKeyNID, ok := stateKeyNIDMap[member] if ok { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }) } } for _, token := range stateNeeded.ThirdPartyInvite { stateKeyNID, ok := stateKeyNIDMap[token] if ok { - keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID}) + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomThirdPartyInviteNID, + EventStateKeyNID: stateKeyNID, + }) } } return keyTuples @@ -738,12 +810,14 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stat // Returns a list of state events in no particular order and a map from string event ID back to state entry. // The map can be used to recover which numeric state entry a given event is for. // Returns an error if there was a problem talking to the database. -func loadStateEvents(db RoomStateDatabase, entries []types.StateEntry) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { +func loadStateEvents( + ctx context.Context, db RoomStateDatabase, entries []types.StateEntry, +) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { eventNIDs := make([]types.EventNID, len(entries)) for i := range entries { eventNIDs[i] = entries[i].EventNID } - events, err := db.Events(eventNIDs) + events, err := db.Events(ctx, eventNIDs) if err != nil { return nil, nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/event_json_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/event_json_table.go index 5e203b4f..63792d15 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/event_json_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/event_json_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/matrix-org/dendrite/roomserver/types" @@ -65,8 +66,10 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventJSONStatements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error { - _, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON) +func (s *eventJSONStatements) insertEventJSON( + ctx context.Context, eventNID types.EventNID, eventJSON []byte, +) error { + _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } @@ -75,8 +78,10 @@ type eventJSONPair struct { EventJSON []byte } -func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) { - rows, err := s.bulkSelectEventJSONStmt.Query(eventNIDsAsArray(eventNIDs)) +func (s *eventJSONStatements) bulkSelectEventJSON( + ctx context.Context, eventNIDs []types.EventNID, +) ([]eventJSONPair, error) { + rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go index b06f5b2a..46f2434a 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/lib/pq" @@ -91,20 +92,30 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { +func (s *eventStateKeyStatements) insertEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID) + stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt) + err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { +func (s *eventStateKeyStatements) selectEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID) + stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt) + err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { - rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys)) +func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( + ctx, pq.StringArray(eventStateKeys), + ) if err != nil { return nil, err } @@ -122,18 +133,23 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st return result, nil } -func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) { +func (s *eventStateKeyStatements) selectEventStateKey( + ctx context.Context, txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID, +) (string, error) { var eventStateKey string - err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey) + stmt := common.TxStmt(txn, s.selectEventStateKeyStmt) + err := stmt.QueryRowContext(ctx, eventStateKeyNID).Scan(&eventStateKey) return eventStateKey, err } -func (s *eventStateKeyStatements) bulkSelectEventStateKey(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) { +func (s *eventStateKeyStatements) bulkSelectEventStateKey( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { var nIDs pq.Int64Array for i := range eventStateKeyNIDs { nIDs[i] = int64(eventStateKeyNIDs[i]) } - rows, err := s.bulkSelectEventStateKeyStmt.Query(nIDs) + rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/event_types_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/event_types_table.go index 7c0bf9b1..2d9f290e 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/event_types_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/event_types_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/lib/pq" @@ -107,20 +108,26 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventTypeStatements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) { +func (s *eventTypeStatements) insertEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) + err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) { +func (s *eventTypeStatements) selectEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) + err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } -func (s *eventTypeStatements) bulkSelectEventTypeNID(eventTypes []string) (map[string]types.EventTypeNID, error) { - rows, err := s.bulkSelectEventTypeNIDStmt.Query(pq.StringArray(eventTypes)) +func (s *eventTypeStatements) bulkSelectEventTypeNID( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go index 2d2b8562..bcf17c26 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "fmt" @@ -154,7 +155,10 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { } func (s *eventStatements) insertEvent( - roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, + ctx context.Context, + roomNID types.RoomNID, + eventTypeNID types.EventTypeNID, + eventStateKeyNID types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, @@ -162,24 +166,28 @@ func (s *eventStatements) insertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.insertEventStmt.QueryRow( - int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, - eventNIDsAsArray(authEventNIDs), depth, + err := s.insertEventStmt.QueryRowContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, ).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } -func (s *eventStatements) selectEvent(eventID string) (types.EventNID, types.StateSnapshotNID, error) { +func (s *eventStatements) selectEvent( + ctx context.Context, eventID string, +) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.selectEventStmt.QueryRow(eventID).Scan(&eventNID, &stateNID) + err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError -func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) { - rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs)) +func (s *eventStatements) bulkSelectStateEventByID( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -216,8 +224,10 @@ func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.S // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. -func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.StateAtEvent, error) { - rows, err := s.bulkSelectStateAtEventByIDStmt.Query(pq.StringArray(eventIDs)) +func (s *eventStatements) bulkSelectStateAtEventByID( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -248,28 +258,40 @@ func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types return results, err } -func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error { - _, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID)) +func (s *eventStatements) updateEventState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) return err } -func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) { - err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput) +func (s *eventStatements) selectEventSentToOutput( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (sentToOutput bool, err error) { + stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt) + stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) return } -func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error { - _, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID)) +func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { + stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID)) return err } -func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) { - err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID) +func (s *eventStatements) selectEventID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (eventID string, err error) { + stmt := common.TxStmt(txn, s.selectEventIDStmt) + err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) return } -func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) { - rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs)) +func (s *eventStatements) bulkSelectStateAtEventAndReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.StateAtEventAndReference, error) { + stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } @@ -304,8 +326,10 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventN return results, nil } -func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) { - rows, err := s.bulkSelectEventReferenceStmt.Query(eventNIDsAsArray(eventNIDs)) +func (s *eventStatements) bulkSelectEventReference( + ctx context.Context, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.EventReference, error) { + rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } @@ -325,8 +349,8 @@ func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - rows, err := s.bulkSelectEventIDStmt.Query(eventNIDsAsArray(eventNIDs)) +func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } @@ -349,8 +373,8 @@ func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]types.EventNID, error) { - rows, err := s.bulkSelectEventNIDStmt.Query(pq.StringArray(eventIDs)) +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { + rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -367,9 +391,10 @@ func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]type return results, nil } -func (s *eventStatements) selectMaxEventDepth(eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { var result int64 - err := s.selectMaxEventDepthStmt.QueryRow(eventNIDsAsArray(eventNIDs)).Scan(&result) + stmt := s.selectMaxEventDepthStmt + err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) if err != nil { return 0, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go index 8bae2b78..76fa3e04 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/matrix-org/dendrite/common" @@ -91,12 +92,13 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) { } func (s *inviteStatements) insertInviteEvent( + ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec( - inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ) if err != nil { return false, err @@ -109,9 +111,11 @@ func (s *inviteStatements) insertInviteEvent( } func (s *inviteStatements) updateInviteRetired( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { - rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID) + stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) if err != nil { return nil, err } @@ -129,10 +133,11 @@ func (s *inviteStatements) updateInviteRetired( // selectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) selectInviteActiveForUserInRoom( + ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.Query( - targetUserNID, roomNID, + rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + ctx, targetUserNID, roomNID, ) if err != nil { return nil, err diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go index 5db772aa..88a9ed72 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/matrix-org/dendrite/common" @@ -114,34 +115,38 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { } func (s *membershipStatements) insertMembership( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) error { - _, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID) + stmt := common.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) return err } func (s *membershipStatements) selectMembershipForUpdate( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership membershipState, err error) { - err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow( - roomNID, targetUserNID, + err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( + ctx, roomNID, targetUserNID, ).Scan(&membership) return } func (s *membershipStatements) selectMembershipFromRoomAndTarget( + ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership membershipState, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRow( - roomNID, targetUserNID, + err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } func (s *membershipStatements) selectMembershipsFromRoom( - roomNID types.RoomNID, + ctx context.Context, roomNID types.RoomNID, ) (eventNIDs []types.EventNID, err error) { - rows, err := s.selectMembershipsFromRoomStmt.Query(roomNID) + rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) if err != nil { return } @@ -156,9 +161,11 @@ func (s *membershipStatements) selectMembershipsFromRoom( return } func (s *membershipStatements) selectMembershipsFromRoomAndMembership( + ctx context.Context, roomNID types.RoomNID, membership membershipState, ) (eventNIDs []types.EventNID, err error) { - rows, err := s.selectMembershipsFromRoomAndMembershipStmt.Query(roomNID, membership) + stmt := s.selectMembershipsFromRoomAndMembershipStmt + rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return } @@ -174,12 +181,13 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( } func (s *membershipStatements) updateMembership( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership membershipState, eventNID types.EventNID, ) error { - _, err := common.TxStmt(txn, s.updateMembershipStmt).Exec( - roomNID, targetUserNID, senderUserNID, membership, eventNID, + _, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext( + ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, ) return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go index 9fcf1cb5..81d581a9 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/matrix-org/dendrite/common" @@ -73,14 +74,26 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error { - _, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID)) +func (s *previousEventStatements) insertPreviousEvent( + ctx context.Context, + txn *sql.Tx, + previousEventID string, + previousEventReferenceSHA256 []byte, + eventNID types.EventNID, +) error { + stmt := common.TxStmt(txn, s.insertPreviousEventStmt) + _, err := stmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ) return err } // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. -func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error { +func (s *previousEventStatements) selectPreviousEventExists( + ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, +) error { var ok int64 - return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok) + stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt) + return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/room_aliases_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/room_aliases_table.go index bfd6cc09..f640c37f 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/room_aliases_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/room_aliases_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" ) @@ -62,22 +63,28 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *roomAliasesStatements) insertRoomAlias(alias string, roomID string) (err error) { - _, err = s.insertRoomAliasStmt.Exec(alias, roomID) +func (s *roomAliasesStatements) insertRoomAlias( + ctx context.Context, alias string, roomID string, +) (err error) { + _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID) return } -func (s *roomAliasesStatements) selectRoomIDFromAlias(alias string) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRow(alias).Scan(&roomID) +func (s *roomAliasesStatements) selectRoomIDFromAlias( + ctx context.Context, alias string, +) (roomID string, err error) { + err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } return } -func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases []string, err error) { +func (s *roomAliasesStatements) selectAliasesFromRoomID( + ctx context.Context, roomID string, +) (aliases []string, err error) { aliases = []string{} - rows, err := s.selectAliasesFromRoomIDStmt.Query(roomID) + rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) if err != nil { return } @@ -94,7 +101,9 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases return } -func (s *roomAliasesStatements) deleteRoomAlias(alias string) (err error) { - _, err = s.deleteRoomAliasStmt.Exec(alias) +func (s *roomAliasesStatements) deleteRoomAlias( + ctx context.Context, alias string, +) (err error) { + _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) return } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go index 4ba329f3..64193ffe 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/lib/pq" @@ -81,22 +82,31 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { +func (s *roomStatements) insertRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { var roomNID int64 - err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID) + stmt := common.TxStmt(txn, s.insertRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } -func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { +func (s *roomStatements) selectRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { var roomNID int64 - err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID) + stmt := common.TxStmt(txn, s.selectRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } -func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) { +func (s *roomStatements) selectLatestEventNIDs( + ctx context.Context, roomNID types.RoomNID, +) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 - err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids, &stateSnapshotNID) + stmt := s.selectLatestEventNIDsStmt + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) if err != nil { return nil, 0, err } @@ -107,13 +117,14 @@ func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.E return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ( - []types.EventNID, types.EventNID, types.StateSnapshotNID, error, -) { +func (s *roomStatements) selectLatestEventsNIDsForUpdate( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var lastEventSentNID int64 var stateSnapshotNID int64 - err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) + stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) if err != nil { return nil, 0, 0, err } @@ -125,11 +136,20 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty } func (s *roomStatements) updateLatestEventNIDs( - txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventNIDs []types.EventNID, + lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - _, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec( - roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID), + stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + roomNID, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), ) return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table.go index 343e9395..136f61c4 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "fmt" "sort" @@ -97,9 +98,14 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error { +func (s *stateBlockStatements) bulkInsertStateData( + ctx context.Context, + stateBlockNID types.StateBlockNID, + entries []types.StateEntry, +) error { for _, entry := range entries { - _, err := s.insertStateDataStmt.Exec( + _, err := s.insertStateDataStmt.ExecContext( + ctx, int64(stateBlockNID), int64(entry.EventTypeNID), int64(entry.EventStateKeyNID), @@ -112,18 +118,22 @@ func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBloc return nil } -func (s *stateBlockStatements) selectNextStateBlockNID() (types.StateBlockNID, error) { +func (s *stateBlockStatements) selectNextStateBlockNID( + ctx context.Context, +) (types.StateBlockNID, error) { var stateBlockNID int64 - err := s.selectNextStateBlockNIDStmt.QueryRow().Scan(&stateBlockNID) + err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID) return types.StateBlockNID(stateBlockNID), err } -func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { +func (s *stateBlockStatements) bulkSelectStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - rows, err := s.bulkSelectStateBlockEntriesStmt.Query(pq.Int64Array(nids)) + rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids)) if err != nil { return nil, err } @@ -165,15 +175,20 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []type } func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( - stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { tuples := stateKeyTupleSorter(stateKeyTuples) // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.Query( - stateBlockNIDsAsArray(stateBlockNIDs), eventTypeNIDArray, eventStateKeyNIDArray, + rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext( + ctx, + stateBlockNIDsAsArray(stateBlockNIDs), + eventTypeNIDArray, + eventStateKeyNIDArray, ) if err != nil { return nil, err diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table_test.go b/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table_test.go index dd0a1b1d..f891b5bc 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table_test.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table_test.go @@ -15,29 +15,30 @@ package storage import ( - "github.com/matrix-org/dendrite/roomserver/types" "sort" "testing" + + "github.com/matrix-org/dendrite/roomserver/types" ) func TestStateKeyTupleSorter(t *testing.T) { input := stateKeyTupleSorter{ - {1, 2}, - {1, 4}, - {2, 2}, - {1, 1}, + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 1}, } want := []types.StateKeyTuple{ - {1, 1}, - {1, 2}, - {1, 4}, - {2, 2}, + {EventTypeNID: 1, EventStateKeyNID: 1}, + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, } doNotWant := []types.StateKeyTuple{ - {0, 0}, - {1, 3}, - {2, 1}, - {3, 1}, + {EventTypeNID: 0, EventStateKeyNID: 0}, + {EventTypeNID: 1, EventStateKeyNID: 3}, + {EventTypeNID: 2, EventStateKeyNID: 1}, + {EventTypeNID: 3, EventStateKeyNID: 1}, } wantTypeNIDs := []int64{1, 2} wantStateKeyNIDs := []int64{1, 2, 4} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/state_snapshot_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/state_snapshot_table.go index 4d588662..a0dace82 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/state_snapshot_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/state_snapshot_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "fmt" @@ -74,21 +75,25 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *stateSnapshotStatements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) { +func (s *stateSnapshotStatements) insertState( + ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, +) (stateNID types.StateSnapshotNID, err error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) return } -func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) { +func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) for i := range stateNIDs { nids[i] = int64(stateNIDs[i]) } - rows, err := s.bulkSelectStateBlockNIDsStmt.Query(pq.Int64Array(nids)) + rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index 85c5160a..239837fc 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" // Import the postgres database driver. @@ -43,7 +44,9 @@ func Open(dataSourceName string) (*Database, error) { } // StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) { +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { var ( roomNID types.RoomNID eventTypeNID types.EventTypeNID @@ -53,11 +56,11 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ err error ) - if roomNID, err = d.assignRoomNID(nil, event.RoomID()); err != nil { + if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil { return 0, types.StateAtEvent{}, err } - if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil { + if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { return 0, types.StateAtEvent{}, err } @@ -65,12 +68,13 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ // Assigned a numeric ID for the state_key if there is one present. // Otherwise set the numeric ID for the state_key to 0. if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(nil, *eventStateKey); err != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { return 0, types.StateAtEvent{}, err } } if eventNID, stateNID, err = d.statements.insertEvent( + ctx, roomNID, eventTypeNID, eventStateKeyNID, @@ -81,14 +85,14 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ ); err != nil { if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(event.EventID()) + eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) } if err != nil { return 0, types.StateAtEvent{}, err } } - if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil { + if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { return 0, types.StateAtEvent{}, err } @@ -104,76 +108,94 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ }, nil } -func (d *Database) assignRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { // Check if we already have a numeric ID in the database. - roomNID, err := d.statements.selectRoomNID(txn, roomID) + roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(txn, roomID) + roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - roomNID, err = d.statements.selectRoomNID(txn, roomID) + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) } } return roomNID, err } -func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) { +func (d *Database) assignEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { // Check if we already have a numeric ID in the database. - eventTypeNID, err := d.statements.selectEventTypeNID(eventType) + eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(eventType) + eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(eventType) + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) } } return eventTypeNID, err } -func (d *Database) assignStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { // Check if we already have a numeric ID in the database. - eventStateKeyNID, err := d.statements.selectEventStateKeyNID(txn, eventStateKey) + eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - eventStateKeyNID, err = d.statements.insertEventStateKeyNID(txn, eventStateKey) + eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(txn, eventStateKey) + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) } } return eventStateKeyNID, err } // StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(eventIDs) +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return d.statements.bulkSelectStateEventByID(ctx, eventIDs) } // EventTypeNIDs implements state.RoomStateDatabase -func (d *Database) EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) { - return d.statements.bulkSelectEventTypeNID(eventTypes) +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) } // EventStateKeyNIDs implements state.RoomStateDatabase -func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { - return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) } // EventStateKeys implements query.RoomserverQueryAPIDatabase -func (d *Database) EventStateKeys(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) { - return d.statements.bulkSelectEventStateKey(eventStateKeyNIDs) +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) } // EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs(eventIDs []string) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(eventIDs) +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.statements.bulkSelectEventNID(ctx, eventIDs) } // Events implements input.EventDatabase -func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs) +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) if err != nil { return nil, err } @@ -191,78 +213,98 @@ func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) { } // AddState implements input.EventDatabase -func (d *Database) AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) { +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (types.StateSnapshotNID, error) { if len(state) > 0 { - stateBlockNID, err := d.statements.selectNextStateBlockNID() + stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) if err != nil { return 0, err } - if err = d.statements.bulkInsertStateData(stateBlockNID, state); err != nil { + if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { return 0, err } stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) } - return d.statements.insertState(roomNID, stateBlockNIDs) + return d.statements.insertState(ctx, roomNID, stateBlockNIDs) } // SetState implements input.EventDatabase -func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error { - return d.statements.updateEventState(eventNID, stateNID) +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return d.statements.updateEventState(ctx, eventNID, stateNID) } // StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(eventIDs) +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) } // StateBlockNIDs implements state.RoomStateDatabase -func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) { - return d.statements.bulkSelectStateBlockNIDs(stateNIDs) +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) } // StateEntries implements state.RoomStateDatabase -func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { - return d.statements.bulkSelectStateBlockEntries(stateBlockNIDs) +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) } // SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(eventID) +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.statements.selectEvent(ctx, eventID) return stateNID, err } // EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(eventNIDs) +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return d.statements.bulkSelectEventID(ctx, eventNIDs) } // GetLatestEventsForUpdate implements input.EventDatabase -func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) { +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, +) (types.RoomRecentEventsUpdater, error) { txn, err := d.db.Begin() if err != nil { return nil, err } - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID) + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) if err != nil { txn.Rollback() return nil, err } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(txn, eventNIDs) + stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) if err != nil { txn.Rollback() return nil, err } var lastEventIDSent string if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(txn, lastEventNIDSent) + lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) if err != nil { txn.Rollback() return nil, err } } return &roomRecentEventsUpdater{ - transaction{txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, }, nil } @@ -293,7 +335,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN // StorePreviousEvents implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { return err } } @@ -302,7 +344,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p // IsReferenced implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.statements.selectPreviousEventExists(u.txn, eventReference.EventID, eventReference.EventSHA256) + err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) if err == nil { return true, nil } @@ -321,26 +363,26 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( for i := range latest { eventNIDs[i] = latest[i].EventNID } - return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) } // HasEventBeenSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.statements.selectEventSentToOutput(u.txn, eventNID) + return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) } // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.statements.updateEventSentToOutput(u.txn, eventNID) + return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) } func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.txn, u.roomNID, targetUserNID) + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) } // RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(roomID string) (types.RoomNID, error) { - roomNID, err := d.statements.selectRoomNID(nil, roomID) +func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { + roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) if err == sql.ErrNoRows { return 0, nil } @@ -348,16 +390,18 @@ func (d *Database) RoomNID(roomID string) (types.RoomNID, error) { } // LatestEventIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { - eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID) +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { + eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) if err != nil { return nil, 0, 0, err } - references, err := d.statements.bulkSelectEventReference(eventNIDs) + references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) if err != nil { return nil, 0, 0, err } - depth, err := d.statements.selectMaxEventDepth(eventNIDs) + depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) if err != nil { return nil, 0, 0, err } @@ -366,40 +410,48 @@ func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.Ev // GetInvitesForUser implements query.RoomserverQueryAPIDatabase func (d *Database) GetInvitesForUser( - roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, ) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(targetUserNID, roomNID) + return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) } // SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(alias string, roomID string) error { - return d.statements.insertRoomAlias(alias, roomID) +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error { + return d.statements.insertRoomAlias(ctx, alias, roomID) } // GetRoomIDFromAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetRoomIDFromAlias(alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(alias) +func (d *Database) GetRoomIDFromAlias(ctx context.Context, alias string) (string, error) { + return d.statements.selectRoomIDFromAlias(ctx, alias) } // GetAliasesFromRoomID implements alias.RoomserverAliasAPIDB -func (d *Database) GetAliasesFromRoomID(roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(roomID) +func (d *Database) GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.statements.selectAliasesFromRoomID(ctx, roomID) } // RemoveRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) RemoveRoomAlias(alias string) error { - return d.statements.deleteRoomAlias(alias) +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.statements.deleteRoomAlias(ctx, alias) } // StateEntriesForTuples implements state.RoomStateDatabase func (d *Database) StateEntriesForTuples( - stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { - return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples) + return d.statements.bulkSelectFilteredStateBlockEntries( + ctx, stateBlockNIDs, stateKeyTuples, + ) } // MembershipUpdater implements input.RoomEventDatabase -func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.MembershipUpdater, error) { +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, +) (types.MembershipUpdater, error) { txn, err := d.db.Begin() if err != nil { return nil, err @@ -411,17 +463,17 @@ func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.Members } }() - roomNID, err := d.assignRoomNID(txn, roomID) + roomNID, err := d.assignRoomNID(ctx, txn, roomID) if err != nil { return nil, err } - targetUserNID, err := d.assignStateKeyNID(txn, targetUserID) + targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) if err != nil { return nil, err } - updater, err := d.membershipUpdaterTxn(txn, roomNID, targetUserNID) + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) if err != nil { return nil, err } @@ -439,20 +491,23 @@ type membershipUpdater struct { } func (d *Database) membershipUpdaterTxn( - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(txn, roomNID, targetUserNID); err != nil { + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { return nil, err } - membership, err := d.statements.selectMembershipForUpdate(txn, roomNID, targetUserNID) + membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) if err != nil { return nil, err } return &membershipUpdater{ - transaction{txn}, d, roomNID, targetUserNID, membership, + transaction{ctx, txn}, d, roomNID, targetUserNID, membership, }, nil } @@ -473,19 +528,19 @@ func (u *membershipUpdater) IsLeave() bool { // SetToInvite implements types.MembershipUpdater func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.txn, event.Sender()) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) if err != nil { return false, err } inserted, err := u.d.statements.insertInviteEvent( - u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { return false, err } if u.membership != membershipStateInvite { if err = u.d.statements.updateMembership( - u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, ); err != nil { return false, err } @@ -497,7 +552,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { var inviteEventIDs []string - senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { return nil, err } @@ -505,7 +560,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd // If this is a join event update, there is no invite to update if !isUpdate { inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.txn, u.roomNID, u.targetUserNID, + u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { return nil, err @@ -513,14 +568,15 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs([]string{eventID}) + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if err != nil { return nil, err } if u.membership != membershipStateJoin || isUpdate { if err = u.d.statements.updateMembership( - u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, nIDs[eventID], + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateJoin, nIDs[eventID], ); err != nil { return nil, err } @@ -531,26 +587,27 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd // SetToLeave implements types.MembershipUpdater func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { return nil, err } inviteEventIDs, err := u.d.statements.updateInviteRetired( - u.txn, u.roomNID, u.targetUserNID, + u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { return nil, err } // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs([]string{eventID}) + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if err != nil { return nil, err } if u.membership != membershipStateLeaveOrBan { if err = u.d.statements.updateMembership( - u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, nIDs[eventID], + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateLeaveOrBan, nIDs[eventID], ); err != nil { return nil, err } @@ -559,19 +616,18 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } // GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - txn, err := d.db.Begin() - if err != nil { - return - } - defer txn.Commit() - - requestSenderUserNID, err := d.assignStateKeyNID(txn, requestSenderUserID) +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) if err != nil { return } - senderMembershipEventNID, senderMembership, err := d.statements.selectMembershipFromRoomAndTarget(roomNID, requestSenderUserNID) + senderMembershipEventNID, senderMembership, err := + d.statements.selectMembershipFromRoomAndTarget( + ctx, roomNID, requestSenderUserNID, + ) if err == sql.ErrNoRows { // The user has never been a member of that room return 0, false, nil @@ -583,15 +639,20 @@ func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID stri } // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) { +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, +) ([]types.EventNID, error) { if joinOnly { - return d.statements.selectMembershipsFromRoomAndMembership(roomNID, membershipStateJoin) + return d.statements.selectMembershipsFromRoomAndMembership( + ctx, roomNID, membershipStateJoin, + ) } - return d.statements.selectMembershipsFromRoom(roomNID) + return d.statements.selectMembershipsFromRoom(ctx, roomNID) } type transaction struct { + ctx context.Context txn *sql.Tx }