Handle inbound federation E2E key queries/claims (#1215)

* Handle inbound /keys/claim and /keys/query requests

* Add display names to device key responses

* Linting
main
Kegsay 2020-07-22 17:04:57 +01:00 committed by GitHub
parent 1e71fd645e
commit 541a23f712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 321 additions and 35 deletions

View File

@ -186,7 +186,7 @@ func main() {
ServerKeyAPI: serverKeyAPI,
StateAPI: stateAPI,
UserAPI: userAPI,
KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation),
KeyAPI: keyserver.NewInternalAPI(base.Base.Cfg, federation, userAPI),
ExtPublicRoomsProvider: provider,
}
monolith.AddAllPublicRoutes(base.Base.PublicAPIMux)

View File

@ -141,7 +141,7 @@ func main() {
RoomserverAPI: rsAPI,
UserAPI: userAPI,
StateAPI: stateAPI,
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation),
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI),
//ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation,

View File

@ -30,10 +30,11 @@ func main() {
keyRing := serverKeyAPI.KeyRing()
fsAPI := base.FederationSenderHTTPClient()
rsAPI := base.RoomserverHTTPClient()
keyAPI := base.KeyServerHTTPClient()
federationapi.AddPublicRoutes(
base.PublicAPIMux, base.Cfg, userAPI, federation, keyRing,
rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(),
rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), keyAPI,
)
base.SetupAndServeHTTP(string(base.Cfg.Bind.FederationAPI), string(base.Cfg.Listen.FederationAPI))

View File

@ -24,7 +24,7 @@ func main() {
base := setup.NewBaseDendrite(cfg, "KeyServer", true)
defer base.Close() // nolint: errcheck
intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient())
intAPI := keyserver.NewInternalAPI(base.Cfg, base.CreateFederationClient(), base.UserAPIClient())
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI)

View File

@ -119,7 +119,7 @@ func main() {
rsImpl.SetFederationSenderAPI(fsAPI)
stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer)
keyAPI := keyserver.NewInternalAPI(base.Cfg, federation)
keyAPI := keyserver.NewInternalAPI(base.Cfg, federation, userAPI)
monolith := setup.Monolith{
Config: base.Cfg,

View File

@ -233,7 +233,7 @@ func main() {
RoomserverAPI: rsAPI,
StateAPI: stateAPI,
UserAPI: userAPI,
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation),
KeyAPI: keyserver.NewInternalAPI(base.Cfg, federation, userAPI),
//ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: p2pPublicRoomProvider,
}

View File

@ -20,6 +20,7 @@ import (
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
@ -38,11 +39,12 @@ func AddPublicRoutes(
federationSenderAPI federationSenderAPI.FederationSenderInternalAPI,
eduAPI eduserverAPI.EDUServerInputAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI,
keyAPI keyserverAPI.KeyInternalAPI,
) {
routing.Setup(
router, cfg, rsAPI,
eduAPI, federationSenderAPI, keyRing,
federation, userAPI, stateAPI,
federation, userAPI, stateAPI, keyAPI,
)
}

View File

@ -31,7 +31,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
fsAPI := base.FederationSenderHTTPClient()
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil)
federationapi.AddPublicRoutes(base.PublicAPIMux, cfg, nil, nil, keyRing, nil, fsAPI, nil, nil, nil)
httputil.SetupHTTPAPI(
base.BaseMux,
base.PublicAPIMux,

View File

@ -19,12 +19,106 @@ import (
"net/http"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"golang.org/x/crypto/ed25519"
)
type queryKeysRequest struct {
DeviceKeys map[string][]string `json:"device_keys"`
}
// QueryDeviceKeys returns device keys for users on this server.
// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query
func QueryDeviceKeys(
httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName,
) util.JSONResponse {
var qkr queryKeysRequest
err := json.Unmarshal(request.Content(), &qkr)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
}
// make sure we only query users on our domain
for userID := range qkr.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
delete(qkr.DeviceKeys, userID)
continue // ignore invalid users
}
if serverName != thisServer {
delete(qkr.DeviceKeys, userID)
continue
}
}
var queryRes api.QueryKeysResponse
keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
UserToDevices: qkr.DeviceKeys,
}, &queryRes)
if queryRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: 200,
JSON: struct {
DeviceKeys interface{} `json:"device_keys"`
}{queryRes.DeviceKeys},
}
}
type claimOTKsRequest struct {
OneTimeKeys map[string]map[string]string `json:"one_time_keys"`
}
// ClaimOneTimeKeys claims OTKs for users on this server.
// https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim
func ClaimOneTimeKeys(
httpReq *http.Request, request *gomatrixserverlib.FederationRequest, keyAPI api.KeyInternalAPI, thisServer gomatrixserverlib.ServerName,
) util.JSONResponse {
var cor claimOTKsRequest
err := json.Unmarshal(request.Content(), &cor)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
}
// make sure we only claim users on our domain
for userID := range cor.OneTimeKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
delete(cor.OneTimeKeys, userID)
continue // ignore invalid users
}
if serverName != thisServer {
delete(cor.OneTimeKeys, userID)
continue
}
}
var claimRes api.PerformClaimKeysResponse
keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: cor.OneTimeKeys,
}, &claimRes)
if claimRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: 200,
JSON: struct {
OneTimeKeys interface{} `json:"one_time_keys"`
}{claimRes.OneTimeKeys},
}
}
// LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg *config.Dendrite) util.JSONResponse {

