User registration return M_USER_IN_USE when username is already taken (#372)

When registering a new user using POST `/_matrix/client/r0/register`, the server was returning a 500 error when user name was already taken.

I added a check in `completeRegistration` to verify if the username is available before inserting it, and return a 400 `M_USER_IN_USE` error if there is a conflict, as [defined in matrix-doc](https://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-r0-register)

Signed-off-by: Thibaut CHARLES cromfr@gmail.com
main
Thibaut CHARLES 2017-12-19 10:49:42 +01:00 committed by Erik Johnston
parent b835e585c4
commit ec30d143cd
5 changed files with 28 additions and 2 deletions

View File

@ -118,7 +118,8 @@ func (d *Database) SetDisplayName(
} }
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword string, ctx context.Context, localpart, plaintextPassword string,
) (*authtypes.Account, error) { ) (*authtypes.Account, error) {
@ -127,6 +128,9 @@ func (d *Database) CreateAccount(
return nil, err return nil, err
} }
if err := d.profiles.insertProfile(ctx, localpart); err != nil { if err := d.profiles.insertProfile(ctx, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, localpart, hash) return d.accounts.insertAccount(ctx, localpart, hash)

View File

@ -103,6 +103,12 @@ func InvalidUsername(msg string) *MatrixError {
return &MatrixError{"M_INVALID_USERNAME", msg} return &MatrixError{"M_INVALID_USERNAME", msg}
} }
// UserInUse is an error returned when the client tries to register an
// username that already exists
func UserInUse(msg string) *MatrixError {
return &MatrixError{"M_USER_IN_USE", msg}
}
// GuestAccessForbidden is an error which is returned when the client is // GuestAccessForbidden is an error which is returned when the client is
// forbidden from accessing a resource as a guest. // forbidden from accessing a resource as a guest.
func GuestAccessForbidden(msg string) *MatrixError { func GuestAccessForbidden(msg string) *MatrixError {

View File

@ -463,6 +463,11 @@ func completeRegistration(
Code: 500, Code: 500,
JSON: jsonerror.Unknown("failed to create account: " + err.Error()), JSON: jsonerror.Unknown("failed to create account: " + err.Error()),
} }
} else if acc == nil {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.UserInUse("Desired user ID is already taken."),
}
} }
token, err := auth.GenerateAccessToken() token, err := auth.GenerateAccessToken()

View File

@ -69,10 +69,13 @@ func main() {
os.Exit(1) os.Exit(1)
} }
_, err = accountDB.CreateAccount(context.Background(), *username, *password) account, err := accountDB.CreateAccount(context.Background(), *username, *password)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(1) os.Exit(1)
} else if account == nil {
fmt.Println("Username already exists")
os.Exit(1)
} }
deviceDB, err := devices.NewDatabase(*database, serverName) deviceDB, err := devices.NewDatabase(*database, serverName)

View File

@ -16,6 +16,8 @@ package common
import ( import (
"database/sql" "database/sql"
"github.com/lib/pq"
) )
// A Transaction is something that can be committed or rolledback. // A Transaction is something that can be committed or rolledback.
@ -66,3 +68,9 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
} }
return statement return statement
} }
// IsUniqueConstraintViolationErr returns true if the error is a postgresql unique_violation error
func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}