From 5e9dce1c0c66736937eeddd5c33c92700d9a65a7 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 13 May 2020 12:14:50 +0100 Subject: [PATCH] syncapi: Rename and split out tokens (#1025) * syncapi: Rename and split out tokens Previously we used the badly named `PaginationToken` which was used for both `/sync` and `/messages` requests. This quickly became confusing because named fields like `PDUPosition` meant different things depending on the token type. Instead, we now have two token types: `TopologyToken` and `StreamingToken`, both of which have fields which make more sense for their specific situations. Updated the codebase to use one or the other. `PaginationToken` still lives on as `syncToken`, an unexported type which both tokens rely on. This allows us to guarantee that the specific mappings of positions to a string remain solely under the control of the `types` package. This enables us to move high-level conceptual things like "decrement this topological token" to function calls e.g `TopologicalToken.Decrement()`. Currently broken because `/messages` seemingly used both stream and topological tokens, though I need to confirm this. * final tweaks/hacks * spurious logging * Review comments and linting --- syncapi/consumers/clientapi.go | 2 +- syncapi/consumers/eduserver.go | 6 +- syncapi/consumers/roomserver.go | 4 +- syncapi/routing/messages.go | 77 +++---- syncapi/storage/interface.go | 11 +- syncapi/storage/postgres/syncserver.go | 127 +++++------ syncapi/storage/sqlite3/syncserver.go | 154 ++++++------- syncapi/storage/storage_test.go | 64 +++--- syncapi/sync/notifier.go | 10 +- syncapi/sync/notifier_test.go | 70 ++---- syncapi/sync/request.go | 26 +-- syncapi/sync/requestpool.go | 6 +- syncapi/sync/userstream.go | 12 +- syncapi/types/types.go | 293 ++++++++++++++++--------- syncapi/types/types_test.go | 34 +-- 15 files changed, 457 insertions(+), 439 deletions(-) diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index f5b8c43e..b65d01a0 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -90,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error }).Panicf("could not save account data") } - s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0)) return nil } diff --git a/syncapi/consumers/eduserver.go b/syncapi/consumers/eduserver.go index 249452af..ece999d5 100644 --- a/syncapi/consumers/eduserver.go +++ b/syncapi/consumers/eduserver.go @@ -65,9 +65,7 @@ func (s *OutputTypingEventConsumer) Start() error { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.notifier.OnNewEvent( nil, roomID, nil, - types.PaginationToken{ - EDUTypingPosition: types.StreamPosition(latestSyncPosition), - }, + types.NewStreamToken(0, types.StreamPosition(latestSyncPosition)), ) }) @@ -96,6 +94,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) } - s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.PaginationToken{EDUTypingPosition: typingPos}) + s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos)) return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 987cc5df..368420a6 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -146,7 +146,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write event failure") return nil } - s.notifier.OnNewEvent(&ev, "", nil, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0)) return nil } @@ -164,7 +164,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(&msg.Event, "", nil, types.PaginationToken{PDUPosition: pduPos}) + s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0)) return nil } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 270b0ee9..72c306d4 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -38,8 +38,9 @@ type messagesReq struct { federation *gomatrixserverlib.FederationClient cfg *config.Dendrite roomID string - from *types.PaginationToken - to *types.PaginationToken + from *types.TopologyToken + to *types.TopologyToken + fromStream *types.StreamingToken wasToProvided bool limit int backwardOrdering bool @@ -66,11 +67,16 @@ func OnIncomingMessagesRequest( // Extract parameters from the request's URL. // Pagination tokens. - from, err := types.NewPaginationTokenFromString(req.URL.Query().Get("from")) + var fromStream *types.StreamingToken + from, err := types.NewTopologyTokenFromString(req.URL.Query().Get("from")) if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()), + fs, err2 := types.NewStreamTokenFromString(req.URL.Query().Get("from")) + fromStream = &fs + if err2 != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err2.Error()), + } } } @@ -88,10 +94,10 @@ func OnIncomingMessagesRequest( // Pagination tokens. To is optional, and its default value depends on the // direction ("b" or "f"). - var to *types.PaginationToken + var to types.TopologyToken wasToProvided := true if s := req.URL.Query().Get("to"); len(s) > 0 { - to, err = types.NewPaginationTokenFromString(s) + to, err = types.NewTopologyTokenFromString(s) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -139,8 +145,9 @@ func OnIncomingMessagesRequest( federation: federation, cfg: cfg, roomID: roomID, - from: from, - to: to, + from: &from, + to: &to, + fromStream: fromStream, wasToProvided: wasToProvided, limit: limit, backwardOrdering: backwardOrdering, @@ -178,12 +185,20 @@ func OnIncomingMessagesRequest( // remote homeserver. func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, - end *types.PaginationToken, err error, + end types.TopologyToken, err error, ) { // Retrieve the events from the local database. - streamEvents, err := r.db.GetEventsInRange( - r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, - ) + var streamEvents []types.StreamEvent + if r.fromStream != nil { + toStream := r.to.StreamToken() + streamEvents, err = r.db.GetEventsInStreamingRange( + r.ctx, r.fromStream, &toStream, r.roomID, r.limit, r.backwardOrdering, + ) + } else { + streamEvents, err = r.db.GetEventsInTopologicalRange( + r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + ) + } if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) return @@ -206,7 +221,7 @@ func (r *messagesReq) retrieveEvents() ( // If we didn't get any event, we don't need to proceed any further. if len(events) == 0 { - return []gomatrixserverlib.ClientEvent{}, r.from, r.to, nil + return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil } // Sort the events to ensure we send them in the right order. @@ -246,12 +261,8 @@ func (r *messagesReq) retrieveEvents() ( } // Generate pagination tokens to send to the client using the positions // retrieved previously. - start = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, startPos, startStreamPos, - ) - end = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, endPos, endStreamPos, - ) + start = types.NewTopologyToken(startPos, startStreamPos) + end = types.NewTopologyToken(endPos, endStreamPos) if r.backwardOrdering { // A stream/topological position is a cursor located between two events. @@ -259,14 +270,7 @@ func (r *messagesReq) retrieveEvents() ( // we consider a left to right chronological order), tokens need to refer // to them by the event on their left, therefore we need to decrement the // end position we send in the response if we're going backward. - end.PDUPosition-- - end.EDUTypingPosition += 1000 - } - - // The lowest token value is 1, therefore we need to manually set it to that - // value if we're below it. - if end.PDUPosition < types.StreamPosition(1) { - end.PDUPosition = types.StreamPosition(1) + end.Decrement() } return clientEvents, start, end, err @@ -317,11 +321,11 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent // The condition in the SQL query is a strict "greater than" so // we need to check against to-1. streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition) - isSetLargeEnough = (r.to.PDUPosition-1 == streamPos) + isSetLargeEnough = (r.to.PDUPosition()-1 == streamPos) } } else { streamPos := types.StreamPosition(streamEvents[0].StreamPosition) - isSetLargeEnough = (r.from.PDUPosition-1 == streamPos) + isSetLargeEnough = (r.from.PDUPosition()-1 == streamPos) } } @@ -424,18 +428,17 @@ func (r *messagesReq) backfill(roomID string, fromEventIDs []string, limit int) func setToDefault( ctx context.Context, db storage.Database, backwardOrdering bool, roomID string, -) (to *types.PaginationToken, err error) { +) (to types.TopologyToken, err error) { if backwardOrdering { // go 1 earlier than the first event so we correctly fetch the earliest event - to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to = types.NewTopologyToken(0, 0) } else { - var pos, stream types.StreamPosition - pos, stream, err = db.MaxTopologicalPosition(ctx, roomID) + var depth, stream types.StreamPosition + depth, stream, err = db.MaxTopologicalPosition(ctx, roomID) if err != nil { return } - - to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, stream) + to = types.NewTopologyToken(depth, stream) } return diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 7d637643..63af1136 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -50,13 +50,13 @@ type Database interface { // Returns an error if there was an issue with the retrieval. GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) // SyncPosition returns the latest positions for syncing. - SyncPosition(ctx context.Context) (types.PaginationToken, error) + SyncPosition(ctx context.Context) (types.StreamingToken, error) // IncrementalSync returns all the data needed in order to create an incremental // sync response for the given user. Events returned will include any client // transaction IDs associated with the given device. These transaction IDs come // from when the device sent the event via an API that included a transaction // ID. - IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) // CompleteSync returns a complete /sync API response for the given user. CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) // GetAccountDataInRange returns all account data for a given user inserted or @@ -88,9 +88,10 @@ type Database interface { // RemoveTypingUser removes a typing user from the typing cache. // Returns the newly calculated sync position for typing notifications. RemoveTypingUser(userID, roomID string) types.StreamPosition - // GetEventsInRange retrieves all of the events on a given ordering using the - // given extremities and limit. - GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + // GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit. + GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (depth types.StreamPosition, stream types.StreamPosition, err error) // EventsAtTopologicalPosition returns all of the events matching a given diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 1845ac38..d45bc09e 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -228,69 +228,68 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( return } -func (d *SyncServerDatasource) GetEventsInRange( +func (d *SyncServerDatasource) GetEventsInTopologicalRange( ctx context.Context, - from, to *types.PaginationToken, + from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool, ) (events []types.StreamEvent, err error) { - // If the pagination token's type is types.PaginationTokenTypeTopology, the - // events must be retrieved from the rooms' topology table rather than the - // table contaning the syncapi server's whole stream of events. - if from.Type == types.PaginationTokenTypeTopology { - // Determine the backward and forward limit, i.e. the upper and lower - // limits to the selection in the room's topology, from the direction. - var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition - if backwardOrdering { - // Backward ordering is antichronological (latest event to oldest - // one). - backwardLimit = to.PDUPosition - forwardLimit = from.PDUPosition - forwardMicroLimit = from.EDUTypingPosition - } else { - // Forward ordering is chronological (oldest event to latest one). - backwardLimit = from.PDUPosition - forwardLimit = to.PDUPosition - } + // Determine the backward and forward limit, i.e. the upper and lower + // limits to the selection in the room's topology, from the direction. + var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition + if backwardOrdering { + // Backward ordering is antichronological (latest event to oldest + // one). + backwardLimit = to.Depth() + forwardLimit = from.Depth() + forwardMicroLimit = from.PDUPosition() + } else { + // Forward ordering is chronological (oldest event to latest one). + backwardLimit = from.Depth() + forwardLimit = to.Depth() + } - // Select the event IDs from the defined range. - var eIDs []string - eIDs, err = d.topology.selectEventIDsInRange( - ctx, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, - ) - if err != nil { - return - } - - // Retrieve the events' contents using their IDs. - events, err = d.events.selectEvents(ctx, nil, eIDs) + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.topology.selectEventIDsInRange( + ctx, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, + ) + if err != nil { return } - // If the pagination token's type is types.PaginationTokenTypeStream, the - // events must be retrieved from the table contaning the syncapi server's - // whole stream of events. + // Retrieve the events' contents using their IDs. + events, err = d.events.selectEvents(ctx, nil, eIDs) + return +} +// GetEventsInStreamingRange retrieves all of the events on a given ordering using the +// given extremities and limit. +func (d *SyncServerDatasource) GetEventsInStreamingRange( + ctx context.Context, + from, to *types.StreamingToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { if backwardOrdering { // When using backward ordering, we want the most recent events first. if events, err = d.events.selectRecentEvents( - ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, + ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false, ); err != nil { return } } else { // When using forward ordering, we want the least recent events first. if events, err = d.events.selectEarlyEvents( - ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, + ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit, ); err != nil { return } } - - return + return events, err } -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.StreamingToken, error) { return d.syncPositionTx(ctx, nil) } @@ -353,7 +352,7 @@ func (d *SyncServerDatasource) syncStreamPositionTx( func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, -) (sp types.PaginationToken, err error) { +) (sp types.StreamingToken, err error) { maxEventID, err := d.events.selectMaxEventID(ctx, txn) if err != nil { @@ -373,8 +372,7 @@ func (d *SyncServerDatasource) syncPositionTx( if maxInviteID > maxEventID { maxEventID = maxInviteID } - sp.PDUPosition = types.StreamPosition(maxEventID) - sp.EDUTypingPosition = types.StreamPosition(d.eduCache.GetLatestSyncPosition()) + sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.eduCache.GetLatestSyncPosition())) return } @@ -439,7 +437,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( // addTypingDeltaToResponse adds all typing notifications to a sync response // since the specified position. func (d *SyncServerDatasource) addTypingDeltaToResponse( - since types.PaginationToken, + since types.StreamingToken, joinedRoomIDs []string, res *types.Response, ) error { @@ -448,7 +446,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( var err error for _, roomID := range joinedRoomIDs { if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( - roomID, int64(since.EDUTypingPosition), + roomID, int64(since.EDUPosition()), ); updated { ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MTyping, @@ -473,12 +471,12 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // the positions of that type are not equal in fromPos and toPos. func (d *SyncServerDatasource) addEDUDeltaToResponse( - fromPos, toPos types.PaginationToken, + fromPos, toPos types.StreamingToken, joinedRoomIDs []string, res *types.Response, ) (err error) { - if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { + if fromPos.EDUPosition() != toPos.EDUPosition() { err = d.addTypingDeltaToResponse( fromPos, joinedRoomIDs, res, ) @@ -490,7 +488,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) IncrementalSync( ctx context.Context, device authtypes.Device, - fromPos, toPos types.PaginationToken, + fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool, ) (*types.Response, error) { @@ -499,9 +497,9 @@ func (d *SyncServerDatasource) IncrementalSync( var joinedRoomIDs []string var err error - if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { + if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState { joinedRoomIDs, err = d.addPDUDeltaToResponse( - ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, + ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res, ) } else { joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( @@ -530,7 +528,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( numRecentEventsPerRoom int, ) ( res *types.Response, - toPos types.PaginationToken, + toPos types.StreamingToken, joinedRoomIDs []string, err error, ) { @@ -577,7 +575,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 var recentStreamEvents []types.StreamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, + ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition(), numRecentEventsPerRoom, true, true, ) if err != nil { @@ -588,27 +586,25 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // oldest event in the room's topology. var backwardTopologyPos, backwardStreamPos types.StreamPosition backwardTopologyPos, backwardStreamPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) - if backwardTopologyPos-1 <= 0 { - backwardTopologyPos = types.StreamPosition(1) - } else { - backwardTopologyPos-- + if err != nil { + return } + prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardStreamPos) + prevBatch.Decrement() // We don't include a device here as we don't need to send down // transaction IDs for complete syncs recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) res.Rooms.Join[roomID] = *jr } - if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { + if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition(), res); err != nil { return } @@ -628,7 +624,7 @@ func (d *SyncServerDatasource) CompleteSync( // Use a zero value SyncPosition for fromPos so all EDU states are added. err = d.addEDUDeltaToResponse( - types.PaginationToken{}, toPos, joinedRoomIDs, res, + types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res, ) if err != nil { return nil, err @@ -757,14 +753,15 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, recentStreamEvents) + prevBatch := types.NewTopologyToken( + backwardTopologyPos, backwardStreamPos, + ) switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -775,9 +772,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() - lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + lr.Timeline.PrevBatch = prevBatch.String() lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 314ea2aa..212f882b 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -269,63 +269,63 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( return } -// GetEventsInRange retrieves all of the events on a given ordering using the +// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the // given extremities and limit. -func (d *SyncServerDatasource) GetEventsInRange( +func (d *SyncServerDatasource) GetEventsInTopologicalRange( ctx context.Context, - from, to *types.PaginationToken, + from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool, ) (events []types.StreamEvent, err error) { - // If the pagination token's type is types.PaginationTokenTypeTopology, the - // events must be retrieved from the rooms' topology table rather than the - // table contaning the syncapi server's whole stream of events. - if from.Type == types.PaginationTokenTypeTopology { - // TODO: ARGH CONFUSING - // Determine the backward and forward limit, i.e. the upper and lower - // limits to the selection in the room's topology, from the direction. - var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition - if backwardOrdering { - // Backward ordering is antichronological (latest event to oldest - // one). - backwardLimit = to.PDUPosition - forwardLimit = from.PDUPosition - forwardMicroLimit = from.EDUTypingPosition - } else { - // Forward ordering is chronological (oldest event to latest one). - backwardLimit = from.PDUPosition - forwardLimit = to.PDUPosition - } + // TODO: ARGH CONFUSING + // Determine the backward and forward limit, i.e. the upper and lower + // limits to the selection in the room's topology, from the direction. + var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition + if backwardOrdering { + // Backward ordering is antichronological (latest event to oldest + // one). + backwardLimit = to.Depth() + forwardLimit = from.Depth() + forwardMicroLimit = from.PDUPosition() + } else { + // Forward ordering is chronological (oldest event to latest one). + backwardLimit = from.Depth() + forwardLimit = to.Depth() + } - // Select the event IDs from the defined range. - var eIDs []string - eIDs, err = d.topology.selectEventIDsInRange( - ctx, nil, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, - ) - if err != nil { - return - } - - // Retrieve the events' contents using their IDs. - events, err = d.events.selectEvents(ctx, nil, eIDs) + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.topology.selectEventIDsInRange( + ctx, nil, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, + ) + if err != nil { return } - // If the pagination token's type is types.PaginationTokenTypeStream, the - // events must be retrieved from the table contaning the syncapi server's - // whole stream of events. + // Retrieve the events' contents using their IDs. + events, err = d.events.selectEvents(ctx, nil, eIDs) + return +} +// GetEventsInStreamingRange retrieves all of the events on a given ordering using the +// given extremities and limit. +func (d *SyncServerDatasource) GetEventsInStreamingRange( + ctx context.Context, + from, to *types.StreamingToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { if backwardOrdering { // When using backward ordering, we want the most recent events first. if events, err = d.events.selectRecentEvents( - ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, + ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false, ); err != nil { return } } else { // When using forward ordering, we want the least recent events first. if events, err = d.events.selectEarlyEvents( - ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, + ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit, ); err != nil { return } @@ -334,10 +334,14 @@ func (d *SyncServerDatasource) GetEventsInRange( } // SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.PaginationToken, err error) { +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.StreamingToken, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - tok, err = d.syncPositionTx(ctx, txn) - return err + pos, err := d.syncPositionTx(ctx, txn) + if err != nil { + return err + } + tok = *pos + return nil }) return } @@ -412,30 +416,31 @@ func (d *SyncServerDatasource) syncStreamPositionTx( func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, -) (sp types.PaginationToken, err error) { +) (*types.StreamingToken, error) { maxEventID, err := d.events.selectMaxEventID(ctx, txn) if err != nil { - return sp, err + return nil, err } maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) if err != nil { - return sp, err + return nil, err } if maxAccountDataID > maxEventID { maxEventID = maxAccountDataID } maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) if err != nil { - return sp, err + return nil, err } if maxInviteID > maxEventID { maxEventID = maxInviteID } - sp.PDUPosition = types.StreamPosition(maxEventID) - sp.EDUTypingPosition = types.StreamPosition(d.eduCache.GetLatestSyncPosition()) - sp.Type = types.PaginationTokenTypeStream - return + sp := types.NewStreamToken( + types.StreamPosition(maxEventID), + types.StreamPosition(d.eduCache.GetLatestSyncPosition()), + ) + return &sp, nil } // addPDUDeltaToResponse adds all PDU deltas to a sync response. @@ -499,7 +504,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( // addTypingDeltaToResponse adds all typing notifications to a sync response // since the specified position. func (d *SyncServerDatasource) addTypingDeltaToResponse( - since types.PaginationToken, + since types.StreamingToken, joinedRoomIDs []string, res *types.Response, ) error { @@ -508,7 +513,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( var err error for _, roomID := range joinedRoomIDs { if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( - roomID, int64(since.EDUTypingPosition), + roomID, int64(since.EDUPosition()), ); updated { ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MTyping, @@ -533,12 +538,12 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // the positions of that type are not equal in fromPos and toPos. func (d *SyncServerDatasource) addEDUDeltaToResponse( - fromPos, toPos types.PaginationToken, + fromPos, toPos types.StreamingToken, joinedRoomIDs []string, res *types.Response, ) (err error) { - if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { + if fromPos.EDUPosition() != toPos.EDUPosition() { err = d.addTypingDeltaToResponse( fromPos, joinedRoomIDs, res, ) @@ -555,18 +560,21 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) IncrementalSync( ctx context.Context, device authtypes.Device, - fromPos, toPos types.PaginationToken, + fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool, ) (*types.Response, error) { + fmt.Println("from ", fromPos, "to", toPos) nextBatchPos := fromPos.WithUpdates(toPos) res := types.NewResponse(nextBatchPos) + fmt.Println("from ", fromPos, "to", toPos, "next", nextBatchPos) var joinedRoomIDs []string var err error - if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { + fmt.Println("from", fromPos.PDUPosition(), "to", toPos.PDUPosition()) + if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState { joinedRoomIDs, err = d.addPDUDeltaToResponse( - ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, + ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res, ) } else { joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( @@ -595,7 +603,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( numRecentEventsPerRoom int, ) ( res *types.Response, - toPos types.PaginationToken, + toPos *types.StreamingToken, joinedRoomIDs []string, err error, ) { @@ -621,7 +629,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( return } - res = types.NewResponse(toPos) + res = types.NewResponse(*toPos) // Extract room state and recent events for all rooms the user is joined to. joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) @@ -643,7 +651,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 var recentStreamEvents []types.StreamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, + ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition(), numRecentEventsPerRoom, true, true, ) if err != nil { @@ -655,28 +663,22 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // oldest event in the room's topology. var backwardTopologyPos, backwardTopologyStreamPos types.StreamPosition backwardTopologyPos, backwardTopologyStreamPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) - if backwardTopologyPos-1 <= 0 { - backwardTopologyPos = types.StreamPosition(1) - } else { - backwardTopologyPos-- - backwardTopologyStreamPos += 1000 // this has to be bigger than the number of events we backfill per request - } + prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardTopologyStreamPos) + prevBatch.Decrement() // We don't include a device here as we don't need to send down // transaction IDs for complete syncs recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardTopologyStreamPos, - ).String() + jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) res.Rooms.Join[roomID] = *jr } - if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { + if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition(), res); err != nil { return } @@ -697,7 +699,7 @@ func (d *SyncServerDatasource) CompleteSync( // Use a zero value SyncPosition for fromPos so all EDU states are added. err = d.addEDUDeltaToResponse( - types.PaginationToken{}, toPos, joinedRoomIDs, res, + types.NewStreamToken(0, 0), *toPos, joinedRoomIDs, res, ) if err != nil { return nil, err @@ -860,14 +862,14 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) + prevBatch := types.NewTopologyToken( + backwardTopologyPos, backwardStreamPos, + ) switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() - - jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -878,9 +880,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() - lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, - ).String() + lr.Timeline.PrevBatch = prevBatch.String() lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index b951efa4..f7fa1a87 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -154,10 +154,10 @@ func TestSyncResponse(t *testing.T) { { Name: "IncrementalSync penultimate", DoSync: func() (*types.Response, error) { - from := types.NewPaginationTokenFromTypeAndPosition( // pretend we are at the penultimate event - types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0), + from := types.NewStreamToken( // pretend we are at the penultimate event + positions[len(positions)-2], types.StreamPosition(0), ) - return db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) }, WantTimeline: events[len(events)-1:], }, @@ -166,11 +166,11 @@ func TestSyncResponse(t *testing.T) { { Name: "IncrementalSync limited", DoSync: func() (*types.Response, error) { - from := types.NewPaginationTokenFromTypeAndPosition( // pretend we are 10 events behind - types.PaginationTokenTypeStream, positions[len(positions)-11], types.StreamPosition(0), + from := types.NewStreamToken( // pretend we are 10 events behind + positions[len(positions)-11], types.StreamPosition(0), ) // limit is set to 5 - return db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) }, // want the last 5 events, NOT the last 10. WantTimeline: events[len(events)-5:], @@ -207,7 +207,7 @@ func TestSyncResponse(t *testing.T) { if err != nil { st.Fatalf("failed to do sync: %s", err) } - next := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeStream, latest.PDUPosition, latest.EDUTypingPosition) + next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition()) if res.NextBatch != next.String() { st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) } @@ -230,11 +230,11 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) { if err != nil { t.Fatalf("failed to get SyncPosition: %s", err) } - from := types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0), + from := types.NewStreamToken( + positions[len(positions)-2], types.StreamPosition(0), ) - res, err := db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + res, err := db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) if err != nil { t.Fatalf("failed to IncrementalSync with latest token") } @@ -249,14 +249,14 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) { if prev == "" { t.Fatalf("IncrementalSync expected prev_batch token") } - prevBatchToken, err := types.NewPaginationTokenFromString(prev) + prevBatchToken, err := types.NewTopologyTokenFromString(prev) if err != nil { - t.Fatalf("failed to NewPaginationTokenFromString : %s", err) + t.Fatalf("failed to NewTopologyTokenFromString : %s", err) } // backpaginate 5 messages starting at the latest position. // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) - paginatedEvents, err := db.GetEventsInRange(ctx, prevBatchToken, to, testRoomID, 5, true) + to := types.NewTopologyToken(0, 0) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true) if err != nil { t.Fatalf("GetEventsInRange returned an error: %s", err) } @@ -275,10 +275,10 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) { t.Fatalf("failed to get SyncPosition: %s", err) } // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to := types.NewStreamToken(0, 0) // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInRange(ctx, &latest, to, testRoomID, 5, true) + paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) if err != nil { t.Fatalf("GetEventsInRange returned an error: %s", err) } @@ -296,12 +296,12 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { if err != nil { t.Fatalf("failed to get MaxTopologicalPosition: %s", err) } - from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, latestStream) + from := types.NewTopologyToken(latest, latestStream) // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to := types.NewTopologyToken(0, 0) // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInRange(ctx, from, to, testRoomID, 5, true) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) if err != nil { t.Fatalf("GetEventsInRange returned an error: %s", err) } @@ -366,14 +366,14 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { if err != nil { t.Fatalf("failed to get EventPositionInTopology for event: %s", err) } - fromLatest := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latestPos, latestStreamPos) - fromFork := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, topoPos, streamPos) + fromLatest := types.NewTopologyToken(latestPos, latestStreamPos) + fromFork := types.NewTopologyToken(topoPos, streamPos) // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to := types.NewTopologyToken(0, 0) testCases := []struct { Name string - From *types.PaginationToken + From types.TopologyToken Limit int Wants []gomatrixserverlib.HeaderedEvent }{ @@ -399,7 +399,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { for _, tc := range testCases { // backpaginate messages starting at the latest position. - paginatedEvents, err := db.GetEventsInRange(ctx, tc.From, to, testRoomID, tc.Limit, true) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &tc.From, &to, testRoomID, tc.Limit, true) if err != nil { t.Fatalf("%s GetEventsInRange returned an error: %s", tc.Name, err) } @@ -446,13 +446,13 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) { } // head towards the beginning of time - to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + to := types.NewTopologyToken(0, 0) // starting at `from`, backpaginate to the beginning of time, asserting as we go. chunkSize = 3 events = reversed(events) for i := 0; i < len(events); i += chunkSize { - paginatedEvents, err := db.GetEventsInRange(ctx, from, to, testRoomID, chunkSize, true) + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, from, &to, testRoomID, chunkSize, true) if err != nil { t.Fatalf("GetEventsInRange returned an error: %s", err) } @@ -506,19 +506,15 @@ func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatr } } -func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *types.PaginationToken { +func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *types.TopologyToken { pos, spos, err := db.EventPositionInTopology(ctx, eventID) if err != nil { t.Fatalf("failed to get EventPositionInTopology: %s", err) } - if pos-1 <= 0 { - pos = types.StreamPosition(1) - } else { - pos = pos - 1 - spos += 1000 // this has to be bigger than the chunk limit - } - return types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, spos) + tok := types.NewTopologyToken(pos, spos) + tok.Decrement() + return &tok } func reversed(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 0d805011..b3ed5cd0 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -36,7 +36,7 @@ type Notifier struct { // Protects currPos and userStreams. streamLock *sync.Mutex // The latest sync position - currPos types.PaginationToken + currPos types.StreamingToken // A map of user_id => UserStream which can be used to wake a given user's /sync request. userStreams map[string]*UserStream // The last time we cleaned out stale entries from the userStreams map @@ -46,7 +46,7 @@ type Notifier struct { // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier(pos types.PaginationToken) *Notifier { +func NewNotifier(pos types.StreamingToken) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), @@ -68,7 +68,7 @@ func NewNotifier(pos types.PaginationToken) *Notifier { // event type it handles, leaving other fields as 0. func (n *Notifier) OnNewEvent( ev *gomatrixserverlib.HeaderedEvent, roomID string, userIDs []string, - posUpdate types.PaginationToken, + posUpdate types.StreamingToken, ) { // update the current position then notify relevant /sync streams. // This needs to be done PRIOR to waking up users as they will read this value. @@ -151,7 +151,7 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error { } // CurrentPosition returns the current sync position -func (n *Notifier) CurrentPosition() types.PaginationToken { +func (n *Notifier) CurrentPosition() types.StreamingToken { n.streamLock.Lock() defer n.streamLock.Unlock() @@ -173,7 +173,7 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } -func (n *Notifier) wakeupUsers(userIDs []string, newPos types.PaginationToken) { +func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { for _, userID := range userIDs { stream := n.fetchUserStream(userID, false) if stream != nil { diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 350d757c..7d979fcc 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -33,11 +33,11 @@ var ( randomMessageEvent gomatrixserverlib.HeaderedEvent aliceInviteBobEvent gomatrixserverlib.HeaderedEvent bobLeaveEvent gomatrixserverlib.HeaderedEvent - syncPositionVeryOld types.PaginationToken - syncPositionBefore types.PaginationToken - syncPositionAfter types.PaginationToken - syncPositionNewEDU types.PaginationToken - syncPositionAfter2 types.PaginationToken + syncPositionVeryOld = types.NewStreamToken(5, 0) + syncPositionBefore = types.NewStreamToken(11, 0) + syncPositionAfter = types.NewStreamToken(12, 0) + syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1) + syncPositionAfter2 = types.NewStreamToken(13, 0) ) var ( @@ -47,26 +47,6 @@ var ( ) func init() { - baseSyncPos := types.PaginationToken{ - PDUPosition: 0, - EDUTypingPosition: 0, - } - - syncPositionVeryOld = baseSyncPos - syncPositionVeryOld.PDUPosition = 5 - - syncPositionBefore = baseSyncPos - syncPositionBefore.PDUPosition = 11 - - syncPositionAfter = baseSyncPos - syncPositionAfter.PDUPosition = 12 - - syncPositionNewEDU = syncPositionAfter - syncPositionNewEDU.EDUTypingPosition = 1 - - syncPositionAfter2 = baseSyncPos - syncPositionAfter2.PDUPosition = 13 - var err error err = json.Unmarshal([]byte(`{ "_room_version": "1", @@ -118,6 +98,12 @@ func init() { } } +func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { + if got.String() != want.String() { + t.Fatalf("mustEqualPositions got %s want %s", got.String(), want.String()) + } +} + // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { n := NewNotifier(syncPositionBefore) @@ -125,9 +111,7 @@ func TestImmediateNotification(t *testing.T) { if err != nil { t.Fatalf("TestImmediateNotification error: %s", err) } - if pos != syncPositionBefore { - t.Fatalf("TestImmediateNotification want %v, got %v", syncPositionBefore, pos) - } + mustEqualPositions(t, pos, syncPositionBefore) } // Test that new events to a joined room unblocks the request. @@ -144,9 +128,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { if err != nil { t.Errorf("TestNewEventAndJoinedToRoom error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewEventAndJoinedToRoom want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() }() @@ -172,9 +154,7 @@ func TestNewInviteEventForUser(t *testing.T) { if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() }() @@ -200,9 +180,7 @@ func TestEDUWakeup(t *testing.T) { if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } - if pos != syncPositionNewEDU { - t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionNewEDU, pos) - } + mustEqualPositions(t, pos, syncPositionNewEDU) wg.Done() }() @@ -228,9 +206,7 @@ func TestMultipleRequestWakeup(t *testing.T) { if err != nil { t.Errorf("TestMultipleRequestWakeup error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestMultipleRequestWakeup want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) wg.Done() } go poll() @@ -268,9 +244,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } - if pos != syncPositionAfter { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter, pos) - } + mustEqualPositions(t, pos, syncPositionAfter) leaveWG.Done() }() bobStream := lockedFetchUserStream(n, bob) @@ -287,9 +261,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } - if pos != syncPositionAfter2 { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter2, pos) - } + mustEqualPositions(t, pos, syncPositionAfter2) aliceWG.Done() }() @@ -312,13 +284,13 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { time.Sleep(1 * time.Millisecond) } -func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) { +func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) { listener := n.GetListener(req) defer listener.Close() select { case <-time.After(5 * time.Second): - return types.PaginationToken{}, fmt.Errorf( + return types.StreamingToken{}, fmt.Errorf( "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since, ) case <-listener.GetNotifyChannel(*req.since): @@ -344,7 +316,7 @@ func lockedFetchUserStream(n *Notifier, userID string) *UserStream { return n.fetchUserStream(userID, true) } -func newTestSyncRequest(userID string, since types.PaginationToken) syncRequest { +func newTestSyncRequest(userID string, since types.StreamingToken) syncRequest { return syncRequest{ device: authtypes.Device{UserID: userID}, timeout: 1 * time.Minute, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index f2e199d2..66663cf0 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -36,7 +36,7 @@ type syncRequest struct { device authtypes.Device limit int timeout time.Duration - since *types.PaginationToken // nil means that no since token was supplied + since *types.StreamingToken // nil means that no since token was supplied wantFullState bool log *log.Entry } @@ -45,9 +45,14 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" - since, err := getPaginationToken(req.URL.Query().Get("since")) - if err != nil { - return nil, err + var since *types.StreamingToken + sinceStr := req.URL.Query().Get("since") + if sinceStr != "" { + tok, err := types.NewStreamTokenFromString(sinceStr) + if err != nil { + return nil, err + } + since = &tok } // TODO: Additional query params: set_presence, filter return &syncRequest{ @@ -71,16 +76,3 @@ func getTimeout(timeoutMS string) time.Duration { } return time.Duration(i) * time.Millisecond } - -// getSyncStreamPosition tries to parse a 'since' token taken from the API to a -// types.PaginationToken. If the string is empty then (nil, nil) is returned. -// There are two forms of tokens: The full length form containing all PDU and EDU -// positions separated by "_", and the short form containing only the PDU -// position. Short form can be used for, e.g., `prev_batch` tokens. -func getPaginationToken(since string) (*types.PaginationToken, error) { - if since == "" { - return nil, nil - } - - return types.NewPaginationTokenFromString(since) -} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 69efd8aa..126e76f5 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -132,7 +132,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.PaginationToken) (res *types.Response, err error) { +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { // TODO: handle ignored users if req.since == nil { res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) @@ -145,7 +145,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Pagin } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead - res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) + res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) return } @@ -187,7 +187,7 @@ func (rp *RequestPool) appendAccountData( // Sync is not initial, get all account data since the latest sync dataTypes, err := rp.db.GetAccountDataInRange( req.ctx, userID, - types.StreamPosition(req.since.PDUPosition), types.StreamPosition(currentPos), + types.StreamPosition(req.since.PDUPosition()), types.StreamPosition(currentPos), accountDataFilter, ) if err != nil { diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index 88867005..b2eafa3d 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -34,7 +34,7 @@ type UserStream struct { // Closed when there is an update. signalChannel chan struct{} // The last sync position that there may have been an update for the user - pos types.PaginationToken + pos types.StreamingToken // The last time when we had some listeners waiting timeOfLastChannel time.Time // The number of listeners waiting @@ -50,7 +50,7 @@ type UserStreamListener struct { } // NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.PaginationToken) *UserStream { +func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { return &UserStream{ UserID: userID, timeOfLastChannel: time.Now(), @@ -83,7 +83,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { } // Broadcast a new sync position for this user. -func (s *UserStream) Broadcast(pos types.PaginationToken) { +func (s *UserStream) Broadcast(pos types.StreamingToken) { s.lock.Lock() defer s.lock.Unlock() @@ -116,9 +116,9 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { return s.timeOfLastChannel } -// GetStreamPosition returns last sync position which the UserStream was +// GetSyncPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetSyncPosition() types.PaginationToken { +func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -130,7 +130,7 @@ func (s *UserStreamListener) GetSyncPosition() types.PaginationToken { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <-chan struct{} { +func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() diff --git a/syncapi/types/types.go b/syncapi/types/types.go index c04fe521..c1b6d7dd 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -27,19 +27,19 @@ import ( ) var ( - // ErrInvalidPaginationTokenType is returned when an attempt at creating a - // new instance of PaginationToken with an invalid type (i.e. neither "s" + // ErrInvalidSyncTokenType is returned when an attempt at creating a + // new instance of SyncToken with an invalid type (i.e. neither "s" // nor "t"). - ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)") - // ErrInvalidPaginationTokenLen is returned when the pagination token is an + ErrInvalidSyncTokenType = fmt.Errorf("Sync token has an unknown prefix (should be either s or t)") + // ErrInvalidSyncTokenLen is returned when the pagination token is an // invalid length - ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length") + ErrInvalidSyncTokenLen = fmt.Errorf("Sync token has an invalid length") ) // StreamPosition represents the offset in the sync stream a client is at. type StreamPosition int64 -// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. +// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. type StreamEvent struct { gomatrixserverlib.HeaderedEvent StreamPosition StreamPosition @@ -47,118 +47,201 @@ type StreamEvent struct { ExcludeFromSync bool } -// PaginationTokenType represents the type of a pagination token. +// SyncTokenType represents the type of a sync token. // It can be either "s" (representing a position in the whole stream of events) // or "t" (representing a position in a room's topology/depth). -type PaginationTokenType string +type SyncTokenType string const ( - // PaginationTokenTypeStream represents a position in the server's whole + // SyncTokenTypeStream represents a position in the server's whole // stream of events - PaginationTokenTypeStream PaginationTokenType = "s" - // PaginationTokenTypeTopology represents a position in a room's topology. - PaginationTokenTypeTopology PaginationTokenType = "t" + SyncTokenTypeStream SyncTokenType = "s" + // SyncTokenTypeTopology represents a position in a room's topology. + SyncTokenTypeTopology SyncTokenType = "t" ) -// PaginationToken represents a pagination token, used for interactions with -// /sync or /messages, for example. -type PaginationToken struct { - //Position StreamPosition - Type PaginationTokenType - // For /sync, this is the PDU position. For /messages, this is the topological position (depth). - // TODO: Given how different the positions are depending on the token type, they should probably be renamed - // or use different structs altogether. - PDUPosition StreamPosition - // For /sync, this is the EDU position. For /messages, this is the stream (PDU) position. - // TODO: Given how different the positions are depending on the token type, they should probably be renamed - // or use different structs altogether. - EDUTypingPosition StreamPosition +type StreamingToken struct { + syncToken } -// NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x" -// represents the type of a pagination token and "yyyy..." the token itself, and -// parses it in order to create a new instance of PaginationToken. Returns an -// error if the token couldn't be parsed into an int64, or if the token type -// isn't a known type (returns ErrInvalidPaginationTokenType in the latter -// case). -func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) { - if len(s) == 0 { - return nil, ErrInvalidPaginationTokenLen - } +func (t *StreamingToken) PDUPosition() StreamPosition { + return t.Positions[0] +} +func (t *StreamingToken) EDUPosition() StreamPosition { + return t.Positions[1] +} - token = new(PaginationToken) - var positions []string - - switch t := PaginationTokenType(s[:1]); t { - case PaginationTokenTypeStream, PaginationTokenTypeTopology: - token.Type = t - positions = strings.Split(s[1:], "_") - default: - token.Type = PaginationTokenTypeStream - positions = strings.Split(s, "_") - } - - // Try to get the PDU position. - if len(positions) >= 1 { - if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil { - return nil, err - } else if pduPos < 0 { - return nil, errors.New("negative PDU position not allowed") - } else { - token.PDUPosition = StreamPosition(pduPos) +// IsAfter returns true if ANY position in this token is greater than `other`. +func (t *StreamingToken) IsAfter(other StreamingToken) bool { + for i := range other.Positions { + if t.Positions[i] > other.Positions[i] { + return true } } + return false +} - // Try to get the typing position. - if len(positions) >= 2 { - if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil { - return nil, err - } else if typPos < 0 { - return nil, errors.New("negative EDU typing position not allowed") - } else { - token.EDUTypingPosition = StreamPosition(typPos) +// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. +// If the latter StreamingToken contains a field that is not 0, it is considered an update, +// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called. +func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) { + ret.Type = t.Type + ret.Positions = make([]StreamPosition, len(t.Positions)) + for i := range t.Positions { + ret.Positions[i] = t.Positions[i] + if other.Positions[i] == 0 { + continue } - } - - return -} - -// NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a -// StreamPosition and returns an instance of PaginationToken. -func NewPaginationTokenFromTypeAndPosition( - t PaginationTokenType, pdupos StreamPosition, typpos StreamPosition, -) (p *PaginationToken) { - return &PaginationToken{ - Type: t, - PDUPosition: pdupos, - EDUTypingPosition: typpos, - } -} - -// String translates a PaginationToken to a string of the "xyyyy..." (see -// NewPaginationToken to know what it represents). -func (p *PaginationToken) String() string { - return fmt.Sprintf("%s%d_%d", p.Type, p.PDUPosition, p.EDUTypingPosition) -} - -// WithUpdates returns a copy of the PaginationToken with updates applied from another PaginationToken. -// If the latter PaginationToken contains a field that is not 0, it is considered an update, -// and its value will replace the corresponding value in the PaginationToken on which WithUpdates is called. -func (pt *PaginationToken) WithUpdates(other PaginationToken) PaginationToken { - ret := *pt - if other.PDUPosition != 0 { - ret.PDUPosition = other.PDUPosition - } - if other.EDUTypingPosition != 0 { - ret.EDUTypingPosition = other.EDUTypingPosition + ret.Positions[i] = other.Positions[i] } return ret } -// IsAfter returns whether one PaginationToken refers to states newer than another PaginationToken. -func (sp *PaginationToken) IsAfter(other PaginationToken) bool { - return sp.PDUPosition > other.PDUPosition || - sp.EDUTypingPosition > other.EDUTypingPosition +type TopologyToken struct { + syncToken +} + +func (t *TopologyToken) Depth() StreamPosition { + return t.Positions[0] +} +func (t *TopologyToken) PDUPosition() StreamPosition { + return t.Positions[1] +} +func (t *TopologyToken) StreamToken() StreamingToken { + return NewStreamToken(t.PDUPosition(), 0) +} +func (t *TopologyToken) String() string { + return t.syncToken.String() +} + +// Decrement the topology token to one event earlier. +func (t *TopologyToken) Decrement() { + depth := t.Positions[0] + pduPos := t.Positions[1] + if depth-1 <= 0 { + depth = 1 + } else { + depth-- + pduPos += 1000 + } + // The lowest token value is 1, therefore we need to manually set it to that + // value if we're below it. + if depth < 1 { + depth = 1 + } + t.Positions = []StreamPosition{ + depth, pduPos, + } +} + +// NewSyncTokenFromString takes a string of the form "xyyyy..." where "x" +// represents the type of a pagination token and "yyyy..." the token itself, and +// parses it in order to create a new instance of SyncToken. Returns an +// error if the token couldn't be parsed into an int64, or if the token type +// isn't a known type (returns ErrInvalidSyncTokenType in the latter +// case). +func newSyncTokenFromString(s string) (token *syncToken, err error) { + if len(s) == 0 { + return nil, ErrInvalidSyncTokenLen + } + + token = new(syncToken) + var positions []string + + switch t := SyncTokenType(s[:1]); t { + case SyncTokenTypeStream, SyncTokenTypeTopology: + token.Type = t + positions = strings.Split(s[1:], "_") + default: + return nil, ErrInvalidSyncTokenType + } + + for _, pos := range positions { + if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil { + return nil, err + } else if posInt < 0 { + return nil, errors.New("negative position not allowed") + } else { + token.Positions = append(token.Positions, StreamPosition(posInt)) + } + } + return +} + +// NewTopologyToken creates a new sync token for /messages +func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken { + if depth < 0 { + depth = 1 + } + return TopologyToken{ + syncToken: syncToken{ + Type: SyncTokenTypeTopology, + Positions: []StreamPosition{depth, streamPos}, + }, + } +} +func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { + t, err := newSyncTokenFromString(tok) + if err != nil { + return + } + if t.Type != SyncTokenTypeTopology { + err = fmt.Errorf("token %s is not a topology token", tok) + return + } + if len(t.Positions) != 2 { + err = fmt.Errorf("token %s wrong number of values, got %d want 2", tok, len(t.Positions)) + return + } + return TopologyToken{ + syncToken: *t, + }, nil +} + +// NewStreamToken creates a new sync token for /sync +func NewStreamToken(pduPos, eduPos StreamPosition) StreamingToken { + return StreamingToken{ + syncToken: syncToken{ + Type: SyncTokenTypeStream, + Positions: []StreamPosition{pduPos, eduPos}, + }, + } +} +func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { + t, err := newSyncTokenFromString(tok) + if err != nil { + return + } + if t.Type != SyncTokenTypeStream { + err = fmt.Errorf("token %s is not a streaming token", tok) + return + } + if len(t.Positions) != 2 { + err = fmt.Errorf("token %s wrong number of values, got %d want 2", tok, len(t.Positions)) + return + } + return StreamingToken{ + syncToken: *t, + }, nil +} + +// syncToken represents a syncapi token, used for interactions with +// /sync or /messages, for example. +type syncToken struct { + Type SyncTokenType + // A list of stream positions, their meanings vary depending on the token type. + Positions []StreamPosition +} + +// String translates a SyncToken to a string of the "xyyyy..." (see +// NewSyncToken to know what it represents). +func (p *syncToken) String() string { + posStr := make([]string, len(p.Positions)) + for i := range p.Positions { + posStr[i] = strconv.FormatInt(int64(p.Positions[i]), 10) + } + + return fmt.Sprintf("%s%s", p.Type, strings.Join(posStr, "_")) } // PrevEventRef represents a reference to a previous event in a state event upgrade @@ -185,7 +268,7 @@ type Response struct { } // NewResponse creates an empty response with initialised maps. -func NewResponse(token PaginationToken) *Response { +func NewResponse(token StreamingToken) *Response { res := Response{ NextBatch: token.String(), } @@ -202,14 +285,6 @@ func NewResponse(token PaginationToken) *Response { res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) - // Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume - // we'll always return a stream token. - res.NextBatch = NewPaginationTokenFromTypeAndPosition( - PaginationTokenTypeStream, - StreamPosition(token.PDUPosition), - StreamPosition(token.EDUTypingPosition), - ).String() - return &res } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index f4c84e0d..1e27a8e3 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,26 +2,11 @@ package types import "testing" -func TestNewPaginationTokenFromString(t *testing.T) { - shouldPass := map[string]PaginationToken{ - "2": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 2, - }, - "s4": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 4, - }, - "s3_1": PaginationToken{ - Type: PaginationTokenTypeStream, - PDUPosition: 3, - EDUTypingPosition: 1, - }, - "t3_1_4": PaginationToken{ - Type: PaginationTokenTypeTopology, - PDUPosition: 3, - EDUTypingPosition: 1, - }, +func TestNewSyncTokenFromString(t *testing.T) { + shouldPass := map[string]syncToken{ + "s4_0": NewStreamToken(4, 0).syncToken, + "s3_1": NewStreamToken(3, 1).syncToken, + "t3_1": NewTopologyToken(3, 1).syncToken, } shouldFail := []string{ @@ -32,20 +17,21 @@ func TestNewPaginationTokenFromString(t *testing.T) { "b", "b-1", "-4", + "2", } for test, expected := range shouldPass { - result, err := NewPaginationTokenFromString(test) + result, err := newSyncTokenFromString(test) if err != nil { t.Error(err) } - if *result != expected { - t.Errorf("expected %v but got %v", expected.String(), result.String()) + if result.String() != expected.String() { + t.Errorf("%s expected %v but got %v", test, expected.String(), result.String()) } } for _, test := range shouldFail { - if _, err := NewPaginationTokenFromString(test); err == nil { + if _, err := newSyncTokenFromString(test); err == nil { t.Errorf("input '%v' should have errored but didn't", test) } }