diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 72b27c8b..2d09090f 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -40,7 +40,10 @@ type Database struct { threepids threepidStatements serverName gomatrixserverlib.ServerName - createAccountMu sync.Mutex + accountsMu sync.Mutex + profilesMu sync.Mutex + accountDatasMu sync.Mutex + threepidsMu sync.Mutex } // NewDatabase creates a new accounts and profiles database @@ -74,7 +77,15 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err = t.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, ac, t, serverName, sync.Mutex{}}, nil + return &Database{ + db: db, + PartitionOffsetStatements: partitions, + accounts: a, + profiles: p, + accountDatas: ac, + threepids: t, + serverName: serverName, + }, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -105,6 +116,8 @@ func (d *Database) GetProfileByLocalpart( func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, ) error { + d.profilesMu.Lock() + defer d.profilesMu.Unlock() return d.profiles.setAvatarURL(ctx, localpart, avatarURL) } @@ -113,6 +126,8 @@ func (d *Database) SetAvatarURL( func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, ) error { + d.profilesMu.Lock() + defer d.profilesMu.Unlock() return d.profiles.setDisplayName(ctx, localpart, displayName) } @@ -124,8 +139,12 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er // when the first txn upgrades to a write txn. We also need to lock the account creation else we can // race with CreateAccount // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.createAccountMu.Lock() - defer d.createAccountMu.Unlock() + d.profilesMu.Lock() + d.accountDatasMu.Lock() + d.accountsMu.Lock() + defer d.profilesMu.Unlock() + defer d.accountDatasMu.Unlock() + defer d.accountsMu.Unlock() err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) @@ -146,8 +165,12 @@ func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *api.Account, err error) { // Create one account at a time else we can get 'database is locked'. - d.createAccountMu.Lock() - defer d.createAccountMu.Unlock() + d.profilesMu.Lock() + d.accountDatasMu.Lock() + d.accountsMu.Lock() + defer d.profilesMu.Unlock() + defer d.accountDatasMu.Unlock() + defer d.accountsMu.Unlock() err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err @@ -155,6 +178,8 @@ func (d *Database) CreateAccount( return } +// WARNING! This function assumes that the relevant mutexes have already +// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ) (*api.Account, error) { @@ -196,6 +221,8 @@ func (d *Database) createAccount( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { + d.accountDatasMu.Lock() + defer d.accountDatasMu.Unlock() return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) @@ -247,6 +274,8 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { + d.threepidsMu.Lock() + defer d.threepidsMu.Unlock() return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, @@ -270,6 +299,8 @@ func (d *Database) SaveThreePIDAssociation( func (d *Database) RemoveThreePIDAssociation( ctx context.Context, threepid string, medium string, ) (err error) { + d.threepidsMu.Lock() + defer d.threepidsMu.Unlock() return d.threepids.deleteThreePID(ctx, threepid, medium) }