Per-user-per-device sync streams (#1068)
* Per-user-per-device sync streams * Tweaks * Tweaks * Pass full device into CompleteSync * Set user IDs and device IDs properly in tests * Add new test, fix TestNewEventAndWasPreviouslyJoinedToRoom * nolint a function that is not used yet * Add test for waking up single device * Hopefully unstick test * Try to ensure that TestCorrectStreamWakeup doesn't block forever * Update testsmain
parent
57841fc35e
commit
02fe38e1f7
|
@ -13,4 +13,4 @@ go build ./cmd/...
|
||||||
./scripts/find-lint.sh
|
./scripts/find-lint.sh
|
||||||
|
|
||||||
echo "Testing..."
|
echo "Testing..."
|
||||||
go test ./...
|
go test -v ./...
|
||||||
|
|
|
@ -58,7 +58,7 @@ type Database interface {
|
||||||
// ID.
|
// ID.
|
||||||
IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
|
IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
|
||||||
// CompleteSync returns a complete /sync API response for the given user.
|
// CompleteSync returns a complete /sync API response for the given user.
|
||||||
CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error)
|
CompleteSync(ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error)
|
||||||
// GetAccountDataInRange returns all account data for a given user inserted or
|
// GetAccountDataInRange returns all account data for a given user inserted or
|
||||||
// updated between two given positions
|
// updated between two given positions
|
||||||
// Returns a map following the format data[roomID] = []dataTypes
|
// Returns a map following the format data[roomID] = []dataTypes
|
||||||
|
|
|
@ -666,10 +666,10 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) CompleteSync(
|
func (d *Database) CompleteSync(
|
||||||
ctx context.Context, userID string, numRecentEventsPerRoom int,
|
ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int,
|
||||||
) (*types.Response, error) {
|
) (*types.Response, error) {
|
||||||
res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
|
res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
|
||||||
ctx, userID, numRecentEventsPerRoom,
|
ctx, device.UserID, numRecentEventsPerRoom,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -181,7 +181,7 @@ func TestSyncResponse(t *testing.T) {
|
||||||
Name: "CompleteSync limited",
|
Name: "CompleteSync limited",
|
||||||
DoSync: func() (*types.Response, error) {
|
DoSync: func() (*types.Response, error) {
|
||||||
// limit set to 5
|
// limit set to 5
|
||||||
return db.CompleteSync(ctx, testUserIDA, 5)
|
return db.CompleteSync(ctx, testUserDeviceA, 5)
|
||||||
},
|
},
|
||||||
// want the last 5 events
|
// want the last 5 events
|
||||||
WantTimeline: events[len(events)-5:],
|
WantTimeline: events[len(events)-5:],
|
||||||
|
@ -193,7 +193,7 @@ func TestSyncResponse(t *testing.T) {
|
||||||
{
|
{
|
||||||
Name: "CompleteSync",
|
Name: "CompleteSync",
|
||||||
DoSync: func() (*types.Response, error) {
|
DoSync: func() (*types.Response, error) {
|
||||||
return db.CompleteSync(ctx, testUserIDA, len(events)+1)
|
return db.CompleteSync(ctx, testUserDeviceA, len(events)+1)
|
||||||
},
|
},
|
||||||
WantTimeline: events,
|
WantTimeline: events,
|
||||||
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
|
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
|
||||||
|
|
|
@ -37,8 +37,8 @@ type Notifier struct {
|
||||||
streamLock *sync.Mutex
|
streamLock *sync.Mutex
|
||||||
// The latest sync position
|
// The latest sync position
|
||||||
currPos types.StreamingToken
|
currPos types.StreamingToken
|
||||||
// A map of user_id => UserStream which can be used to wake a given user's /sync request.
|
// A map of user_id => device_id => UserStream which can be used to wake a given user's /sync request.
|
||||||
userStreams map[string]*UserStream
|
userDeviceStreams map[string]map[string]*UserDeviceStream
|
||||||
// The last time we cleaned out stale entries from the userStreams map
|
// The last time we cleaned out stale entries from the userStreams map
|
||||||
lastCleanUpTime time.Time
|
lastCleanUpTime time.Time
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ func NewNotifier(pos types.StreamingToken) *Notifier {
|
||||||
return &Notifier{
|
return &Notifier{
|
||||||
currPos: pos,
|
currPos: pos,
|
||||||
roomIDToJoinedUsers: make(map[string]userIDSet),
|
roomIDToJoinedUsers: make(map[string]userIDSet),
|
||||||
userStreams: make(map[string]*UserStream),
|
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
|
||||||
streamLock: &sync.Mutex{},
|
streamLock: &sync.Mutex{},
|
||||||
lastCleanUpTime: time.Now(),
|
lastCleanUpTime: time.Now(),
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ func (n *Notifier) OnNewEvent(
|
||||||
// GetListener returns a UserStreamListener that can be used to wait for
|
// GetListener returns a UserStreamListener that can be used to wait for
|
||||||
// updates for a user. Must be closed.
|
// updates for a user. Must be closed.
|
||||||
// notify for anything before sincePos
|
// notify for anything before sincePos
|
||||||
func (n *Notifier) GetListener(req syncRequest) UserStreamListener {
|
func (n *Notifier) GetListener(req syncRequest) UserDeviceStreamListener {
|
||||||
// Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298
|
// Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298
|
||||||
// - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID
|
// - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID
|
||||||
// - Incoming events wake requests for a matching room ID
|
// - Incoming events wake requests for a matching room ID
|
||||||
|
@ -137,7 +137,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener {
|
||||||
|
|
||||||
n.removeEmptyUserStreams()
|
n.removeEmptyUserStreams()
|
||||||
|
|
||||||
return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx)
|
return n.fetchUserDeviceStream(req.device.UserID, req.device.ID, true).GetListener(req.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the membership states required to notify users correctly.
|
// Load the membership states required to notify users correctly.
|
||||||
|
@ -173,27 +173,69 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wakeupUsers will wake up the sync strems for all of the devices for all of the
|
||||||
|
// specified user IDs.
|
||||||
func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
|
func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
|
||||||
for _, userID := range userIDs {
|
for _, userID := range userIDs {
|
||||||
stream := n.fetchUserStream(userID, false)
|
for _, stream := range n.fetchUserStreams(userID) {
|
||||||
if stream != nil {
|
if stream == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
|
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true,
|
// wakeupUserDevice will wake up the sync stream for a specific user device. Other
|
||||||
|
// device streams will be left alone.
|
||||||
|
// nolint:unused
|
||||||
|
func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) {
|
||||||
|
for userID, deviceID := range userDevices {
|
||||||
|
if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil {
|
||||||
|
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchUserDeviceStream retrieves a stream unique to the given device. If makeIfNotExists is true,
|
||||||
|
// a stream will be made for this device if one doesn't exist and it will be returned. This
|
||||||
|
// function does not wait for data to be available on the stream.
|
||||||
|
// NB: Callers should have locked the mutex before calling this function.
|
||||||
|
func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream {
|
||||||
|
_, ok := n.userDeviceStreams[userID]
|
||||||
|
if !ok {
|
||||||
|
if !makeIfNotExists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
n.userDeviceStreams[userID] = map[string]*UserDeviceStream{}
|
||||||
|
}
|
||||||
|
stream, ok := n.userDeviceStreams[userID][deviceID]
|
||||||
|
if !ok {
|
||||||
|
if !makeIfNotExists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// TODO: Unbounded growth of streams (1 per user)
|
||||||
|
if stream = NewUserDeviceStream(userID, deviceID, n.currPos); stream != nil {
|
||||||
|
n.userDeviceStreams[userID][deviceID] = stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchUserStreams retrieves all streams for the given user. If makeIfNotExists is true,
|
||||||
// a stream will be made for this user if one doesn't exist and it will be returned. This
|
// a stream will be made for this user if one doesn't exist and it will be returned. This
|
||||||
// function does not wait for data to be available on the stream.
|
// function does not wait for data to be available on the stream.
|
||||||
// NB: Callers should have locked the mutex before calling this function.
|
// NB: Callers should have locked the mutex before calling this function.
|
||||||
func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream {
|
func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream {
|
||||||
stream, ok := n.userStreams[userID]
|
user, ok := n.userDeviceStreams[userID]
|
||||||
if !ok && makeIfNotExists {
|
if !ok {
|
||||||
// TODO: Unbounded growth of streams (1 per user)
|
return []*UserDeviceStream{}
|
||||||
stream = NewUserStream(userID, n.currPos)
|
|
||||||
n.userStreams[userID] = stream
|
|
||||||
}
|
}
|
||||||
return stream
|
streams := []*UserDeviceStream{}
|
||||||
|
for _, stream := range user {
|
||||||
|
streams = append(streams, stream)
|
||||||
|
}
|
||||||
|
return streams
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||||
|
@ -236,9 +278,14 @@ func (n *Notifier) removeEmptyUserStreams() {
|
||||||
n.lastCleanUpTime = now
|
n.lastCleanUpTime = now
|
||||||
|
|
||||||
deleteBefore := now.Add(-5 * time.Minute)
|
deleteBefore := now.Add(-5 * time.Minute)
|
||||||
for key, value := range n.userStreams {
|
for user, byUser := range n.userDeviceStreams {
|
||||||
if value.TimeOfLastNonEmpty().Before(deleteBefore) {
|
for device, stream := range byUser {
|
||||||
delete(n.userStreams, key)
|
if stream.TimeOfLastNonEmpty().Before(deleteBefore) {
|
||||||
|
delete(n.userDeviceStreams[user], device)
|
||||||
|
}
|
||||||
|
if len(n.userDeviceStreams[user]) == 0 {
|
||||||
|
delete(n.userDeviceStreams, user)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,9 @@ var (
|
||||||
var (
|
var (
|
||||||
roomID = "!test:localhost"
|
roomID = "!test:localhost"
|
||||||
alice = "@alice:localhost"
|
alice = "@alice:localhost"
|
||||||
|
aliceDev = "alicedevice"
|
||||||
bob = "@bob:localhost"
|
bob = "@bob:localhost"
|
||||||
|
bobDev = "bobdev"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -107,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
|
||||||
// Test that the current position is returned if a request is already behind.
|
// Test that the current position is returned if a request is already behind.
|
||||||
func TestImmediateNotification(t *testing.T) {
|
func TestImmediateNotification(t *testing.T) {
|
||||||
n := NewNotifier(syncPositionBefore)
|
n := NewNotifier(syncPositionBefore)
|
||||||
pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld))
|
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("TestImmediateNotification error: %s", err)
|
t.Fatalf("TestImmediateNotification error: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -124,7 +126,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TestNewEventAndJoinedToRoom error: %w", err)
|
t.Errorf("TestNewEventAndJoinedToRoom error: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -132,7 +134,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
stream := lockedFetchUserStream(n, bob)
|
stream := lockedFetchUserStream(n, bob, bobDev)
|
||||||
waitForBlocking(stream, 1)
|
waitForBlocking(stream, 1)
|
||||||
|
|
||||||
n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
|
n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
|
||||||
|
@ -140,6 +142,43 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCorrectStream(t *testing.T) {
|
||||||
|
n := NewNotifier(syncPositionBefore)
|
||||||
|
stream := lockedFetchUserStream(n, bob, bobDev)
|
||||||
|
if stream.UserID != bob {
|
||||||
|
t.Fatalf("expected user %q, got %q", bob, stream.UserID)
|
||||||
|
}
|
||||||
|
if stream.DeviceID != bobDev {
|
||||||
|
t.Fatalf("expected device %q, got %q", bobDev, stream.DeviceID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorrectStreamWakeup(t *testing.T) {
|
||||||
|
n := NewNotifier(syncPositionBefore)
|
||||||
|
awoken := make(chan string)
|
||||||
|
|
||||||
|
streamone := lockedFetchUserStream(n, alice, "one")
|
||||||
|
streamtwo := lockedFetchUserStream(n, alice, "two")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-streamone.signalChannel:
|
||||||
|
awoken <- "one"
|
||||||
|
case <-streamtwo.signalChannel:
|
||||||
|
awoken <- "two"
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
|
wake := "two"
|
||||||
|
n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter)
|
||||||
|
|
||||||
|
if result := <-awoken; result != wake {
|
||||||
|
t.Fatalf("expected to wake %q, got %q", wake, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Test that an invite unblocks the request
|
// Test that an invite unblocks the request
|
||||||
func TestNewInviteEventForUser(t *testing.T) {
|
func TestNewInviteEventForUser(t *testing.T) {
|
||||||
n := NewNotifier(syncPositionBefore)
|
n := NewNotifier(syncPositionBefore)
|
||||||
|
@ -150,7 +189,7 @@ func TestNewInviteEventForUser(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TestNewInviteEventForUser error: %w", err)
|
t.Errorf("TestNewInviteEventForUser error: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -158,7 +197,7 @@ func TestNewInviteEventForUser(t *testing.T) {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
stream := lockedFetchUserStream(n, bob)
|
stream := lockedFetchUserStream(n, bob, bobDev)
|
||||||
waitForBlocking(stream, 1)
|
waitForBlocking(stream, 1)
|
||||||
|
|
||||||
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter)
|
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter)
|
||||||
|
@ -176,7 +215,7 @@ func TestEDUWakeup(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter))
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TestNewInviteEventForUser error: %w", err)
|
t.Errorf("TestNewInviteEventForUser error: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -184,7 +223,7 @@ func TestEDUWakeup(t *testing.T) {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
stream := lockedFetchUserStream(n, bob)
|
stream := lockedFetchUserStream(n, bob, bobDev)
|
||||||
waitForBlocking(stream, 1)
|
waitForBlocking(stream, 1)
|
||||||
|
|
||||||
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU)
|
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU)
|
||||||
|
@ -202,7 +241,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(3)
|
wg.Add(3)
|
||||||
poll := func() {
|
poll := func() {
|
||||||
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TestMultipleRequestWakeup error: %w", err)
|
t.Errorf("TestMultipleRequestWakeup error: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -213,7 +252,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
|
||||||
go poll()
|
go poll()
|
||||||
go poll()
|
go poll()
|
||||||
|
|
||||||
stream := lockedFetchUserStream(n, bob)
|
stream := lockedFetchUserStream(n, bob, bobDev)
|
||||||
waitForBlocking(stream, 3)
|
waitForBlocking(stream, 3)
|
||||||
|
|
||||||
n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
|
n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
|
||||||
|
@ -240,24 +279,24 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
|
||||||
// Make bob leave the room
|
// Make bob leave the room
|
||||||
leaveWG.Add(1)
|
leaveWG.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err)
|
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err)
|
||||||
}
|
}
|
||||||
mustEqualPositions(t, pos, syncPositionAfter)
|
mustEqualPositions(t, pos, syncPositionAfter)
|
||||||
leaveWG.Done()
|
leaveWG.Done()
|
||||||
}()
|
}()
|
||||||
bobStream := lockedFetchUserStream(n, bob)
|
bobStream := lockedFetchUserStream(n, bob, bobDev)
|
||||||
waitForBlocking(bobStream, 1)
|
waitForBlocking(bobStream, 1)
|
||||||
n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter)
|
n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter)
|
||||||
leaveWG.Wait()
|
leaveWG.Wait()
|
||||||
|
|
||||||
// send an event into the room. Make sure alice gets it. Bob should not.
|
// send an event into the room. Make sure alice gets it. Bob should not.
|
||||||
var aliceWG sync.WaitGroup
|
var aliceWG sync.WaitGroup
|
||||||
aliceStream := lockedFetchUserStream(n, alice)
|
aliceStream := lockedFetchUserStream(n, alice, aliceDev)
|
||||||
aliceWG.Add(1)
|
aliceWG.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter))
|
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionAfter))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err)
|
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -267,7 +306,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// this should timeout with an error (but the main goroutine won't wait for the timeout explicitly)
|
// this should timeout with an error (but the main goroutine won't wait for the timeout explicitly)
|
||||||
_, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter))
|
_, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil")
|
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil")
|
||||||
}
|
}
|
||||||
|
@ -300,7 +339,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait until something is Wait()ing on the user stream.
|
// Wait until something is Wait()ing on the user stream.
|
||||||
func waitForBlocking(s *UserStream, numBlocking uint) {
|
func waitForBlocking(s *UserDeviceStream, numBlocking uint) {
|
||||||
for numBlocking != s.NumWaiting() {
|
for numBlocking != s.NumWaiting() {
|
||||||
// This is horrible but I don't want to add a signalling mechanism JUST for testing.
|
// This is horrible but I don't want to add a signalling mechanism JUST for testing.
|
||||||
time.Sleep(1 * time.Microsecond)
|
time.Sleep(1 * time.Microsecond)
|
||||||
|
@ -309,16 +348,19 @@ func waitForBlocking(s *UserStream, numBlocking uint) {
|
||||||
|
|
||||||
// lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock.
|
// lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock.
|
||||||
// A new stream is made if it doesn't exist already.
|
// A new stream is made if it doesn't exist already.
|
||||||
func lockedFetchUserStream(n *Notifier, userID string) *UserStream {
|
func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStream {
|
||||||
n.streamLock.Lock()
|
n.streamLock.Lock()
|
||||||
defer n.streamLock.Unlock()
|
defer n.streamLock.Unlock()
|
||||||
|
|
||||||
return n.fetchUserStream(userID, true)
|
return n.fetchUserDeviceStream(userID, deviceID, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestSyncRequest(userID string, since types.StreamingToken) syncRequest {
|
func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syncRequest {
|
||||||
return syncRequest{
|
return syncRequest{
|
||||||
device: authtypes.Device{UserID: userID},
|
device: authtypes.Device{
|
||||||
|
UserID: userID,
|
||||||
|
ID: deviceID,
|
||||||
|
},
|
||||||
timeout: 1 * time.Minute,
|
timeout: 1 * time.Minute,
|
||||||
since: &since,
|
since: &since,
|
||||||
wantFullState: false,
|
wantFullState: false,
|
||||||
|
|
|
@ -47,7 +47,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
var syncData *types.Response
|
var syncData *types.Response
|
||||||
|
|
||||||
// Extract values from request
|
// Extract values from request
|
||||||
userID := device.UserID
|
|
||||||
syncReq, err := newSyncRequest(req, *device)
|
syncReq, err := newSyncRequest(req, *device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -56,7 +55,8 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger := util.GetLogger(req.Context()).WithFields(log.Fields{
|
logger := util.GetLogger(req.Context()).WithFields(log.Fields{
|
||||||
"userID": userID,
|
"userID": device.UserID,
|
||||||
|
"deviceID": device.ID,
|
||||||
"since": syncReq.since,
|
"since": syncReq.since,
|
||||||
"timeout": syncReq.timeout,
|
"timeout": syncReq.timeout,
|
||||||
"limit": syncReq.limit,
|
"limit": syncReq.limit,
|
||||||
|
@ -136,7 +136,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
||||||
// TODO: handle ignored users
|
// TODO: handle ignored users
|
||||||
if req.since == nil {
|
if req.since == nil {
|
||||||
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
|
res, err = rp.db.CompleteSync(req.ctx, req.device, req.limit)
|
||||||
} else {
|
} else {
|
||||||
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,12 +23,13 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserStream represents a communication mechanism between the /sync request goroutine
|
// UserDeviceStream represents a communication mechanism between the /sync request goroutine
|
||||||
// and the underlying sync server goroutines.
|
// and the underlying sync server goroutines.
|
||||||
// Goroutines can get a UserStreamListener to wait for updates, and can Broadcast()
|
// Goroutines can get a UserStreamListener to wait for updates, and can Broadcast()
|
||||||
// updates.
|
// updates.
|
||||||
type UserStream struct {
|
type UserDeviceStream struct {
|
||||||
UserID string
|
UserID string
|
||||||
|
DeviceID string
|
||||||
// The lock that protects changes to this struct
|
// The lock that protects changes to this struct
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
// Closed when there is an update.
|
// Closed when there is an update.
|
||||||
|
@ -41,18 +42,19 @@ type UserStream struct {
|
||||||
numWaiting uint
|
numWaiting uint
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserStreamListener allows a sync request to wait for updates for a user.
|
// UserDeviceStreamListener allows a sync request to wait for updates for a user.
|
||||||
type UserStreamListener struct {
|
type UserDeviceStreamListener struct {
|
||||||
userStream *UserStream
|
userStream *UserDeviceStream
|
||||||
|
|
||||||
// Whether the stream has been closed
|
// Whether the stream has been closed
|
||||||
hasClosed bool
|
hasClosed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserStream creates a new user stream
|
// NewUserDeviceStream creates a new user stream
|
||||||
func NewUserStream(userID string, currPos types.StreamingToken) *UserStream {
|
func NewUserDeviceStream(userID, deviceID string, currPos types.StreamingToken) *UserDeviceStream {
|
||||||
return &UserStream{
|
return &UserDeviceStream{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
timeOfLastChannel: time.Now(),
|
timeOfLastChannel: time.Now(),
|
||||||
pos: currPos,
|
pos: currPos,
|
||||||
signalChannel: make(chan struct{}),
|
signalChannel: make(chan struct{}),
|
||||||
|
@ -62,18 +64,18 @@ func NewUserStream(userID string, currPos types.StreamingToken) *UserStream {
|
||||||
// GetListener returns UserStreamListener that a sync request can use to wait
|
// GetListener returns UserStreamListener that a sync request can use to wait
|
||||||
// for new updates with.
|
// for new updates with.
|
||||||
// UserStreamListener must be closed
|
// UserStreamListener must be closed
|
||||||
func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
|
func (s *UserDeviceStream) GetListener(ctx context.Context) UserDeviceStreamListener {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
s.numWaiting++ // We decrement when UserStreamListener is closed
|
s.numWaiting++ // We decrement when UserStreamListener is closed
|
||||||
|
|
||||||
listener := UserStreamListener{
|
listener := UserDeviceStreamListener{
|
||||||
userStream: s,
|
userStream: s,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lets be a bit paranoid here and check that Close() is being called
|
// Lets be a bit paranoid here and check that Close() is being called
|
||||||
runtime.SetFinalizer(&listener, func(l *UserStreamListener) {
|
runtime.SetFinalizer(&listener, func(l *UserDeviceStreamListener) {
|
||||||
if !l.hasClosed {
|
if !l.hasClosed {
|
||||||
l.Close()
|
l.Close()
|
||||||
}
|
}
|
||||||
|
@ -83,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast a new sync position for this user.
|
// Broadcast a new sync position for this user.
|
||||||
func (s *UserStream) Broadcast(pos types.StreamingToken) {
|
func (s *UserDeviceStream) Broadcast(pos types.StreamingToken) {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
@ -96,7 +98,7 @@ func (s *UserStream) Broadcast(pos types.StreamingToken) {
|
||||||
|
|
||||||
// NumWaiting returns the number of goroutines waiting for waiting for updates.
|
// NumWaiting returns the number of goroutines waiting for waiting for updates.
|
||||||
// Used for metrics and testing.
|
// Used for metrics and testing.
|
||||||
func (s *UserStream) NumWaiting() uint {
|
func (s *UserDeviceStream) NumWaiting() uint {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
return s.numWaiting
|
return s.numWaiting
|
||||||
|
@ -105,7 +107,7 @@ func (s *UserStream) NumWaiting() uint {
|
||||||
// TimeOfLastNonEmpty returns the last time that the number of waiting listeners
|
// TimeOfLastNonEmpty returns the last time that the number of waiting listeners
|
||||||
// was non-empty, may be time.Now() if number of waiting listeners is currently
|
// was non-empty, may be time.Now() if number of waiting listeners is currently
|
||||||
// non-empty.
|
// non-empty.
|
||||||
func (s *UserStream) TimeOfLastNonEmpty() time.Time {
|
func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
@ -118,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time {
|
||||||
|
|
||||||
// GetSyncPosition returns last sync position which the UserStream was
|
// GetSyncPosition returns last sync position which the UserStream was
|
||||||
// notified about
|
// notified about
|
||||||
func (s *UserStreamListener) GetSyncPosition() types.StreamingToken {
|
func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken {
|
||||||
s.userStream.lock.Lock()
|
s.userStream.lock.Lock()
|
||||||
defer s.userStream.lock.Unlock()
|
defer s.userStream.lock.Unlock()
|
||||||
|
|
||||||
|
@ -130,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.StreamingToken {
|
||||||
// sincePos specifies from which point we want to be notified about. If there
|
// sincePos specifies from which point we want to be notified about. If there
|
||||||
// has already been an update after sincePos we'll return a closed channel
|
// has already been an update after sincePos we'll return a closed channel
|
||||||
// immediately.
|
// immediately.
|
||||||
func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} {
|
func (s *UserDeviceStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} {
|
||||||
s.userStream.lock.Lock()
|
s.userStream.lock.Lock()
|
||||||
defer s.userStream.lock.Unlock()
|
defer s.userStream.lock.Unlock()
|
||||||
|
|
||||||
|
@ -147,7 +149,7 @@ func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close cleans up resources used
|
// Close cleans up resources used
|
||||||
func (s *UserStreamListener) Close() {
|
func (s *UserDeviceStreamListener) Close() {
|
||||||
s.userStream.lock.Lock()
|
s.userStream.lock.Lock()
|
||||||
defer s.userStream.lock.Unlock()
|
defer s.userStream.lock.Unlock()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue