diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index dbf6606c..d84f25b1 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -42,7 +42,7 @@ type Database struct { filter filterStatements serverName gomatrixserverlib.ServerName - createGuestAccountMu sync.Mutex + createAccountMu sync.Mutex } // NewDatabase creates a new accounts and profiles database @@ -129,14 +129,14 @@ func (d *Database) SetDisplayName( // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { + // We need to lock so we sequentially create numeric localparts. If we don't, two calls to + // this function will cause the same number to be selected and one will fail with 'database is locked' + // when the first txn upgrades to a write txn. We also need to lock the account creation else we can + // race with CreateAccount + // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. + d.createAccountMu.Lock() + defer d.createAccountMu.Unlock() err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - // We need to lock so we sequentially create numeric localparts. If we don't, two calls to - // this function will cause the same number to be selected and one will fail with 'database is locked' - // when the first txn upgrades to a write txn. - // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.createGuestAccountMu.Lock() - defer d.createGuestAccountMu.Unlock() - var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil { @@ -155,6 +155,9 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *api.Account, err error) { + // Create one account at a time else we can get 'database is locked'. + d.createAccountMu.Lock() + defer d.createAccountMu.Unlock() err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err