From 13107c6b2ba623b8b59ffb1ee7f69a8a9c194827 Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 9 Oct 2017 16:24:38 +0200 Subject: [PATCH] Implement /register/available API (#291) Signed-off-by: MTRNord --- .../auth/storage/accounts/storage.go | 10 +++ .../dendrite/clientapi/routing/routing.go | 4 + .../dendrite/clientapi/writers/register.go | 80 +++++++++++++++---- 3 files changed, 79 insertions(+), 15 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 16474ec6..5449df5c 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -314,3 +314,13 @@ func (d *Database) GetThreePIDsForLocalpart( ) (threepids []authtypes.ThreePID, err error) { return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) } + +// CheckAccountAvailability checks if the username/localpart is already present in the database. +// If the DB returns sql.ErrNoRows the Localpart isn't taken. +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + if err == sql.ErrNoRows { + return true, nil + } + return false, err +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index b32ddf14..00bdd15f 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -131,6 +131,10 @@ func Setup( return writers.LegacyRegister(req, accountDB, deviceDB, &cfg) })).Methods("POST", "OPTIONS") + r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + return writers.RegisterAvailable(req, accountDB) + })).Methods("GET") + r0mux.Handle("/directory/room/{roomAlias}", common.MakeAuthAPI("directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse { vars := mux.Vars(req) diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/register.go b/src/github.com/matrix-org/dendrite/clientapi/writers/register.go index 8519c9a1..84227d15 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/register.go @@ -91,24 +91,14 @@ type registerResponse struct { DeviceID string `json:"device_id"` } -// Validate returns an error response if the username/password are invalid -func validate(username, password string) *util.JSONResponse { +// validateUserName returns an error response if the username is invalid +func validateUserName(username string) *util.JSONResponse { // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), - } - } else if len(username) > maxUsernameLength { + if len(username) > maxUsernameLength { return &util.JSONResponse{ Code: 400, JSON: jsonerror.BadJSON(fmt.Sprintf("'username' >%d characters", maxUsernameLength)), } - } else if len(password) > 0 && len(password) < minPasswordLength { - return &util.JSONResponse{ - Code: 400, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), - } } else if !validUsernameRegex.MatchString(username) { return &util.JSONResponse{ Code: 400, @@ -123,6 +113,23 @@ func validate(username, password string) *util.JSONResponse { return nil } +// validatePassword returns an error response if the password is invalid +func validatePassword(password string) *util.JSONResponse { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if len(password) > maxPasswordLength { + return &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), + } + } else if len(password) > 0 && len(password) < minPasswordLength { + return &util.JSONResponse{ + Code: 400, + JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + } + } + return nil +} + // Register processes a /register request. http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register func Register( req *http.Request, @@ -149,7 +156,10 @@ func Register( } } - if resErr = validate(r.Username, r.Password); resErr != nil { + if resErr = validateUserName(r.Username); resErr != nil { + return *resErr + } + if resErr = validatePassword(r.Password); resErr != nil { return *resErr } @@ -209,7 +219,10 @@ func LegacyRegister( if resErr != nil { return *resErr } - if resErr = validate(r.Username, r.Password); resErr != nil { + if resErr = validateUserName(r.Username); resErr != nil { + return *resErr + } + if resErr = validatePassword(r.Password); resErr != nil { return *resErr } @@ -344,3 +357,40 @@ func isValidMacLogin( return hmac.Equal(givenMac, expectedMAC), nil } + +type availableResponse struct { + Available bool `json:"available"` +} + +// RegisterAvailable checks if the username is already taken or invalid +func RegisterAvailable( + req *http.Request, + accountDB *accounts.Database, +) util.JSONResponse { + username := req.URL.Query().Get("username") + + if err := validateUserName(username); err != nil { + return *err + } + + availability, availabilityErr := accountDB.CheckAccountAvailability(req.Context(), username) + if availabilityErr != nil { + return util.JSONResponse{ + Code: 500, + JSON: jsonerror.Unknown("failed to check availability: " + availabilityErr.Error()), + } + } + if !availability { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidUsername("A different user ID has already been registered for this session"), + } + } + + return util.JSONResponse{ + Code: 200, + JSON: availableResponse{ + Available: true, + }, + } +}