diff --git a/appservice/appservice.go b/appservice/appservice.go index be5b30e2..68cf52e7 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -16,6 +16,7 @@ package appservice import ( "context" + "errors" "net/http" "sync" "time" @@ -29,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/appservice/workers" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/basecomponent" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/transactions" @@ -117,12 +119,12 @@ func generateAppServiceAccount( ctx := context.Background() // Create an account for the application service - acc, err := accountsDB.CreateAccount(ctx, as.SenderLocalpart, "", as.ID) + _, err := accountsDB.CreateAccount(ctx, as.SenderLocalpart, "", as.ID) if err != nil { + if errors.Is(err, internal.ErrUserExists) { // This account already exists + return nil + } return err - } else if acc == nil { - // This account already exists - return nil } // Create a dummy device with a dummy token for the application service diff --git a/clientapi/auth/storage/accounts/interface.go b/clientapi/auth/storage/accounts/interface.go index a860f809..4d1941a2 100644 --- a/clientapi/auth/storage/accounts/interface.go +++ b/clientapi/auth/storage/accounts/interface.go @@ -29,6 +29,9 @@ type Database interface { GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error + // CreateAccount makes a new account with the given login name and password, and creates an empty profile + // for this account. If no password is supplied, the account will be a passwordless account. If the + // account already exists, it will return nil, ErrUserExists. CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go index 4a183267..4be5dca9 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -138,7 +138,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, nil. +// account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *authtypes.Account, err error) { @@ -164,7 +164,7 @@ func (d *Database) createAccount( } if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { if internal.IsUniqueConstraintViolationErr(err) { - return nil, nil + return nil, internal.ErrUserExists } return nil, err } diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index 7dec8729..4b436708 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -27,8 +27,8 @@ import ( "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" - // Import the postgres database driver. - _ "github.com/mattn/go-sqlite3" + // Import the sqlite3 database driver. + "github.com/mattn/go-sqlite3" ) // Database represents an account database @@ -148,7 +148,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, nil. +// account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *authtypes.Account, err error) { @@ -172,8 +172,8 @@ func (d *Database) createAccount( } } if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { - if internal.IsUniqueConstraintViolationErr(err) { - return nil, nil + if errors.Is(err, sqlite3.ErrConstraint) { + return nil, internal.ErrUserExists } return nil, err } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 40a2862f..d356db2c 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -830,15 +830,16 @@ func completeRegistration( acc, err := accountDB.CreateAccount(ctx, username, password, appserviceID) if err != nil { + if errors.Is(err, internal.ErrUserExists) { // user already exists + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + } + } return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.Unknown("failed to create account: " + err.Error()), } - } else if acc == nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.UserInUse("Desired user ID is already taken."), - } } // Increment prometheus counter for created users diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index eca9b2fe..9c1e4593 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -69,13 +69,10 @@ func main() { os.Exit(1) } - account, err := accountDB.CreateAccount(context.Background(), *username, *password, "") + _, err = accountDB.CreateAccount(context.Background(), *username, *password, "") if err != nil { fmt.Println(err.Error()) os.Exit(1) - } else if account == nil { - fmt.Println("Username already exists") - os.Exit(1) } deviceDB, err := devices.NewDatabase(*database, nil, serverName) diff --git a/internal/sql.go b/internal/sql.go index 546954bd..e3c10afc 100644 --- a/internal/sql.go +++ b/internal/sql.go @@ -26,6 +26,9 @@ import ( "go.uber.org/atomic" ) +// ErrUserExists is returned if a username already exists in the database. +var ErrUserExists = errors.New("Username already exists") + // A Transaction is something that can be committed or rolledback. type Transaction interface { // Commit the transaction