diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 756598db..7b9d8acd 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -15,11 +15,10 @@ package routing import ( - "encoding/json" - "io/ioutil" "net/http" "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/keyserver/api" @@ -29,37 +28,52 @@ import ( "github.com/matrix-org/util" ) +type crossSigningRequest struct { + api.PerformUploadDeviceKeysRequest + Auth newPasswordAuth `json:"auth"` +} + func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, keyserverAPI api.KeyInternalAPI, device *userapi.Device, accountDB accounts.Database, cfg *config.ClientAPI, ) util.JSONResponse { - uploadReq := &api.PerformUploadDeviceKeysRequest{} + uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} - ctx := req.Context() - defer req.Body.Close() // nolint:errcheck - bodyBytes, err := ioutil.ReadAll(req.Body) - if err != nil { + resErr := httputil.UnmarshalJSONRequest(req, &uploadReq) + if resErr != nil { + return *resErr + } + sessionID := uploadReq.Auth.Session + if sessionID == "" { + sessionID = util.RandomString(sessionIDLength) + } + if uploadReq.Auth.Type != authtypes.LoginTypePassword { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be read: " + err.Error()), + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + }, + nil, + ), } } - - if _, err := userInteractiveAuth.Verify(ctx, bodyBytes, device); err != nil { - return *err + typePassword := auth.LoginTypePassword{ + GetAccountByPassword: accountDB.GetAccountByPassword, + Config: cfg, } - - if err = json.Unmarshal(bodyBytes, &uploadReq); err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be unmarshalled: " + err.Error()), - } + if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { + return *authErr } + AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID - keyserverAPI.PerformUploadDeviceKeys(req.Context(), uploadReq, uploadRes) + keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) if err := uploadRes.Error; err != nil { switch {