View File

@ -24,6 +24,7 @@ import (
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/httputil"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -54,6 +55,7 @@ func Setup(
federation *gomatrixserverlib.FederationClient,
userAPI userapi.UserInternalAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI,
keyAPI keyserverAPI.KeyInternalAPI,
) {
v2keysmux := publicAPIMux.PathPrefix(pathPrefixV2Keys).Subrouter()
v1fedmux := publicAPIMux.PathPrefix(pathPrefixV1Federation).Subrouter()
@ -299,4 +301,18 @@ func Setup(
return GetPostPublicRooms(req, rsAPI, stateAPI)
}),
).Methods(http.MethodGet)
v1fedmux.Handle("/user/keys/claim", httputil.MakeFedAPI(
"federation_keys_claim", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)
v1fedmux.Handle("/user/keys/query", httputil.MakeFedAPI(
"federation_keys_query", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)
}

2
go.mod
View File

@ -21,7 +21,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
github.com/mattn/go-sqlite3 v2.0.2+incompatible

2
go.sum
View File

@ -423,6 +423,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d h1:WZXyd8YI+PQIDYjN8HxtqNRJ1DCckt9wPTi2P8cdnKM=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=

View File

@ -73,7 +73,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) {
federationapi.AddPublicRoutes(
publicMux, m.Config, m.UserAPI, m.FedClient,
m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI,
m.EDUInternalAPI, m.StateAPI,
m.EDUInternalAPI, m.StateAPI, m.KeyAPI,
)
mediaapi.AddPublicRoutes(publicMux, m.Config, m.UserAPI, m.Client)
syncapi.AddPublicRoutes(

View File

@ -24,7 +24,9 @@ import (
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@ -33,6 +35,7 @@ type KeyInternalAPI struct {
DB storage.Database
ThisServer gomatrixserverlib.ServerName
FedClient *gomatrixserverlib.FederationClient
UserAPI userapi.UserInternalAPI
}
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
@ -66,11 +69,25 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
}
}
mergeInto(res.OneTimeKeys, keys)
util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
for _, key := range keys {
_, ok := res.OneTimeKeys[key.UserID]
if !ok {
res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
}
_, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
if !ok {
res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
}
for keyID, keyJSON := range key.KeyJSON {
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
}
}
delete(domainToDeviceKeys, string(a.ThisServer))
}
// claim remote keys
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
}
}
func (a *KeyInternalAPI) claimRemoteKeys(
@ -82,6 +99,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
wg.Add(len(domainToDeviceKeys))
// mutex for failures
var failMu sync.Mutex
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
// fan out
for d, k := range domainToDeviceKeys {
@ -91,6 +109,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
failMu.Lock()
res.Failures[domain] = map[string]interface{}{
"message": err.Error(),
@ -108,6 +127,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
close(resultCh)
}()
keysClaimed := 0
for result := range resultCh {
for userID, nest := range result.OneTimeKeys {
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
@ -119,10 +139,12 @@ func (a *KeyInternalAPI) claimRemoteKeys(
continue
}
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
keysClaimed++
}
}
}
}
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
}
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
@ -145,13 +167,28 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
return
}
// pull out display names after we have the keys so we handle wildcards correctly
var dids []string
for _, dk := range deviceKeys {
dids = append(dids, dk.DeviceID)
}
var queryRes userapi.QueryDeviceInfosResponse
err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
DeviceIDs: dids,
}, &queryRes)
if err != nil {
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
}
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
for _, dk := range deviceKeys {
// inject an empty 'unsigned' key which should be used for display names
// (but not via this API? unsure when they should be added)
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct{}{})
// inject display name if known
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
}
} else {
@ -298,19 +335,3 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
// TODO
}
func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) {
for _, key := range src {
_, ok := dst[key.UserID]
if !ok {
dst[key.UserID] = make(map[string]map[string]json.RawMessage)
}
_, ok = dst[key.UserID][key.DeviceID]
if !ok {
dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
}
for keyID, keyJSON := range key.KeyJSON {
dst[key.UserID][key.DeviceID][keyID] = keyJSON
}
}
}

