From 614e67280defda4a9156f620f2751e3ef136da81 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 18 Aug 2021 12:07:09 +0100 Subject: [PATCH] Delete device keys/signatures from key server when deleting devices (#1979) * Delete device keys/signatures from key server when deleting device from user API * Move loop to within database transaction * Don't fall over deleting no rows --- keyserver/api/api.go | 13 +++++++++++++ keyserver/internal/internal.go | 8 ++++++++ keyserver/inthttp/client.go | 18 ++++++++++++++++++ keyserver/inthttp/server.go | 11 +++++++++++ keyserver/storage/interface.go | 4 ++++ .../postgres/cross_signing_sigs_table.go | 15 +++++++++++++++ .../storage/postgres/device_keys_table.go | 12 ++++++++++++ keyserver/storage/shared/storage.go | 16 ++++++++++++++++ .../sqlite3/cross_signing_sigs_table.go | 15 +++++++++++++++ keyserver/storage/sqlite3/device_keys_table.go | 12 ++++++++++++ keyserver/storage/tables/interface.go | 2 ++ syncapi/internal/keychange_test.go | 2 ++ userapi/internal/api.go | 12 ++++++++++++ 13 files changed, 140 insertions(+) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 40120236..5a109cc6 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -34,6 +34,7 @@ type KeyInternalAPI interface { PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) // PerformClaimKeys claims one-time keys for use in pre-key messages PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) + PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) @@ -145,6 +146,18 @@ type PerformUploadKeysResponse struct { OneTimeKeyCounts []OneTimeKeysCount } +// PerformDeleteKeysRequest asks the keyserver to forget about certain +// keys, and signatures related to those keys. +type PerformDeleteKeysRequest struct { + UserID string + KeyIDs []gomatrixserverlib.KeyID +} + +// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. +type PerformDeleteKeysResponse struct { + Error *KeyError +} + // KeyError sets a key error field on KeyErrors func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { if r.KeyErrors[userID] == nil { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 47eda179..a546e94b 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -182,6 +182,14 @@ func (a *KeyInternalAPI) claimRemoteKeys( util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") } +func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) { + if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to delete device keys: %s", err), + } + } +} + func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) { count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index 15870571..f50789b8 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -30,6 +30,7 @@ const ( InputDeviceListUpdatePath = "/keyserver/inputDeviceListUpdate" PerformUploadKeysPath = "/keyserver/performUploadKeys" PerformClaimKeysPath = "/keyserver/performClaimKeys" + PerformDeleteKeysPath = "/keyserver/performDeleteKeys" PerformUploadDeviceKeysPath = "/keyserver/performUploadDeviceKeys" PerformUploadDeviceSignaturesPath = "/keyserver/performUploadDeviceSignatures" QueryKeysPath = "/keyserver/queryKeys" @@ -94,6 +95,23 @@ func (h *httpKeyInternalAPI) PerformClaimKeys( } } +func (h *httpKeyInternalAPI) PerformDeleteKeys( + ctx context.Context, + request *api.PerformDeleteKeysRequest, + response *api.PerformDeleteKeysResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys") + defer span.Finish() + + apiURL := h.apiURL + PerformClaimKeysPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.KeyError{ + Err: err.Error(), + } + } +} + func (h *httpKeyInternalAPI) PerformUploadKeys( ctx context.Context, request *api.PerformUploadKeysRequest, diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index 475544a5..8d557a76 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -47,6 +47,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformDeleteKeysPath, + httputil.MakeInternalAPI("performDeleteKeys", func(req *http.Request) util.JSONResponse { + request := api.PerformDeleteKeysRequest{} + response := api.PerformDeleteKeysResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + s.PerformDeleteKeys(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformUploadKeysPath, httputil.MakeInternalAPI("performUploadKeys", func(req *http.Request) util.JSONResponse { request := api.PerformUploadKeysRequest{} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index b9db81ad..99842bc5 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -58,6 +58,10 @@ type Database interface { // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying + // cross-signing signatures relating to that device. + DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error + // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index 677e7a48..e1185395 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -46,10 +46,14 @@ const upsertCrossSigningSigsForTargetSQL = "" + " VALUES($1, $2, $3, $4, $5)" + " ON CONFLICT (origin_user_id, target_user_id, target_key_id) DO UPDATE SET (origin_key_id, signature) = ($2, $5)" +const deleteCrossSigningSigsForTargetSQL = "" + + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + type crossSigningSigsStatements struct { db *sql.DB selectCrossSigningSigsForTargetStmt *sql.Stmt upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt *sql.Stmt } func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { @@ -63,6 +67,7 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro return s, sqlutil.StatementList{ {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, }.Prepare(db) } @@ -101,3 +106,13 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( } return nil } + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index e5f68fd0..5ae0da96 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -62,6 +62,9 @@ const selectMaxStreamForUserSQL = "" + const countStreamIDsForUserSQL = "" + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" +const deleteDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" @@ -72,6 +75,7 @@ type deviceKeysStatements struct { selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt countStreamIDsForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt deleteAllDeviceKeysStmt *sql.Stmt } @@ -98,6 +102,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { return nil, err } + if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { + return nil, err + } if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { return nil, err } @@ -163,6 +170,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx return nil } +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) return err diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 64ce53ef..5bd8be36 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -158,6 +158,22 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta }) } +// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying +// cross-signing signatures relating to that device. +func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + for _, deviceID := range deviceIDs { + if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) + } + if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + } + } + return nil + }) +} + // CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index aa702583..9abf5436 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -45,10 +45,14 @@ const upsertCrossSigningSigsForTargetSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + " VALUES($1, $2, $3, $4, $5)" +const deleteCrossSigningSigsForTargetSQL = "" + + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + type crossSigningSigsStatements struct { db *sql.DB selectCrossSigningSigsForTargetStmt *sql.Stmt upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt *sql.Stmt } func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { @@ -62,6 +66,7 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) return s, sqlutil.StatementList{ {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, }.Prepare(db) } @@ -100,3 +105,13 @@ func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( } return nil } + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index ca7ed9cf..fa1c930d 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" + const countStreamIDsForUserSQL = "" + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" +const deleteDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" @@ -67,6 +70,7 @@ type deviceKeysStatements struct { selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt deleteAllDeviceKeysStmt *sql.Stmt } @@ -90,12 +94,20 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } + if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { + return nil, err + } if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { return nil, err } return s, nil } +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) return err diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 0649b680..612eeb86 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -39,6 +39,7 @@ type DeviceKeys interface { SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error } @@ -62,4 +63,5 @@ type CrossSigningKeys interface { type CrossSigningSigs interface { SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 0c567a96..e52e5556 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -33,6 +33,8 @@ func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) { } +func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) { +} func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) { } func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index a2bc8ecf..518edef4 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -145,6 +145,18 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe if err != nil { return err } + // Ask the keyserver to delete device keys and signatures for those devices + deleteReq := &keyapi.PerformDeleteKeysRequest{ + UserID: req.UserID, + } + for _, keyID := range req.DeviceIDs { + deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) + } + deleteRes := &keyapi.PerformDeleteKeysResponse{} + a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes) + if err := deleteRes.Error; err != nil { + return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err) + } // create empty device keys and upload them to delete what was once there and trigger device list changes return a.deviceListUpdate(req.UserID, deletedDeviceIDs) }