diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 25503692..2a8650db 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -300,6 +300,11 @@ sync_api: max_idle_conns: 2 conn_max_lifetime: -1 + # This option controls which HTTP header to inspect to find the real remote IP + # address of the client. This is likely required if Dendrite is running behind + # a reverse proxy server. + # real_ip_header: X-Real-IP + # Configuration for the User API. user_api: internal_api: diff --git a/internal/config/config_syncapi.go b/internal/config/config_syncapi.go index 0a96e41c..fc08f738 100644 --- a/internal/config/config_syncapi.go +++ b/internal/config/config_syncapi.go @@ -7,6 +7,8 @@ type SyncAPI struct { ExternalAPI ExternalAPIOptions `yaml:"external_api"` Database DatabaseOptions `yaml:"database"` + + RealIPHeader string `yaml:"real_ip_header"` } func (c *SyncAPI) Defaults() { diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go index cbf8b726..265d6ee6 100644 --- a/internal/mscs/msc2836/msc2836_test.go +++ b/internal/mscs/msc2836/msc2836_test.go @@ -457,6 +457,9 @@ func (u *testUserAPI) PerformDeviceDeletion(ctx context.Context, req *userapi.Pe func (u *testUserAPI) PerformDeviceUpdate(ctx context.Context, req *userapi.PerformDeviceUpdateRequest, res *userapi.PerformDeviceUpdateResponse) error { return nil } +func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { + return nil +} func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { return nil } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 8a79737a..61f8c46f 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -19,10 +19,14 @@ package sync import ( "context" "fmt" + "net" "net/http" + "strings" + "sync" "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/config" keyapi "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" @@ -37,18 +41,62 @@ import ( // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { db storage.Database + cfg *config.SyncAPI userAPI userapi.UserInternalAPI notifier *Notifier keyAPI keyapi.KeyInternalAPI rsAPI roomserverAPI.RoomserverInternalAPI + lastseen sync.Map } // NewRequestPool makes a new RequestPool func NewRequestPool( - db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI, + db storage.Database, cfg *config.SyncAPI, n *Notifier, + userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) *RequestPool { - return &RequestPool{db, userAPI, n, keyAPI, rsAPI} + rp := &RequestPool{db, cfg, userAPI, n, keyAPI, rsAPI, sync.Map{}} + go rp.cleanLastSeen() + return rp +} + +func (rp *RequestPool) cleanLastSeen() { + for { + rp.lastseen.Range(func(key interface{}, _ interface{}) bool { + rp.lastseen.Delete(key) + return true + }) + time.Sleep(time.Minute) + } +} + +func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) { + if _, ok := rp.lastseen.LoadOrStore(device.UserID+device.ID, struct{}{}); ok { + return + } + + remoteAddr := req.RemoteAddr + if rp.cfg.RealIPHeader != "" { + if header := req.Header.Get(rp.cfg.RealIPHeader); header != "" { + // TODO: Maybe this isn't great but it will satisfy both X-Real-IP + // and X-Forwarded-For (which can be a list where the real client + // address is the first listed address). Make more intelligent? + addresses := strings.Split(header, ",") + if ip := net.ParseIP(addresses[0]); ip != nil { + remoteAddr = addresses[0] + } + } + } + + lsreq := &userapi.PerformLastSeenUpdateRequest{ + UserID: device.UserID, + DeviceID: device.ID, + RemoteAddr: remoteAddr, + } + lsres := &userapi.PerformLastSeenUpdateResponse{} + go rp.userAPI.PerformLastSeenUpdate(req.Context(), lsreq, lsres) // nolint:errcheck + + rp.lastseen.Store(device.UserID+device.ID, time.Now()) } // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be @@ -74,6 +122,8 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. "limit": syncReq.limit, }) + rp.updateLastSeen(req, device) + currPos := rp.notifier.CurrentPosition() if rp.shouldReturnImmediately(syncReq) { diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 393a7aa5..7e277ba1 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -61,7 +61,7 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start notifier") } - requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, rsAPI) + requestPool := sync.NewRequestPool(syncDB, cfg, notifier, userAPI, keyAPI, rsAPI) keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputKeyChangeEvent)), diff --git a/userapi/api/api.go b/userapi/api/api.go index 6c3f3c69..809ba047 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -29,6 +29,7 @@ type UserInternalAPI interface { PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error + PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -183,6 +184,17 @@ type PerformPasswordUpdateResponse struct { Account *Account } +// PerformLastSeenUpdateRequest is the request for PerformLastSeenUpdate. +type PerformLastSeenUpdateRequest struct { + UserID string + DeviceID string + RemoteAddr string +} + +// PerformLastSeenUpdateResponse is the response for PerformLastSeenUpdate. +type PerformLastSeenUpdateResponse struct { +} + // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 81d00241..3b5f4978 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -172,6 +172,21 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er return nil } +func (a *UserInternalAPI) PerformLastSeenUpdate( + ctx context.Context, + req *api.PerformLastSeenUpdateRequest, + res *api.PerformLastSeenUpdateResponse, +) error { + localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { + return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) + } + return nil +} + func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 4d9dcc41..680e4cb5 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -32,6 +32,7 @@ const ( PerformAccountCreationPath = "/userapi/performAccountCreation" PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" + PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" @@ -119,6 +120,18 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformLastSeenUpdate( + ctx context.Context, + req *api.PerformLastSeenUpdateRequest, + res *api.PerformLastSeenUpdateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLastSeen") + defer span.Finish() + + apiURL := h.apiURL + PerformLastSeenUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") defer span.Finish() diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 81e936e5..e495e353 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -65,6 +65,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformLastSeenUpdatePath, + httputil.MakeInternalAPI("performLastSeenUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformLastSeenUpdateRequest{} + response := api.PerformLastSeenUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLastSeenUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceUpdatePath, httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceUpdateRequest{} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 9953ba06..95fe99f3 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -33,9 +33,9 @@ type Database interface { // Returns the device on success. CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) - UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index cc554fe7..7de9f5f9 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -95,7 +95,7 @@ const selectDevicesByIDSQL = "" + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE device_id = $3" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" type devicesStatements struct { insertDeviceStmt *sql.Stmt @@ -310,9 +310,9 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, deviceID, ipAddr string) error { +func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) return err } diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index e318b260..6dd18b09 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -205,8 +205,8 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, deviceID, ipAddr) + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index cdfe2bb9..955d8ac7 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -80,7 +80,7 @@ const selectDevicesByIDSQL = "" + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE device_id = $3" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" type devicesStatements struct { db *sql.DB @@ -314,9 +314,9 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, deviceID, ipAddr string) error { +func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) return err } diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 25888eae..2eefb3f3 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -207,8 +207,8 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, deviceID, ipAddr) + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) }