From 1f570d0e924e3da832eb26b089762cba3c6b1a76 Mon Sep 17 00:00:00 2001 From: Anant Prakash Date: Thu, 31 May 2018 20:06:15 +0530 Subject: [PATCH] Auto-generate username if none provided during registration (#470) * Auto-generate username if none provided during registration * Remove rogue backtick * Add appropriate log msg --- .../auth/storage/accounts/accounts_table.go | 24 +++++++++++++++---- .../auth/storage/accounts/storage.go | 7 ++++++ .../dendrite/clientapi/routing/register.go | 18 ++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go index 3b4e6bd5..aaf6af39 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go @@ -38,6 +38,8 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- TODO: -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? ); +-- Create sequence for autogenerated numeric usernames +CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + @@ -49,13 +51,17 @@ const selectAccountByLocalpartSQL = "" + const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1" +const selectNewNumericLocalpartSQL = "" + + "SELECT nextval('numeric_username_seq')" + // TODO: Update password type accountsStatements struct { - insertAccountStmt *sql.Stmt - selectAccountByLocalpartStmt *sql.Stmt - selectPasswordHashStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + insertAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName } func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { @@ -72,6 +78,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { return } + if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { + return + } s.serverName = server return } @@ -121,3 +130,10 @@ func (s *accountsStatements) selectAccountByLocalpart( } return } + +func (s *accountsStatements) selectNewNumericLocalpart( + ctx context.Context, +) (id int64, err error) { + err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) + return +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 57148273..fc82ec75 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -267,6 +267,13 @@ func (d *Database) GetAccountDataByType( ) } +// GetNewNumericLocalpart generates and returns a new unused numeric localpart +func (d *Database) GetNewNumericLocalpart( + ctx context.Context, +) (int64, error) { + return d.accounts.selectNewNumericLocalpart(ctx) +} + func hashPassword(plaintext string) (hash string, err error) { hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) return string(hashBytes), err diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index 7a3b8686..cb427b71 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -27,6 +27,7 @@ import ( "net/url" "regexp" "sort" + "strconv" "strings" "time" @@ -403,6 +404,23 @@ func Register( sessionID = util.RandomString(sessionIDLength) } + // Don't allow numeric usernames less than MAX_INT64. + if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + } + } + // Auto generate a numeric username if r.Username is empty + if r.Username == "" { + id, err := accountDB.GetNewNumericLocalpart(req.Context()) + if err != nil { + return httputil.LogThenError(req, err) + } + + r.Username = strconv.FormatInt(id, 10) + } + // If no auth type is specified by the client, send back the list of available flows if r.Auth.Type == "" { return util.JSONResponse{