From ebbfc125920beb321713e28a2a137d768406fa15 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 30 Apr 2020 17:15:29 +0100 Subject: [PATCH] Add tests for the storage interface (#995) * Move docs to interface * Add tests around syncing * Add topology token test * Linting --- syncapi/storage/interface.go | 59 ++++- syncapi/storage/postgres/syncserver.go | 54 ----- syncapi/storage/storage_test.go | 313 +++++++++++++++++++++++++ 3 files changed, 371 insertions(+), 55 deletions(-) create mode 100644 syncapi/storage/storage_test.go diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index a3efd8d5..bd9504db 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -28,26 +28,83 @@ import ( type Database interface { common.PartitionStorer + // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) + // Events lookups a list of event by their event ID. + // Returns a list of events matching the requested IDs found in the database. + // If an event is not found in the database then it will be omitted from the list. + // Returns an error if there was a problem talking with the database. + // Does not include any transaction IDs in the returned events. Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) - WriteEvent(context.Context, *gomatrixserverlib.HeaderedEvent, []gomatrixserverlib.HeaderedEvent, []string, []string, *api.TransactionID, bool) (types.StreamPosition, error) + // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races + // when generating the sync stream position for this event. Returns the sync stream position for the inserted event. + // Returns an error if there was a problem inserting this event. + WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent, + addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) + // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + // GetStateEventsForRoom fetches the state events for a given room. + // Returns an empty slice if no state events could be found for this room. + // Returns an error if there was an issue with the retrieval. GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) + // SyncPosition returns the latest positions for syncing. SyncPosition(ctx context.Context) (types.PaginationToken, error) + // IncrementalSync returns all the data needed in order to create an incremental + // sync response for the given user. Events returned will include any client + // transaction IDs associated with the given device. These transaction IDs come + // from when the device sent the event via an API that included a transaction + // ID. IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + // CompleteSync returns a complete /sync API response for the given user. CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) + // GetAccountDataInRange returns all account data for a given user inserted or + // updated between two given positions + // Returns a map following the format data[roomID] = []dataTypes + // If no data is retrieved, returns an empty map + // If there was an issue with the retrieval, returns an error GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) + // UpsertAccountData keeps track of new or updated account data, by saving the type + // of the new/updated data, and the user ID and room ID the data is related to (empty) + // room ID means the data isn't specific to any room) + // If no data with the given type, user ID and room ID exists in the database, + // creates a new row, else update the existing one + // Returns an error if there was an issue with the upsert UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error) + // AddInviteEvent stores a new invite event for a user. + // If the invite was successfully stored this returns the stream ID it was stored at. + // Returns an error if there was a problem communicating with the database. AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error) + // RetireInviteEvent removes an old invite event from the database. + // Returns an error if there was a problem communicating with the database. RetireInviteEvent(ctx context.Context, inviteEventID string) error + // SetTypingTimeoutCallback sets a callback function that is called right after + // a user is removed from the typing user list due to timeout. SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) + // AddTypingUser adds a typing user to the typing cache. + // Returns the newly calculated sync position for typing notifications. AddTypingUser(userID, roomID string, expireTime *time.Time) types.StreamPosition + // RemoveTypingUser removes a typing user from the typing cache. + // Returns the newly calculated sync position for typing notifications. RemoveTypingUser(userID, roomID string) types.StreamPosition + // GetEventsInRange retrieves all of the events on a given ordering using the + // given extremities and limit. GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + // EventPositionInTopology returns the depth of the given event. EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error) + // EventsAtTopologicalPosition returns all of the events matching a given + // position in the topology of a given room. EventsAtTopologicalPosition(ctx context.Context, roomID string, pos types.StreamPosition) ([]types.StreamEvent, error) + // BackwardExtremitiesForRoom returns the event IDs of all of the backward + // extremities we know of for a given room. BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error) + // MaxTopologicalPosition returns the highest topological position for a given room. MaxTopologicalPosition(ctx context.Context, roomID string) (types.StreamPosition, error) + // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and + // matches the streamevent.transactionID device then the transaction ID gets + // added to the unsigned section of the output event. StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent + // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) } diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 9d61ccfc..ef970739 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -93,16 +93,10 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er return &d, nil } -// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { return d.roomstate.selectJoinedUsers(ctx) } -// Events lookups a list of event by their event ID. -// Returns a list of events matching the requested IDs found in the database. -// If an event is not found in the database then it will be omitted from the list. -// Returns an error if there was a problem talking with the database. -// Does not include any transaction IDs in the returned events. func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) if err != nil { @@ -148,9 +142,6 @@ func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, tx return nil } -// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races -// when generating the sync stream position for this event. Returns the sync stream position for the inserted event. -// Returns an error if there was a problem inserting this event. func (d *SyncServerDatasource) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, @@ -221,18 +212,12 @@ func (d *SyncServerDatasource) updateRoomState( return nil } -// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key -// If no event could be found, returns nil -// If there was an issue during the retrieval, returns an error func (d *SyncServerDatasource) GetStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) } -// GetStateEventsForRoom fetches the state events for a given room. -// Returns an empty slice if no state events could be found for this room. -// Returns an error if there was an issue with the retrieval. func (d *SyncServerDatasource) GetStateEventsForRoom( ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { @@ -243,8 +228,6 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( return } -// GetEventsInRange retrieves all of the events on a given ordering using the -// given extremities and limit. func (d *SyncServerDatasource) GetEventsInRange( ctx context.Context, from, to *types.PaginationToken, @@ -306,29 +289,22 @@ func (d *SyncServerDatasource) GetEventsInRange( return } -// SyncPosition returns the latest positions for syncing. func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { return d.syncPositionTx(ctx, nil) } -// BackwardExtremitiesForRoom returns the event IDs of all of the backward -// extremities we know of for a given room. func (d *SyncServerDatasource) BackwardExtremitiesForRoom( ctx context.Context, roomID string, ) (backwardExtremities []string, err error) { return d.backwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID) } -// MaxTopologicalPosition returns the highest topological position for a given -// room. func (d *SyncServerDatasource) MaxTopologicalPosition( ctx context.Context, roomID string, ) (types.StreamPosition, error) { return d.topology.selectMaxPositionInTopology(ctx, roomID) } -// EventsAtTopologicalPosition returns all of the events matching a given -// position in the topology of a given room. func (d *SyncServerDatasource) EventsAtTopologicalPosition( ctx context.Context, roomID string, pos types.StreamPosition, ) ([]types.StreamEvent, error) { @@ -346,7 +322,6 @@ func (d *SyncServerDatasource) EventPositionInTopology( return d.topology.selectPositionInTopology(ctx, eventID) } -// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { return d.syncStreamPositionTx(ctx, nil) } @@ -511,11 +486,6 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse( return } -// IncrementalSync returns all the data needed in order to create an incremental -// sync response for the given user. Events returned will include any client -// transaction IDs associated with the given device. These transaction IDs come -// from when the device sent the event via an API that included a transaction -// ID. func (d *SyncServerDatasource) IncrementalSync( ctx context.Context, device authtypes.Device, @@ -645,7 +615,6 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( return res, toPos, joinedRoomIDs, err } -// CompleteSync returns a complete /sync API response for the given user. func (d *SyncServerDatasource) CompleteSync( ctx context.Context, userID string, numRecentEventsPerRoom int, ) (*types.Response, error) { @@ -677,11 +646,6 @@ var txReadOnlySnapshot = sql.TxOptions{ ReadOnly: true, } -// GetAccountDataInRange returns all account data for a given user inserted or -// updated between two given positions -// Returns a map following the format data[roomID] = []dataTypes -// If no data is retrieved, returns an empty map -// If there was an issue with the retrieval, returns an error func (d *SyncServerDatasource) GetAccountDataInRange( ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrixserverlib.EventFilter, @@ -689,29 +653,18 @@ func (d *SyncServerDatasource) GetAccountDataInRange( return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) } -// UpsertAccountData keeps track of new or updated account data, by saving the type -// of the new/updated data, and the user ID and room ID the data is related to (empty) -// room ID means the data isn't specific to any room) -// If no data with the given type, user ID and room ID exists in the database, -// creates a new row, else update the existing one -// Returns an error if there was an issue with the upsert func (d *SyncServerDatasource) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, ) (types.StreamPosition, error) { return d.accountData.insertAccountData(ctx, userID, roomID, dataType) } -// AddInviteEvent stores a new invite event for a user. -// If the invite was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ) (types.StreamPosition, error) { return d.invites.insertInviteEvent(ctx, inviteEvent) } -// RetireInviteEvent removes an old invite event from the database. -// Returns an error if there was a problem communicating with the database. func (d *SyncServerDatasource) RetireInviteEvent( ctx context.Context, inviteEventID string, ) error { @@ -725,16 +678,12 @@ func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallback d.eduCache.SetTimeoutCallback(fn) } -// AddTypingUser adds a typing user to the typing cache. -// Returns the newly calculated sync position for typing notifications. func (d *SyncServerDatasource) AddTypingUser( userID, roomID string, expireTime *time.Time, ) types.StreamPosition { return types.StreamPosition(d.eduCache.AddTypingUser(userID, roomID, expireTime)) } -// RemoveTypingUser removes a typing user from the typing cache. -// Returns the newly calculated sync position for typing notifications. func (d *SyncServerDatasource) RemoveTypingUser( userID, roomID string, ) types.StreamPosition { @@ -1073,9 +1022,6 @@ func (d *SyncServerDatasource) currentStateStreamEventsForRoom( return s, nil } -// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and -// matches the streamevent.transactionID device then the transaction ID gets -// added to the unsigned section of the output event. func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent { out := make([]gomatrixserverlib.HeaderedEvent, len(in)) for i := 0; i < len(in); i++ { diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go new file mode 100644 index 00000000..e591e7ed --- /dev/null +++ b/syncapi/storage/storage_test.go @@ -0,0 +1,313 @@ +package storage_test + +import ( + "context" + "crypto/ed25519" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var ( + ctx = context.Background() + emptyStateKey = "" + testOrigin = gomatrixserverlib.ServerName("hollow.knight") + testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin) + testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin) + testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin) + testUserDeviceA = authtypes.Device{ + UserID: testUserIDA, + ID: "device_id_A", + DisplayName: "Device A", + } + testRoomVersion = gomatrixserverlib.RoomVersionV4 + testKeyID = gomatrixserverlib.KeyID("ed25519:storage_test") + testPrivateKey = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + }) +) + +func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) gomatrixserverlib.HeaderedEvent { + b.RoomID = roomID + if prevs != nil { + prevIDs := make([]string, len(prevs)) + for i := range prevs { + prevIDs[i] = prevs[i].EventID() + } + b.PrevEvents = prevIDs + } + e, err := b.Build(time.Now(), testOrigin, testKeyID, testPrivateKey, testRoomVersion) + if err != nil { + t.Fatalf("failed to build event: %s", err) + } + return e.Headered(testRoomVersion) +} + +func MustCreateDatabase(t *testing.T) storage.Database { + db, err := sqlite3.NewSyncServerDatasource("file::memory:") + if err != nil { + t.Fatalf("NewSyncServerDatasource returned %s", err) + } + return db +} + +// Create a list of events which include a create event, join event and some messages. +func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []gomatrixserverlib.HeaderedEvent, state []gomatrixserverlib.HeaderedEvent) { + var events []gomatrixserverlib.HeaderedEvent + events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)), + Type: "m.room.create", + StateKey: &emptyStateKey, + Sender: userA, + Depth: int64(len(events) + 1), + })) + state = append(state, events[len(events)-1]) + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"membership":"join"}`)), + Type: "m.room.member", + StateKey: &userA, + Sender: userA, + Depth: int64(len(events) + 1), + })) + state = append(state, events[len(events)-1]) + for i := 0; i < 10; i++ { + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)), + Type: "m.room.message", + Sender: userA, + Depth: int64(len(events) + 1), + })) + } + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"membership":"join"}`)), + Type: "m.room.member", + StateKey: &userB, + Sender: userB, + Depth: int64(len(events) + 1), + })) + state = append(state, events[len(events)-1]) + for i := 0; i < 10; i++ { + events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)), + Type: "m.room.message", + Sender: userB, + Depth: int64(len(events) + 1), + })) + } + + return events, state +} + +func MustWriteEvents(t *testing.T, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { + for _, ev := range events { + var addStateEvents []gomatrixserverlib.HeaderedEvent + var addStateEventIDs []string + var removeStateEventIDs []string + if ev.StateKey() != nil { + addStateEvents = append(addStateEvents, ev) + addStateEventIDs = append(addStateEventIDs, ev.EventID()) + } + pos, err := db.WriteEvent(ctx, &ev, addStateEvents, addStateEventIDs, removeStateEventIDs, nil, false) + if err != nil { + t.Fatalf("WriteEvent failed: %s", err) + } + positions = append(positions, pos) + } + return +} + +func TestWriteEvents(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + MustWriteEvents(t, db, events) +} + +// These tests assert basic functionality of the IncrementalSync and CompleteSync functions. +func TestSyncResponse(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, state := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + positions := MustWriteEvents(t, db, events) + latest, err := db.SyncPosition(ctx) + if err != nil { + t.Fatalf("failed to get SyncPosition: %s", err) + } + + testCases := []struct { + Name string + DoSync func() (*types.Response, error) + WantTimeline []gomatrixserverlib.HeaderedEvent + WantState []gomatrixserverlib.HeaderedEvent + }{ + // The purpose of this test is to make sure that incremental syncs are including up to the latest events. + // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. + // It makes sure the response includes the final event. + { + Name: "IncrementalSync penultimate", + DoSync: func() (*types.Response, error) { + from := types.NewPaginationTokenFromTypeAndPosition( // pretend we are at the penultimate event + types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0), + ) + return db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + }, + WantTimeline: events[len(events)-1:], + }, + // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the + // number of returned events. This is critical for big rooms hence the test here. + { + Name: "IncrementalSync limited", + DoSync: func() (*types.Response, error) { + from := types.NewPaginationTokenFromTypeAndPosition( // pretend we are 10 events behind + types.PaginationTokenTypeStream, positions[len(positions)-11], types.StreamPosition(0), + ) + // limit is set to 5 + return db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + }, + // want the last 5 events, NOT the last 10. + WantTimeline: events[len(events)-5:], + }, + // The purpose of this test is to check that CompleteSync returns all the current state as well as + // honouring the `numRecentEventsPerRoom` value + { + Name: "CompleteSync limited", + DoSync: func() (*types.Response, error) { + // limit set to 5 + return db.CompleteSync(ctx, testUserIDA, 5) + }, + // want the last 5 events, NOT the last 10. + WantTimeline: events[len(events)-5:], + // want all state for the room + WantState: state, + }, + // The purpose of this test is to check that CompleteSync can return everything with a high enough + // `numRecentEventsPerRoom`. + { + Name: "CompleteSync", + DoSync: func() (*types.Response, error) { + return db.CompleteSync(ctx, testUserIDA, len(events)+1) + }, + WantTimeline: events, + // We want no state at all as that field in /sync is the delta between the token (beginning of time) + // and the START of the timeline. + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(st *testing.T) { + res, err := tc.DoSync() + if err != nil { + st.Fatalf("failed to do sync: %s", err) + } + next := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeStream, latest.PDUPosition, latest.EDUTypingPosition) + if res.NextBatch != next.String() { + st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) + } + roomRes, ok := res.Rooms.Join[testRoomID] + if !ok { + st.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) + } + assertEventsEqual(st, "state for "+testRoomID, false, roomRes.State.Events, tc.WantState) + assertEventsEqual(st, "timeline for "+testRoomID, false, roomRes.Timeline.Events, tc.WantTimeline) + }) + } +} + +// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token. +func TestGetEventsInRangeWithStreamToken(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + MustWriteEvents(t, db, events) + latest, err := db.SyncPosition(ctx) + if err != nil { + t.Fatalf("failed to get SyncPosition: %s", err) + } + // head towards the beginning of time + to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + + // backpaginate 5 messages starting at the latest position. + paginatedEvents, err := db.GetEventsInRange(ctx, &latest, to, testRoomID, 5, true) + if err != nil { + t.Fatalf("GetEventsInRange returned an error: %s", err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) +} + +// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token +func TestGetEventsInRangeWithTopologyToken(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + MustWriteEvents(t, db, events) + latest, err := db.MaxTopologicalPosition(ctx, testRoomID) + if err != nil { + t.Fatalf("failed to get MaxTopologicalPosition: %s", err) + } + from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, 0) + // head towards the beginning of time + to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + + // backpaginate 5 messages starting at the latest position. + paginatedEvents, err := db.GetEventsInRange(ctx, from, to, testRoomID, 5, true) + if err != nil { + t.Fatalf("GetEventsInRange returned an error: %s", err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) +} + +func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { + if len(gots) != len(wants) { + t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) + } + for i := range gots { + g := gots[i] + w := wants[i] + if g.EventID != w.EventID() { + t.Errorf("%s event[%d] event_id mismatch: got %s want %s", msg, i, g.EventID, w.EventID()) + } + if g.Sender != w.Sender() { + t.Errorf("%s event[%d] sender mismatch: got %s want %s", msg, i, g.Sender, w.Sender()) + } + if checkRoomID && g.RoomID != w.RoomID() { + t.Errorf("%s event[%d] room_id mismatch: got %s want %s", msg, i, g.RoomID, w.RoomID()) + } + if g.Type != w.Type() { + t.Errorf("%s event[%d] event type mismatch: got %s want %s", msg, i, g.Type, w.Type()) + } + if g.OriginServerTS != w.OriginServerTS() { + t.Errorf("%s event[%d] origin_server_ts mismatch: got %v want %v", msg, i, g.OriginServerTS, w.OriginServerTS()) + } + if string(g.Content) != string(w.Content()) { + t.Errorf("%s event[%d] content mismatch: got %s want %s", msg, i, string(g.Content), string(w.Content())) + } + if string(g.Unsigned) != string(w.Unsigned()) { + t.Errorf("%s event[%d] unsigned mismatch: got %s want %s", msg, i, string(g.Unsigned), string(w.Unsigned())) + } + if (g.StateKey == nil && w.StateKey() != nil) || (g.StateKey != nil && w.StateKey() == nil) { + t.Fatalf("%s event[%d] state_key [not] missing: got %v want %v", msg, i, g.StateKey, w.StateKey()) + } + if g.StateKey != nil { + if !w.StateKeyEquals(*g.StateKey) { + t.Errorf("%s event[%d] state_key mismatch: got %s want %s", msg, i, *g.StateKey, *w.StateKey()) + } + } + } +} + +func reversed(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { + out := make([]gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[len(in)-i-1] + } + return out +}