Auto-generate username if none provided during registration (#470)

* Auto-generate username if none provided during registration

* Remove rogue backtick

* Add appropriate log msg
main
Anant Prakash 2018-05-31 20:06:15 +05:30 committed by Andrew Morgan
parent 05be8d1c99
commit 1f570d0e92
3 changed files with 45 additions and 4 deletions

View File

@ -38,6 +38,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- TODO:
-- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
);
-- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
`
const insertAccountSQL = "" +
@ -49,13 +51,17 @@ const selectAccountByLocalpartSQL = "" +
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1"
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
@ -72,6 +78,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
return
}
if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
@ -121,3 +130,10 @@ func (s *accountsStatements) selectAccountByLocalpart(
}
return
}
func (s *accountsStatements) selectNewNumericLocalpart(
ctx context.Context,
) (id int64, err error) {
err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
return
}

View File

@ -267,6 +267,13 @@ func (d *Database) GetAccountDataByType(
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err

View File

@ -27,6 +27,7 @@ import (
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
@ -403,6 +404,23 @@ func Register(
sessionID = util.RandomString(sessionIDLength)
}
// Don't allow numeric usernames less than MAX_INT64.
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
}
}
// Auto generate a numeric username if r.Username is empty
if r.Username == "" {
id, err := accountDB.GetNewNumericLocalpart(req.Context())
if err != nil {
return httputil.LogThenError(req, err)
}
r.Username = strconv.FormatInt(id, 10)
}
// If no auth type is specified by the client, send back the list of available flows
if r.Auth.Type == "" {
return util.JSONResponse{