diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 9808d623..30a65271 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -82,7 +82,7 @@ func Setup( func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, eduAPI, keys, federation, + cfg, rsAPI, eduAPI, keyAPI, keys, federation, ) }, )).Methods(http.MethodPut, http.MethodOptions) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 680eaccd..903a2f22 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal/config" + keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -37,6 +38,7 @@ func Send( cfg *config.Dendrite, rsAPI api.RoomserverInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, + keyAPI keyapi.KeyInternalAPI, keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -48,6 +50,7 @@ func Send( federation: federation, haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), newEvents: make(map[string]bool), + keyAPI: keyAPI, } var txnEvents struct { @@ -100,6 +103,7 @@ type txnReq struct { context context.Context rsAPI api.RoomserverInternalAPI eduAPI eduserverAPI.EDUServerInputAPI + keyAPI keyapi.KeyInternalAPI keys gomatrixserverlib.JSONVerifier federation txnFederationClient // local cache of events for auth checks, etc - this may include events @@ -308,12 +312,29 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { } } } + case gomatrixserverlib.MDeviceListUpdate: + t.processDeviceListUpdate(e) default: util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu") } } } +func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) { + var payload gomatrixserverlib.DeviceListUpdateEvent + if err := json.Unmarshal(e.Content, &payload); err != nil { + util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal device list update event") + return + } + var inputRes keyapi.InputDeviceListUpdateResponse + t.keyAPI.InputDeviceListUpdate(context.Background(), &keyapi.InputDeviceListUpdateRequest{ + Event: payload, + }, &inputRes) + if inputRes.Error != nil { + util.GetLogger(t.context).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") + } +} + func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error { prevEventIDs := e.PrevEventIDs() diff --git a/keyserver/api/api.go b/keyserver/api/api.go index c864b328..c3481a38 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -21,11 +21,14 @@ import ( "time" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) type KeyInternalAPI interface { // SetUserAPI assigns a user API to query when extracting device names. SetUserAPI(i userapi.UserInternalAPI) + // InputDeviceListUpdate from a federated server EDU + InputDeviceListUpdate(ctx context.Context, req *InputDeviceListUpdateRequest, res *InputDeviceListUpdateResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) // PerformClaimKeys claims one-time keys for use in pre-key messages PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) @@ -200,3 +203,11 @@ type QueryDeviceMessagesResponse struct { Devices []DeviceMessage Error *KeyError } + +type InputDeviceListUpdateRequest struct { + Event gomatrixserverlib.DeviceListUpdateEvent +} + +type InputDeviceListUpdateResponse struct { + Error *KeyError +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 474f30ff..d6e24566 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -38,12 +38,76 @@ type KeyInternalAPI struct { FedClient *gomatrixserverlib.FederationClient UserAPI userapi.UserInternalAPI Producer *producers.KeyChange + // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 + // request to the remote server and race. + // TODO: Put in an LRU cache to bound growth + UserIDToMutex map[string]*sync.Mutex + Mutex *sync.Mutex // protects UserIDToMutex } func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { a.UserAPI = i } +func (a *KeyInternalAPI) mutex(userID string) *sync.Mutex { + a.Mutex.Lock() + defer a.Mutex.Unlock() + if a.UserIDToMutex[userID] == nil { + a.UserIDToMutex[userID] = &sync.Mutex{} + } + return a.UserIDToMutex[userID] +} + +func (a *KeyInternalAPI) InputDeviceListUpdate( + ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse, +) { + mu := a.mutex(req.Event.UserID) + mu.Lock() + defer mu.Unlock() + // check if we have the prev IDs + exists, err := a.DB.PrevIDsExists(ctx, req.Event.UserID, req.Event.PrevID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to check if prev ids exist: %s", err), + } + return + } + + // if we haven't missed anything update the database and notify users + if exists { + keys := []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: req.Event.DeviceID, + DisplayName: req.Event.DeviceDisplayName, + KeyJSON: req.Event.Keys, + UserID: req.Event.UserID, + }, + StreamID: req.Event.StreamID, + }, + } + err = a.DB.StoreRemoteDeviceKeys(ctx, keys) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to store remote device keys: %s", err), + } + return + } + // ALWAYS emit key changes when we've been poked over federation just in case + // this poke is important for something. + err = a.Producer.ProduceKeyChanges(keys) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to emit remote device key changes: %s", err), + } + } + return + } + + // if we're missing an ID go and fetch it from the remote HS + +} + func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { if req.Partition < 0 { req.Partition = a.Producer.DefaultPartition() @@ -351,7 +415,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per return } // store the device keys and emit changes - err := a.DB.StoreDeviceKeys(ctx, keysToStore) + err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index 93b19051..98200022 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -27,12 +27,13 @@ import ( // HTTP paths for the internal HTTP APIs const ( - PerformUploadKeysPath = "/keyserver/performUploadKeys" - PerformClaimKeysPath = "/keyserver/performClaimKeys" - QueryKeysPath = "/keyserver/queryKeys" - QueryKeyChangesPath = "/keyserver/queryKeyChanges" - QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" - QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages" + InputDeviceListUpdatePath = "/keyserver/inputDeviceListUpdate" + PerformUploadKeysPath = "/keyserver/performUploadKeys" + PerformClaimKeysPath = "/keyserver/performClaimKeys" + QueryKeysPath = "/keyserver/queryKeys" + QueryKeyChangesPath = "/keyserver/queryKeyChanges" + QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" + QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages" ) // NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. @@ -58,6 +59,20 @@ type httpKeyInternalAPI struct { func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { // no-op: doesn't need it } +func (h *httpKeyInternalAPI) InputDeviceListUpdate( + ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputDeviceListUpdate") + defer span.Finish() + + apiURL := h.apiURL + InputDeviceListUpdatePath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + } +} func (h *httpKeyInternalAPI) PerformClaimKeys( ctx context.Context, diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index f0cd3038..7dfaed2e 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -25,6 +25,17 @@ import ( ) func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { + internalAPIMux.Handle(InputDeviceListUpdatePath, + httputil.MakeInternalAPI("inputDeviceListUpdate", func(req *http.Request) util.JSONResponse { + request := api.InputDeviceListUpdateRequest{} + response := api.InputDeviceListUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + s.InputDeviceListUpdate(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformClaimKeysPath, httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse { request := api.PerformClaimKeysRequest{} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 36bedf34..79d9cec9 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -15,6 +15,8 @@ package keyserver import ( + "sync" + "github.com/Shopify/sarama" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/config" @@ -51,9 +53,11 @@ func NewInternalAPI( DB: db, } return &internal.KeyInternalAPI{ - DB: db, - ThisServer: cfg.Matrix.ServerName, - FedClient: fedClient, - Producer: keyChangeProducer, + DB: db, + ThisServer: cfg.Matrix.ServerName, + FedClient: fedClient, + Producer: keyChangeProducer, + Mutex: &sync.Mutex{}, + UserIDToMutex: make(map[string]*sync.Mutex), } } diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 11284d86..f67bbf71 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -35,11 +35,18 @@ type Database interface { // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. // Returns an error if there was a problem storing the keys. - StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index e1b4e947..d321860d 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "time" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" @@ -56,12 +57,16 @@ const selectBatchDeviceKeysSQL = "" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" + type deviceKeysStatements struct { db *sql.DB upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -84,6 +89,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } + if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { + return nil, err + } return s, nil } @@ -115,6 +123,19 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn return } +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + // nullable if there are no results + var count sql.NullInt32 + err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int32), nil + } + return 0, nil +} + func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { for _, key := range keys { now := time.Now().Unix() diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index e78ee943..78729774 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -47,7 +47,25 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { +func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { + sids := make([]int64, len(prevIDs)) + for i := range prevIDs { + sids[i] = int64(prevIDs[i]) + } + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids) + if err != nil { + return false, err + } + return count == len(prevIDs), nil +} + +func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { // work out the latest stream IDs for each user userIDToStreamID := make(map[string]int) for _, k := range keys { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 900d1238..15d9c775 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" "github.com/matrix-org/dendrite/internal" @@ -53,6 +54,9 @@ const selectBatchDeviceKeysSQL = "" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" + type deviceKeysStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -143,6 +147,25 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn return } +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + iStreamIDs := make([]interface{}, len(streamIDs)+1) + iStreamIDs[0] = userID + for i := range streamIDs { + iStreamIDs[i+1] = streamIDs[i] + } + query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) + // nullable if there are no results + var count sql.NullInt32 + err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int32), nil + } + return 0, nil +} + func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { for _, key := range keys { diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 949d9dd6..ec1b299f 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -126,15 +126,15 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { // StreamID: 2 as this is a 2nd device key }, } - MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) if msgs[0].StreamID != 1 { - t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) } if msgs[1].StreamID != 1 { - t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) } if msgs[2].StreamID != 2 { - t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) } // updating a device sets the next stream ID for that user @@ -148,9 +148,9 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { // StreamID: 3 }, } - MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) if msgs[0].StreamID != 3 { - t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) } // Querying for device keys returns the latest stream IDs diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 65da3310..ac932d56 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -35,6 +35,7 @@ type DeviceKeys interface { SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) + CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index cc33c738..6765fa65 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -44,6 +44,9 @@ func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneT } func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) { +} +func (k *mockKeyAPI) InputDeviceListUpdate(ctx context.Context, req *keyapi.InputDeviceListUpdateRequest, res *keyapi.InputDeviceListUpdateResponse) { + } type mockCurrentStateAPI struct {