diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go index ab483622..c50cd1fd 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/filter_table.go @@ -15,8 +15,8 @@ package accounts import ( - "database/sql" "context" + "database/sql" ) const filterSchema = ` @@ -41,13 +41,9 @@ const selectFilterSQL = "" + const insertFilterSQL = "" + "INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id" -const findMaxIDSQL = "" + - "SELECT MAX(id) FROM account_filter WHERE localpart = $1" - type filterStatements struct { selectFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt - findMaxIDStmt *sql.Stmt } func (s *filterStatements) prepare(db *sql.DB) (err error) { @@ -61,10 +57,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { return } - if s.findMaxIDStmt, err = db.Prepare(findMaxIDSQL); err != nil { - return - } - return } @@ -77,14 +69,7 @@ func (s *filterStatements) selectFilter( func (s *filterStatements) insertFilter( ctx context.Context, filter string, localpart string, -) (err error) { - _, err = s.insertFilterStmt.ExecContext(ctx, filter, localpart) - return -} - -func (s *filterStatements) findMaxID( - ctx context.Context, localpart string, -) (id string, err error) { - err = s.findMaxIDStmt.QueryRowContext(ctx, localpart).Scan(&id) +) (pos string, err error) { + err = s.insertFilterStmt.QueryRowContext(ctx, filter, localpart).Scan(&pos) return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 4bc570a2..33fbbd86 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -29,7 +29,7 @@ import ( // Database represents an account database type Database struct { - db *sql.DB + db *sql.DB common.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements @@ -333,11 +333,7 @@ func (d *Database) GetFilter( func (d *Database) PutFilter( ctx context.Context, localpart, filter string, ) (string, error) { - err := d.filter.insertFilter(ctx, filter, localpart) - if err != nil { - return "", err - } - return d.filter.findMaxID(ctx, localpart) + return d.filter.insertFilter(ctx, filter, localpart) } // CheckAccountAvailability checks if the username/localpart is already present in the database.