Prevent sql scanning into nil value in accounts_table (#479)
* Prevent sql scanning into nil value in accounts_table Signed-off-by: Andrew Morgan <andrewm@matrix.org> * Remove uneccessary logging, null checking * Don't forget to set the localpart * Simplify error checkingmain
parent
af08eea46d
commit
141fd91537
|
@ -22,6 +22,8 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountsSchema = `
|
const accountsSchema = `
|
||||||
|
@ -121,14 +123,26 @@ func (s *accountsStatements) selectPasswordHash(
|
||||||
|
|
||||||
func (s *accountsStatements) selectAccountByLocalpart(
|
func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (acc *authtypes.Account, err error) {
|
) (*authtypes.Account, error) {
|
||||||
|
var appserviceIDPtr sql.NullString
|
||||||
|
var acc authtypes.Account
|
||||||
|
|
||||||
stmt := s.selectAccountByLocalpartStmt
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
err = stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &acc.AppServiceID)
|
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if appserviceIDPtr.Valid {
|
||||||
|
acc.AppServiceID = appserviceIDPtr.String
|
||||||
|
}
|
||||||
|
|
||||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
acc.ServerName = s.serverName
|
acc.ServerName = s.serverName
|
||||||
}
|
|
||||||
return
|
return &acc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectNewNumericLocalpart(
|
func (s *accountsStatements) selectNewNumericLocalpart(
|
||||||
|
|
Loading…
Reference in New Issue