Don't retrieve same state events over and over again (#1737)

main
Neil Alexander 2021-01-26 09:12:17 +00:00 committed by GitHub
parent ef9d5ad4fe
commit 64fb6de6d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 27 deletions

View File

@ -33,19 +33,21 @@ import (
type StateResolution struct {
db storage.Database
roomInfo types.RoomInfo
events map[types.EventNID]*gomatrixserverlib.Event
}
func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution {
return StateResolution{
db: db,
roomInfo: roomInfo,
events: make(map[types.EventNID]*gomatrixserverlib.Event),
}
}
// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
// This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func (v StateResolution) LoadStateAtSnapshot(
func (v *StateResolution) LoadStateAtSnapshot(
ctx context.Context, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) {
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
@ -83,7 +85,7 @@ func (v StateResolution) LoadStateAtSnapshot(
}
// LoadStateAtEvent loads the full state of a room before a particular event.
func (v StateResolution) LoadStateAtEvent(
func (v *StateResolution) LoadStateAtEvent(
ctx context.Context, eventID string,
) ([]types.StateEntry, error) {
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
@ -105,7 +107,7 @@ func (v StateResolution) LoadStateAtEvent(
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
// and combines those snapshots together into a single list. At this point it is
// possible to run into duplicate (type, state key) tuples.
func (v StateResolution) LoadCombinedStateAfterEvents(
func (v *StateResolution) LoadCombinedStateAfterEvents(
ctx context.Context, prevStates []types.StateAtEvent,
) ([]types.StateEntry, error) {
stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
@ -177,7 +179,7 @@ func (v StateResolution) LoadCombinedStateAfterEvents(
}
// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
func (v StateResolution) DifferenceBetweeenStateSnapshots(
func (v *StateResolution) DifferenceBetweeenStateSnapshots(
ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID,
) (removed, added []types.StateEntry, err error) {
if oldStateNID == newStateNID {
@ -236,7 +238,7 @@ func (v StateResolution) DifferenceBetweeenStateSnapshots(
// If there is no entry for a given event type and state key pair then it will be discarded.
// This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func (v StateResolution) LoadStateAtSnapshotForStringTuples(
func (v *StateResolution) LoadStateAtSnapshotForStringTuples(
ctx context.Context,
stateNID types.StateSnapshotNID,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
@ -251,7 +253,7 @@ func (v StateResolution) LoadStateAtSnapshotForStringTuples(
// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
// Returns an error if there was a problem talking to the database.
func (v StateResolution) stringTuplesToNumericTuples(
func (v *StateResolution) stringTuplesToNumericTuples(
ctx context.Context,
stringTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateKeyTuple, error) {
@ -292,7 +294,7 @@ func (v StateResolution) stringTuplesToNumericTuples(
// If there is no entry for a given event type and state key pair then it will be discarded.
// This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func (v StateResolution) loadStateAtSnapshotForNumericTuples(
func (v *StateResolution) loadStateAtSnapshotForNumericTuples(
ctx context.Context,
stateNID types.StateSnapshotNID,
stateKeyTuples []types.StateKeyTuple,
@ -340,7 +342,7 @@ func (v StateResolution) loadStateAtSnapshotForNumericTuples(
// If there is no entry for a given event type and state key pair then it will be discarded.
// This is typically the state before an event.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func (v StateResolution) LoadStateAfterEventsForStringTuples(
func (v *StateResolution) LoadStateAfterEventsForStringTuples(
ctx context.Context,
prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
@ -352,7 +354,7 @@ func (v StateResolution) LoadStateAfterEventsForStringTuples(
return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples)
}
func (v StateResolution) loadStateAfterEventsForNumericTuples(
func (v *StateResolution) loadStateAfterEventsForNumericTuples(
ctx context.Context,
prevStates []types.StateAtEvent,
stateKeyTuples []types.StateKeyTuple,
@ -520,7 +522,7 @@ func init() {
// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event.
// Stores the snapshot of the state in the database.
// Returns a numeric ID for the snapshot of the state before the event.
func (v StateResolution) CalculateAndStoreStateBeforeEvent(
func (v *StateResolution) CalculateAndStoreStateBeforeEvent(
ctx context.Context,
event *gomatrixserverlib.Event,
isRejected bool,
@ -537,7 +539,7 @@ func (v StateResolution) CalculateAndStoreStateBeforeEvent(
// CalculateAndStoreStateAfterEvents finds the room state after the given events.
// Stores the resulting state in the database and returns a numeric ID for that snapshot.
func (v StateResolution) CalculateAndStoreStateAfterEvents(
func (v *StateResolution) CalculateAndStoreStateAfterEvents(
ctx context.Context,
prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error) {
@ -607,7 +609,7 @@ const maxStateBlockNIDs = 64
// calculateAndStoreStateAfterManyEvents finds the room state after the given events.
// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
// Stores the resulting state and returns a numeric ID for the snapshot.
func (v StateResolution) calculateAndStoreStateAfterManyEvents(
func (v *StateResolution) calculateAndStoreStateAfterManyEvents(
ctx context.Context,
roomNID types.RoomNID,
prevStates []types.StateAtEvent,
@ -627,7 +629,7 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents(
return metrics.stop(v.db.AddState(ctx, roomNID, nil, state))
}
func (v StateResolution) calculateStateAfterManyEvents(
func (v *StateResolution) calculateStateAfterManyEvents(
ctx context.Context, roomVersion gomatrixserverlib.RoomVersion,
prevStates []types.StateAtEvent,
) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
@ -754,7 +756,7 @@ func ResolveConflictsAdhoc(
return resolved, nil
}
func (v StateResolution) resolveConflicts(
func (v *StateResolution) resolveConflicts(
ctx context.Context, version gomatrixserverlib.RoomVersion,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
@ -778,7 +780,7 @@ func (v StateResolution) resolveConflicts(
// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
// The returned list is sorted by state key tuple.
// Returns an error if there was a problem talking to the database.
func (v StateResolution) resolveConflictsV1(
func (v *StateResolution) resolveConflictsV1(
ctx context.Context,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
@ -842,7 +844,7 @@ func (v StateResolution) resolveConflictsV1(
// The returned list is sorted by state key tuple.
// Returns an error if there was a problem talking to the database.
// nolint:gocyclo
func (v StateResolution) resolveConflictsV2(
func (v *StateResolution) resolveConflictsV2(
ctx context.Context,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
@ -959,7 +961,7 @@ func (v StateResolution) resolveConflictsV2(
}
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
func (v StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple
if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{
@ -1004,26 +1006,33 @@ func (v StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.Ev
// Returns a list of state events in no particular order and a map from string event ID back to state entry.
// The map can be used to recover which numeric state entry a given event is for.
// Returns an error if there was a problem talking to the database.
func (v StateResolution) loadStateEvents(
func (v *StateResolution) loadStateEvents(
ctx context.Context, entries []types.StateEntry,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
eventNIDs := make([]types.EventNID, len(entries))
for i := range entries {
eventNIDs[i] = entries[i].EventNID
result := make([]*gomatrixserverlib.Event, 0, len(entries))
eventEntries := make([]types.StateEntry, 0, len(entries))
eventNIDs := make([]types.EventNID, 0, len(entries))
for _, entry := range entries {
if e, ok := v.events[entry.EventNID]; ok {
result = append(result, e)
} else {
eventEntries = append(eventEntries, entry)
eventNIDs = append(eventNIDs, entry.EventNID)
}
}
events, err := v.db.Events(ctx, eventNIDs)
if err != nil {
return nil, nil, err
}
eventIDMap := map[string]types.StateEntry{}
result := make([]*gomatrixserverlib.Event, len(entries))
for i := range entries {
event, ok := eventMap(events).lookup(entries[i].EventNID)
for _, entry := range eventEntries {
event, ok := eventMap(events).lookup(entry.EventNID)
if !ok {
panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID))
panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entry.EventNID))
}
result[i] = event.Event
eventIDMap[event.Event.EventID()] = entries[i]
result = append(result, event.Event)
eventIDMap[event.Event.EventID()] = entry
v.events[entry.EventNID] = event.Event
}
return result, eventIDMap, nil
}