Implement claiming one-time keys locally (#1210)
* Add API shape for claiming keys * Implement claiming one-time keys locally Fairly boring, nothing too special going on.main
parent
d76eb1b994
commit
1d72ce8b7a
|
@ -117,3 +117,40 @@ func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type claimKeysRequest struct {
|
||||||
|
TimeoutMS int `json:"timeout"`
|
||||||
|
// The keys to be claimed. A map from user ID, to a map from device ID to algorithm name.
|
||||||
|
OneTimeKeys map[string]map[string]string `json:"one_time_keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *claimKeysRequest) GetTimeout() time.Duration {
|
||||||
|
if r.TimeoutMS == 0 {
|
||||||
|
return 10 * time.Second
|
||||||
|
}
|
||||||
|
return time.Duration(r.TimeoutMS) * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClaimKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse {
|
||||||
|
var r claimKeysRequest
|
||||||
|
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
claimRes := api.PerformClaimKeysResponse{}
|
||||||
|
keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
|
||||||
|
OneTimeKeys: r.OneTimeKeys,
|
||||||
|
Timeout: r.GetTimeout(),
|
||||||
|
}, &claimRes)
|
||||||
|
if claimRes.Error != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: map[string]interface{}{
|
||||||
|
"one_time_keys": claimRes.OneTimeKeys,
|
||||||
|
"failures": claimRes.Failures,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -714,4 +714,9 @@ func Setup(
|
||||||
return QueryKeys(req, keyAPI)
|
return QueryKeys(req, keyAPI)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
r0mux.Handle("/keys/claim",
|
||||||
|
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return ClaimKeys(req, keyAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
|
|
||||||
type KeyInternalAPI interface {
|
type KeyInternalAPI interface {
|
||||||
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
|
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
|
||||||
|
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
||||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
||||||
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
|
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
|
||||||
}
|
}
|
||||||
|
@ -102,9 +103,17 @@ func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyEr
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformClaimKeysRequest struct {
|
type PerformClaimKeysRequest struct {
|
||||||
|
// Map of user_id to device_id to algorithm name
|
||||||
|
OneTimeKeys map[string]map[string]string
|
||||||
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformClaimKeysResponse struct {
|
type PerformClaimKeysResponse struct {
|
||||||
|
// Map of user_id to device_id to algorithm:key_id to key JSON
|
||||||
|
OneTimeKeys map[string]map[string]map[string]json.RawMessage
|
||||||
|
// Map of remote server domain to error JSON
|
||||||
|
Failures map[string]interface{}
|
||||||
|
// Set if there was a fatal error processing this action
|
||||||
Error *KeyError
|
Error *KeyError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,9 +37,39 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
|
||||||
a.uploadDeviceKeys(ctx, req, res)
|
a.uploadDeviceKeys(ctx, req, res)
|
||||||
a.uploadOneTimeKeys(ctx, req, res)
|
a.uploadOneTimeKeys(ctx, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
|
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
|
||||||
|
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
|
||||||
|
res.Failures = make(map[string]interface{})
|
||||||
|
// wrap request map in a top-level by-domain map
|
||||||
|
domainToDeviceKeys := make(map[string]map[string]map[string]string)
|
||||||
|
for userID, val := range req.OneTimeKeys {
|
||||||
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
continue // ignore invalid users
|
||||||
|
}
|
||||||
|
nested, ok := domainToDeviceKeys[string(serverName)]
|
||||||
|
if !ok {
|
||||||
|
nested = make(map[string]map[string]string)
|
||||||
|
}
|
||||||
|
nested[userID] = val
|
||||||
|
domainToDeviceKeys[string(serverName)] = nested
|
||||||
|
}
|
||||||
|
// claim local keys
|
||||||
|
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
|
||||||
|
keys, err := a.DB.ClaimKeys(ctx, local)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mergeInto(res.OneTimeKeys, keys)
|
||||||
|
delete(domainToDeviceKeys, string(a.ThisServer))
|
||||||
|
}
|
||||||
|
// TODO: claim remote keys
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
||||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||||
res.Failures = make(map[string]interface{})
|
res.Failures = make(map[string]interface{})
|
||||||
|
@ -166,3 +196,19 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
|
||||||
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
|
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
|
||||||
// TODO
|
// TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) {
|
||||||
|
for _, key := range src {
|
||||||
|
_, ok := dst[key.UserID]
|
||||||
|
if !ok {
|
||||||
|
dst[key.UserID] = make(map[string]map[string]json.RawMessage)
|
||||||
|
}
|
||||||
|
_, ok = dst[key.UserID][key.DeviceID]
|
||||||
|
if !ok {
|
||||||
|
dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
|
||||||
|
}
|
||||||
|
for keyID, keyJSON := range key.KeyJSON {
|
||||||
|
dst[key.UserID][key.DeviceID][keyID] = keyJSON
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -39,4 +39,8 @@ type Database interface {
|
||||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
||||||
|
|
||||||
|
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
||||||
|
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||||
|
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,11 +52,19 @@ const selectKeysSQL = "" +
|
||||||
const selectKeysCountSQL = "" +
|
const selectKeysCountSQL = "" +
|
||||||
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||||
|
|
||||||
|
const deleteOneTimeKeySQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||||
|
|
||||||
|
const selectKeyByAlgorithmSQL = "" +
|
||||||
|
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||||
|
|
||||||
type oneTimeKeysStatements struct {
|
type oneTimeKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertKeysStmt *sql.Stmt
|
upsertKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
selectKeysCountStmt *sql.Stmt
|
selectKeysCountStmt *sql.Stmt
|
||||||
|
selectKeyByAlgorithmStmt *sql.Stmt
|
||||||
|
deleteOneTimeKeyStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
|
@ -76,6 +84,12 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
|
||||||
return rows.Err()
|
return rows.Err()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||||
|
) (map[string]json.RawMessage, error) {
|
||||||
|
var keyID string
|
||||||
|
var keyJSON string
|
||||||
|
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||||
|
return map[string]json.RawMessage{
|
||||||
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
@ -48,3 +49,26 @@ func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) e
|
||||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
||||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||||
|
var result []api.OneTimeKeys
|
||||||
|
err := sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||||
|
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
|
||||||
|
for deviceID, algo := range deviceToAlgo {
|
||||||
|
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if keyJSON != nil {
|
||||||
|
result = append(result, api.OneTimeKeys{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
KeyJSON: keyJSON,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
|
@ -52,11 +52,19 @@ const selectKeysSQL = "" +
|
||||||
const selectKeysCountSQL = "" +
|
const selectKeysCountSQL = "" +
|
||||||
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||||
|
|
||||||
|
const deleteOneTimeKeySQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||||
|
|
||||||
|
const selectKeyByAlgorithmSQL = "" +
|
||||||
|
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||||
|
|
||||||
type oneTimeKeysStatements struct {
|
type oneTimeKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertKeysStmt *sql.Stmt
|
upsertKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
selectKeysCountStmt *sql.Stmt
|
selectKeysCountStmt *sql.Stmt
|
||||||
|
selectKeyByAlgorithmStmt *sql.Stmt
|
||||||
|
deleteOneTimeKeyStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
|
@ -76,6 +84,12 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
|
||||||
return rows.Err()
|
return rows.Err()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||||
|
) (map[string]json.RawMessage, error) {
|
||||||
|
var keyID string
|
||||||
|
var keyJSON string
|
||||||
|
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||||
|
return map[string]json.RawMessage{
|
||||||
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ package tables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
@ -24,6 +25,9 @@ import (
|
||||||
type OneTimeKeys interface {
|
type OneTimeKeys interface {
|
||||||
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||||
|
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
|
||||||
|
// Returns an empty map if the key does not exist.
|
||||||
|
SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceKeys interface {
|
type DeviceKeys interface {
|
||||||
|
|
|
@ -124,6 +124,7 @@ Should reject keys claiming to belong to a different user
|
||||||
Can query device keys using POST
|
Can query device keys using POST
|
||||||
Can query specific device keys using POST
|
Can query specific device keys using POST
|
||||||
query for user with no keys returns empty key dict
|
query for user with no keys returns empty key dict
|
||||||
|
Can claim one time key using POST
|
||||||
Can add account data
|
Can add account data
|
||||||
Can add account data to room
|
Can add account data to room
|
||||||
Can get account data without syncing
|
Can get account data without syncing
|
||||||
|
|
Loading…
Reference in New Issue