View File

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/keyserver/inthttp"
"github.com/matrix-org/dendrite/keyserver/storage"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
@ -33,7 +34,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient) api.KeyInternalAPI {
func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI) api.KeyInternalAPI {
db, err := storage.NewDatabase(
string(cfg.Database.E2EKey),
cfg.DbProperties(),
@ -45,5 +46,6 @@ func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.Federatio
DB: db,
ThisServer: cfg.Matrix.ServerName,
FedClient: fedClient,
UserAPI: userAPI,
}
}

View File

@ -122,9 +122,11 @@ User can invite local user to room with version 1
Can upload device keys
Should reject keys claiming to belong to a different user
Can query device keys using POST
Can query remote device keys using POST
Can query specific device keys using POST
query for user with no keys returns empty key dict
Can claim one time key using POST
Can claim remote one time key using POST
Can add account data
Can add account data to room
Can get account data without syncing

View File

@ -30,6 +30,7 @@ type UserInternalAPI interface {
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
}
// InputAccountDataRequest is the request for InputAccountData
@ -44,6 +45,19 @@ type InputAccountDataRequest struct {
type InputAccountDataResponse struct {
}
// QueryDeviceInfosRequest is the request to QueryDeviceInfos
type QueryDeviceInfosRequest struct {
DeviceIDs []string
}
// QueryDeviceInfosResponse is the response to QueryDeviceInfos
type QueryDeviceInfosResponse struct {
DeviceInfo map[string]struct {
DisplayName string
UserID string
}
}
// QueryAccessTokenRequest is the request for QueryAccessToken
type QueryAccessTokenRequest struct {
AccessToken string

View File

@ -125,6 +125,27 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
return nil
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
res.DeviceInfo = make(map[string]struct {
DisplayName string
UserID string
})
for _, d := range devices {
res.DeviceInfo[d.ID] = struct {
DisplayName string
UserID string
}{
DisplayName: d.DisplayName,
UserID: d.UserID,
}
}
return nil
}
func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {

View File

@ -35,6 +35,7 @@ const (
QueryAccessTokenPath = "/userapi/queryAccessToken"
QueryDevicesPath = "/userapi/queryDevices"
QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -101,6 +102,18 @@ func (h *httpUserInternalAPI) QueryProfile(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryDeviceInfos(
ctx context.Context,
request *api.QueryDeviceInfosRequest,
response *api.QueryDeviceInfosResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos")
defer span.Finish()
apiURL := h.apiURL + QueryDeviceInfosPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryAccessToken(
ctx context.Context,
request *api.QueryAccessTokenRequest,

View File

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/util"
)
// nolint: gocyclo
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
@ -103,4 +104,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryDeviceInfosPath,
httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse {
request := api.QueryDeviceInfosRequest{}
response := api.QueryDeviceInfosResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View File

@ -24,6 +24,7 @@ type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,

View File

@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
s.serverName = server
return
}
@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
}
return &dev, err
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {

View File

@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart(
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,

View File

@ -20,6 +20,7 @@ import (
"strings"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
type devicesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
@ -79,6 +83,7 @@ type devicesStatements struct {
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
s.serverName = server
return
}
@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
}
return &dev, err
}
@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs {
iDeviceIDs[i] = deviceIDs[i]
}
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}

View File

@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart(
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,