From 141fd9153766f9c683aa69108160f1ea0d75e5fe Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 29 Jun 2018 03:55:29 -0700 Subject: [PATCH] Prevent sql scanning into nil value in accounts_table (#479) * Prevent sql scanning into nil value in accounts_table Signed-off-by: Andrew Morgan * Remove uneccessary logging, null checking * Don't forget to set the localpart * Simplify error checking --- .../auth/storage/accounts/accounts_table.go | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go index aaf6af39..4ed54f95 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go @@ -22,6 +22,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" + + log "github.com/sirupsen/logrus" ) const accountsSchema = ` @@ -121,14 +123,26 @@ func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, -) (acc *authtypes.Account, err error) { +) (*authtypes.Account, error) { + var appserviceIDPtr sql.NullString + var acc authtypes.Account + stmt := s.selectAccountByLocalpartStmt - err = stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &acc.AppServiceID) - if err == nil { - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve user from the db") + } + return nil, err } - return + if appserviceIDPtr.Valid { + acc.AppServiceID = appserviceIDPtr.String + } + + acc.UserID = userutil.MakeUserID(localpart, s.serverName) + acc.ServerName = s.serverName + + return &acc, nil } func (s *accountsStatements) selectNewNumericLocalpart(