Produce OTK counts in /sync response (#1235)
* Add QueryOneTimeKeys for /sync extensions * Unbreak tests * Produce OTK counts in /sync response * Lintingmain
parent
b5cb1d1534
commit
ffcb6d2ea1
|
@ -31,6 +31,7 @@ type KeyInternalAPI interface {
|
||||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
||||||
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
|
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
|
||||||
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
|
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
|
||||||
|
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeyError is returned if there was a problem performing/querying the server
|
// KeyError is returned if there was a problem performing/querying the server
|
||||||
|
@ -157,3 +158,16 @@ type QueryKeyChangesResponse struct {
|
||||||
// Set if there was a problem handling the request.
|
// Set if there was a problem handling the request.
|
||||||
Error *KeyError
|
Error *KeyError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryOneTimeKeysRequest struct {
|
||||||
|
// The local user to query OTK counts for
|
||||||
|
UserID string
|
||||||
|
// The device to query OTK counts for
|
||||||
|
DeviceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryOneTimeKeysResponse struct {
|
||||||
|
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
|
||||||
|
Count OneTimeKeysCount
|
||||||
|
Error *KeyError
|
||||||
|
}
|
||||||
|
|
|
@ -168,6 +168,17 @@ func (a *KeyInternalAPI) claimRemoteKeys(
|
||||||
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
|
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
|
||||||
|
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.Count = *count
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
||||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||||
res.Failures = make(map[string]interface{})
|
res.Failures = make(map[string]interface{})
|
||||||
|
|
|
@ -31,6 +31,7 @@ const (
|
||||||
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
||||||
QueryKeysPath = "/keyserver/queryKeys"
|
QueryKeysPath = "/keyserver/queryKeys"
|
||||||
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
||||||
|
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -108,6 +109,23 @@ func (h *httpKeyInternalAPI) QueryKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpKeyInternalAPI) QueryOneTimeKeys(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryOneTimeKeysRequest,
|
||||||
|
response *api.QueryOneTimeKeysResponse,
|
||||||
|
) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryOneTimeKeysPath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
if err != nil {
|
||||||
|
response.Error = &api.KeyError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpKeyInternalAPI) QueryKeyChanges(
|
func (h *httpKeyInternalAPI) QueryKeyChanges(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryKeyChangesRequest,
|
request *api.QueryKeyChangesRequest,
|
||||||
|
|
|
@ -58,6 +58,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(QueryOneTimeKeysPath,
|
||||||
|
httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryOneTimeKeysRequest{}
|
||||||
|
response := api.QueryOneTimeKeysResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
s.QueryOneTimeKeys(req.Context(), &request, &response)
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(QueryKeyChangesPath,
|
internalAPIMux.Handle(QueryKeyChangesPath,
|
||||||
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.QueryKeyChangesRequest{}
|
request := api.QueryKeyChangesRequest{}
|
||||||
|
|
|
@ -29,6 +29,9 @@ type Database interface {
|
||||||
// StoreOneTimeKeys persists the given one-time keys.
|
// StoreOneTimeKeys persists the given one-time keys.
|
||||||
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
|
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||||
|
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
|
||||||
|
|
||||||
|
|
|
@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||||
|
counts := &api.OneTimeKeysCount{
|
||||||
|
DeviceID: deviceID,
|
||||||
|
UserID: userID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
counts := &api.OneTimeKeysCount{
|
counts := &api.OneTimeKeysCount{
|
||||||
|
|
|
@ -39,6 +39,10 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (
|
||||||
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
|
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||||
|
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
|
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
|
||||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||||
|
counts := &api.OneTimeKeysCount{
|
||||||
|
DeviceID: deviceID,
|
||||||
|
UserID: userID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
counts := &api.OneTimeKeysCount{
|
counts := &api.OneTimeKeysCount{
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
|
|
||||||
type OneTimeKeys interface {
|
type OneTimeKeys interface {
|
||||||
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
|
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||||
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||||
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
|
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
|
||||||
// Returns an empty map if the key does not exist.
|
// Returns an empty map if the key does not exist.
|
||||||
|
|
|
@ -29,6 +29,20 @@ import (
|
||||||
|
|
||||||
const DeviceListLogName = "dl"
|
const DeviceListLogName = "dl"
|
||||||
|
|
||||||
|
// DeviceOTKCounts adds one-time key counts to the /sync response
|
||||||
|
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, deviceID string, res *types.Response) error {
|
||||||
|
var queryRes api.QueryOneTimeKeysResponse
|
||||||
|
keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
}, &queryRes)
|
||||||
|
if queryRes.Error != nil {
|
||||||
|
return queryRes.Error
|
||||||
|
}
|
||||||
|
res.DeviceListsOTKCount = queryRes.Count.KeyCount
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
||||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||||
// be already filled in with join/leave information.
|
// be already filled in with join/leave information.
|
||||||
|
@ -36,6 +50,7 @@ func DeviceListCatchup(
|
||||||
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
|
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
|
||||||
userID string, res *types.Response, from, to types.StreamingToken,
|
userID string, res *types.Response, from, to types.StreamingToken,
|
||||||
) (hasNew bool, err error) {
|
) (hasNew bool, err error) {
|
||||||
|
|
||||||
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
|
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
|
||||||
newlyJoinedRooms := joinedRooms(res, userID)
|
newlyJoinedRooms := joinedRooms(res, userID)
|
||||||
newlyLeftRooms := leftRooms(res)
|
newlyLeftRooms := leftRooms(res)
|
||||||
|
|
|
@ -38,6 +38,9 @@ func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformCl
|
||||||
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
|
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
|
||||||
}
|
}
|
||||||
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
|
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
|
||||||
|
}
|
||||||
|
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockCurrentStateAPI struct {
|
type mockCurrentStateAPI struct {
|
||||||
|
|
|
@ -192,8 +192,9 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
// nolint:gocyclo
|
||||||
res = types.NewResponse()
|
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (*types.Response, error) {
|
||||||
|
res := types.NewResponse()
|
||||||
|
|
||||||
since := types.NewStreamToken(0, 0, nil)
|
since := types.NewStreamToken(0, 0, nil)
|
||||||
if req.since != nil {
|
if req.since != nil {
|
||||||
|
@ -213,17 +214,21 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
|
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
|
||||||
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
|
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return res, err
|
||||||
}
|
}
|
||||||
res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos)
|
res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return res, err
|
||||||
|
}
|
||||||
|
err = internal.DeviceOTKCounts(req.ctx, rp.keyAPI, req.device.UserID, req.device.ID, res)
|
||||||
|
if err != nil {
|
||||||
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Before we return the sync response, make sure that we take action on
|
// Before we return the sync response, make sure that we take action on
|
||||||
|
@ -233,7 +238,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
// Handle the updates and deletions in the database.
|
// Handle the updates and deletions in the database.
|
||||||
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since)
|
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return res, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(events) > 0 {
|
if len(events) > 0 {
|
||||||
|
@ -250,7 +255,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *RequestPool) appendDeviceLists(
|
func (rp *RequestPool) appendDeviceLists(
|
||||||
|
|
|
@ -393,6 +393,7 @@ type Response struct {
|
||||||
Changed []string `json:"changed,omitempty"`
|
Changed []string `json:"changed,omitempty"`
|
||||||
Left []string `json:"left,omitempty"`
|
Left []string `json:"left,omitempty"`
|
||||||
} `json:"device_lists,omitempty"`
|
} `json:"device_lists,omitempty"`
|
||||||
|
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponse creates an empty response with initialised maps.
|
// NewResponse creates an empty response with initialised maps.
|
||||||
|
@ -411,6 +412,7 @@ func NewResponse() *Response {
|
||||||
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
|
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
|
||||||
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
|
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
|
||||||
res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0)
|
res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0)
|
||||||
|
res.DeviceListsOTKCount = make(map[string]int)
|
||||||
|
|
||||||
return &res
|
return &res
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue