diff --git a/build/scripts/build-test-lint.sh b/build/scripts/build-test-lint.sh index d2b2b4b1..4b18ca2f 100755 --- a/build/scripts/build-test-lint.sh +++ b/build/scripts/build-test-lint.sh @@ -13,4 +13,4 @@ go build ./cmd/... ./scripts/find-lint.sh echo "Testing..." -go test ./... +go test -v ./... diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 3e9ff053..7e1a40fd 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -58,7 +58,7 @@ type Database interface { // ID. IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) // CompleteSync returns a complete /sync API response for the given user. - CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) + CompleteSync(ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error) // GetAccountDataInRange returns all account data for a given user inserted or // updated between two given positions // Returns a map following the format data[roomID] = []dataTypes diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 145a3dbf..888f85e0 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -666,10 +666,10 @@ func (d *Database) getResponseWithPDUsForCompleteSync( } func (d *Database) CompleteSync( - ctx context.Context, userID string, numRecentEventsPerRoom int, + ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int, ) (*types.Response, error) { res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( - ctx, userID, numRecentEventsPerRoom, + ctx, device.UserID, numRecentEventsPerRoom, ) if err != nil { return nil, err diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 6b045852..bb8554f4 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -181,7 +181,7 @@ func TestSyncResponse(t *testing.T) { Name: "CompleteSync limited", DoSync: func() (*types.Response, error) { // limit set to 5 - return db.CompleteSync(ctx, testUserIDA, 5) + return db.CompleteSync(ctx, testUserDeviceA, 5) }, // want the last 5 events WantTimeline: events[len(events)-5:], @@ -193,7 +193,7 @@ func TestSyncResponse(t *testing.T) { { Name: "CompleteSync", DoSync: func() (*types.Response, error) { - return db.CompleteSync(ctx, testUserIDA, len(events)+1) + return db.CompleteSync(ctx, testUserDeviceA, len(events)+1) }, WantTimeline: events, // We want no state at all as that field in /sync is the delta between the token (beginning of time) diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index b3ed5cd0..9b410a0c 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -37,8 +37,8 @@ type Notifier struct { streamLock *sync.Mutex // The latest sync position currPos types.StreamingToken - // A map of user_id => UserStream which can be used to wake a given user's /sync request. - userStreams map[string]*UserStream + // A map of user_id => device_id => UserStream which can be used to wake a given user's /sync request. + userDeviceStreams map[string]map[string]*UserDeviceStream // The last time we cleaned out stale entries from the userStreams map lastCleanUpTime time.Time } @@ -50,7 +50,7 @@ func NewNotifier(pos types.StreamingToken) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), - userStreams: make(map[string]*UserStream), + userDeviceStreams: make(map[string]map[string]*UserDeviceStream), streamLock: &sync.Mutex{}, lastCleanUpTime: time.Now(), } @@ -123,7 +123,7 @@ func (n *Notifier) OnNewEvent( // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos -func (n *Notifier) GetListener(req syncRequest) UserStreamListener { +func (n *Notifier) GetListener(req syncRequest) UserDeviceStreamListener { // Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298 // - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID // - Incoming events wake requests for a matching room ID @@ -137,7 +137,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { n.removeEmptyUserStreams() - return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx) + return n.fetchUserDeviceStream(req.device.UserID, req.device.ID, true).GetListener(req.ctx) } // Load the membership states required to notify users correctly. @@ -173,27 +173,69 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } +// wakeupUsers will wake up the sync strems for all of the devices for all of the +// specified user IDs. func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { for _, userID := range userIDs { - stream := n.fetchUserStream(userID, false) - if stream != nil { + for _, stream := range n.fetchUserStreams(userID) { + if stream == nil { + continue + } stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } } } -// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, +// wakeupUserDevice will wake up the sync stream for a specific user device. Other +// device streams will be left alone. +// nolint:unused +func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) { + for userID, deviceID := range userDevices { + if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } + } +} + +// fetchUserDeviceStream retrieves a stream unique to the given device. If makeIfNotExists is true, +// a stream will be made for this device if one doesn't exist and it will be returned. This +// function does not wait for data to be available on the stream. +// NB: Callers should have locked the mutex before calling this function. +func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream { + _, ok := n.userDeviceStreams[userID] + if !ok { + if !makeIfNotExists { + return nil + } + n.userDeviceStreams[userID] = map[string]*UserDeviceStream{} + } + stream, ok := n.userDeviceStreams[userID][deviceID] + if !ok { + if !makeIfNotExists { + return nil + } + // TODO: Unbounded growth of streams (1 per user) + if stream = NewUserDeviceStream(userID, deviceID, n.currPos); stream != nil { + n.userDeviceStreams[userID][deviceID] = stream + } + } + return stream +} + +// fetchUserStreams retrieves all streams for the given user. If makeIfNotExists is true, // a stream will be made for this user if one doesn't exist and it will be returned. This // function does not wait for data to be available on the stream. // NB: Callers should have locked the mutex before calling this function. -func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { - stream, ok := n.userStreams[userID] - if !ok && makeIfNotExists { - // TODO: Unbounded growth of streams (1 per user) - stream = NewUserStream(userID, n.currPos) - n.userStreams[userID] = stream +func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream { + user, ok := n.userDeviceStreams[userID] + if !ok { + return []*UserDeviceStream{} } - return stream + streams := []*UserDeviceStream{} + for _, stream := range user { + streams = append(streams, stream) + } + return streams } // Not thread-safe: must be called on the OnNewEvent goroutine only @@ -236,9 +278,14 @@ func (n *Notifier) removeEmptyUserStreams() { n.lastCleanUpTime = now deleteBefore := now.Add(-5 * time.Minute) - for key, value := range n.userStreams { - if value.TimeOfLastNonEmpty().Before(deleteBefore) { - delete(n.userStreams, key) + for user, byUser := range n.userDeviceStreams { + for device, stream := range byUser { + if stream.TimeOfLastNonEmpty().Before(deleteBefore) { + delete(n.userDeviceStreams[user], device) + } + if len(n.userDeviceStreams[user]) == 0 { + delete(n.userDeviceStreams, user) + } } } } diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 7d979fcc..14ddef20 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -41,9 +41,11 @@ var ( ) var ( - roomID = "!test:localhost" - alice = "@alice:localhost" - bob = "@bob:localhost" + roomID = "!test:localhost" + alice = "@alice:localhost" + aliceDev = "alicedevice" + bob = "@bob:localhost" + bobDev = "bobdev" ) func init() { @@ -107,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { n := NewNotifier(syncPositionBefore) - pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld)) + pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) if err != nil { t.Fatalf("TestImmediateNotification error: %s", err) } @@ -124,7 +126,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndJoinedToRoom error: %w", err) } @@ -132,7 +134,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) @@ -140,6 +142,43 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { wg.Wait() } +func TestCorrectStream(t *testing.T) { + n := NewNotifier(syncPositionBefore) + stream := lockedFetchUserStream(n, bob, bobDev) + if stream.UserID != bob { + t.Fatalf("expected user %q, got %q", bob, stream.UserID) + } + if stream.DeviceID != bobDev { + t.Fatalf("expected device %q, got %q", bobDev, stream.DeviceID) + } +} + +func TestCorrectStreamWakeup(t *testing.T) { + n := NewNotifier(syncPositionBefore) + awoken := make(chan string) + + streamone := lockedFetchUserStream(n, alice, "one") + streamtwo := lockedFetchUserStream(n, alice, "two") + + go func() { + select { + case <-streamone.signalChannel: + awoken <- "one" + case <-streamtwo.signalChannel: + awoken <- "two" + } + }() + + time.Sleep(1 * time.Second) + + wake := "two" + n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter) + + if result := <-awoken; result != wake { + t.Fatalf("expected to wake %q, got %q", wake, result) + } +} + // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { n := NewNotifier(syncPositionBefore) @@ -150,7 +189,7 @@ func TestNewInviteEventForUser(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } @@ -158,7 +197,7 @@ func TestNewInviteEventForUser(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) @@ -176,7 +215,7 @@ func TestEDUWakeup(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err != nil { t.Errorf("TestNewInviteEventForUser error: %w", err) } @@ -184,7 +223,7 @@ func TestEDUWakeup(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 1) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) @@ -202,7 +241,7 @@ func TestMultipleRequestWakeup(t *testing.T) { var wg sync.WaitGroup wg.Add(3) poll := func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestMultipleRequestWakeup error: %w", err) } @@ -213,7 +252,7 @@ func TestMultipleRequestWakeup(t *testing.T) { go poll() go poll() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(stream, 3) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) @@ -240,24 +279,24 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // Make bob leave the room leaveWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } mustEqualPositions(t, pos, syncPositionAfter) leaveWG.Done() }() - bobStream := lockedFetchUserStream(n, bob) + bobStream := lockedFetchUserStream(n, bob, bobDev) waitForBlocking(bobStream, 1) n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) leaveWG.Wait() // send an event into the room. Make sure alice gets it. Bob should not. var aliceWG sync.WaitGroup - aliceStream := lockedFetchUserStream(n, alice) + aliceStream := lockedFetchUserStream(n, alice, aliceDev) aliceWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter)) + pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionAfter)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) } @@ -267,7 +306,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { go func() { // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) - _, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) + _, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) if err == nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") } @@ -300,7 +339,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) { } // Wait until something is Wait()ing on the user stream. -func waitForBlocking(s *UserStream, numBlocking uint) { +func waitForBlocking(s *UserDeviceStream, numBlocking uint) { for numBlocking != s.NumWaiting() { // This is horrible but I don't want to add a signalling mechanism JUST for testing. time.Sleep(1 * time.Microsecond) @@ -309,16 +348,19 @@ func waitForBlocking(s *UserStream, numBlocking uint) { // lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock. // A new stream is made if it doesn't exist already. -func lockedFetchUserStream(n *Notifier, userID string) *UserStream { +func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStream { n.streamLock.Lock() defer n.streamLock.Unlock() - return n.fetchUserStream(userID, true) + return n.fetchUserDeviceStream(userID, deviceID, true) } -func newTestSyncRequest(userID string, since types.StreamingToken) syncRequest { +func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syncRequest { return syncRequest{ - device: authtypes.Device{UserID: userID}, + device: authtypes.Device{ + UserID: userID, + ID: deviceID, + }, timeout: 1 * time.Minute, since: &since, wantFullState: false, diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 82ce283b..bd29b333 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -47,7 +47,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype var syncData *types.Response // Extract values from request - userID := device.UserID syncReq, err := newSyncRequest(req, *device) if err != nil { return util.JSONResponse{ @@ -56,10 +55,11 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } } logger := util.GetLogger(req.Context()).WithFields(log.Fields{ - "userID": userID, - "since": syncReq.since, - "timeout": syncReq.timeout, - "limit": syncReq.limit, + "userID": device.UserID, + "deviceID": device.ID, + "since": syncReq.since, + "timeout": syncReq.timeout, + "limit": syncReq.limit, }) currPos := rp.notifier.CurrentPosition() @@ -136,7 +136,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { // TODO: handle ignored users if req.since == nil { - res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) + res, err = rp.db.CompleteSync(req.ctx, req.device, req.limit) } else { res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState) } diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index b2eafa3d..ff9a4d00 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -23,12 +23,13 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" ) -// UserStream represents a communication mechanism between the /sync request goroutine +// UserDeviceStream represents a communication mechanism between the /sync request goroutine // and the underlying sync server goroutines. // Goroutines can get a UserStreamListener to wait for updates, and can Broadcast() // updates. -type UserStream struct { - UserID string +type UserDeviceStream struct { + UserID string + DeviceID string // The lock that protects changes to this struct lock sync.Mutex // Closed when there is an update. @@ -41,18 +42,19 @@ type UserStream struct { numWaiting uint } -// UserStreamListener allows a sync request to wait for updates for a user. -type UserStreamListener struct { - userStream *UserStream +// UserDeviceStreamListener allows a sync request to wait for updates for a user. +type UserDeviceStreamListener struct { + userStream *UserDeviceStream // Whether the stream has been closed hasClosed bool } -// NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { - return &UserStream{ +// NewUserDeviceStream creates a new user stream +func NewUserDeviceStream(userID, deviceID string, currPos types.StreamingToken) *UserDeviceStream { + return &UserDeviceStream{ UserID: userID, + DeviceID: deviceID, timeOfLastChannel: time.Now(), pos: currPos, signalChannel: make(chan struct{}), @@ -62,18 +64,18 @@ func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { // GetListener returns UserStreamListener that a sync request can use to wait // for new updates with. // UserStreamListener must be closed -func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { +func (s *UserDeviceStream) GetListener(ctx context.Context) UserDeviceStreamListener { s.lock.Lock() defer s.lock.Unlock() s.numWaiting++ // We decrement when UserStreamListener is closed - listener := UserStreamListener{ + listener := UserDeviceStreamListener{ userStream: s, } // Lets be a bit paranoid here and check that Close() is being called - runtime.SetFinalizer(&listener, func(l *UserStreamListener) { + runtime.SetFinalizer(&listener, func(l *UserDeviceStreamListener) { if !l.hasClosed { l.Close() } @@ -83,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { } // Broadcast a new sync position for this user. -func (s *UserStream) Broadcast(pos types.StreamingToken) { +func (s *UserDeviceStream) Broadcast(pos types.StreamingToken) { s.lock.Lock() defer s.lock.Unlock() @@ -96,7 +98,7 @@ func (s *UserStream) Broadcast(pos types.StreamingToken) { // NumWaiting returns the number of goroutines waiting for waiting for updates. // Used for metrics and testing. -func (s *UserStream) NumWaiting() uint { +func (s *UserDeviceStream) NumWaiting() uint { s.lock.Lock() defer s.lock.Unlock() return s.numWaiting @@ -105,7 +107,7 @@ func (s *UserStream) NumWaiting() uint { // TimeOfLastNonEmpty returns the last time that the number of waiting listeners // was non-empty, may be time.Now() if number of waiting listeners is currently // non-empty. -func (s *UserStream) TimeOfLastNonEmpty() time.Time { +func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time { s.lock.Lock() defer s.lock.Unlock() @@ -118,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { // GetSyncPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { +func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -130,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { +func (s *UserDeviceStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -147,7 +149,7 @@ func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-c } // Close cleans up resources used -func (s *UserStreamListener) Close() { +func (s *UserDeviceStreamListener) Close() { s.userStream.lock.Lock() defer s.userStream.lock.Unlock()