Correctly create new device when device_id is passed to /login (#753)

Fixes https://github.com/matrix-org/dendrite/issues/401

Currently when passing a `device_id` parameter to `/login`, which is [supposed](https://matrix.org/docs/spec/client_server/unstable#post-matrix-client-r0-login) to return a device with that ID set, it instead just generates a random `device_id` and hands that back to you.

The code was already there to do this correctly, it looks like it had just been broken during some change. Hopefully sytest will prevent this from becoming broken again.
main
Andrew Morgan 2019-07-22 15:05:38 +01:00 committed by GitHub
parent bdd1a87d4d
commit 78032b3f4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 25 deletions

View File

@ -169,6 +169,8 @@ func (s *devicesStatements) selectDeviceByToken(
return &dev, err return &dev, err
} }
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) { ) (*authtypes.Device, error) {

View File

@ -84,7 +84,7 @@ func (d *Database) CreateDevice(
if deviceID != nil { if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing token for this device // Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err return err
} }

View File

@ -18,7 +18,6 @@ import (
"net/http" "net/http"
"context" "context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -44,8 +43,10 @@ type flow struct {
type passwordRequest struct { type passwordRequest struct {
User string `json:"user"` User string `json:"user"`
Password string `json:"password"` Password string `json:"password"`
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between the two
InitialDisplayName *string `json:"initial_device_display_name"` InitialDisplayName *string `json:"initial_device_display_name"`
DeviceID string `json:"device_id"` DeviceID *string `json:"device_id"`
} }
type loginResponse struct { type loginResponse struct {
@ -110,7 +111,7 @@ func Login(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token) dev, err := getDevice(req.Context(), r, deviceDB, acc, token)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -134,20 +135,16 @@ func Login(
} }
} }
// check if device exists else create one // getDevice returns a new or existing device
func getDevice( func getDevice(
ctx context.Context, ctx context.Context,
r passwordRequest, r passwordRequest,
deviceDB *devices.Database, deviceDB *devices.Database,
acc *authtypes.Account, acc *authtypes.Account,
localpart, token string, token string,
) (dev *authtypes.Device, err error) { ) (dev *authtypes.Device, err error) {
dev, err = deviceDB.GetDeviceByID(ctx, localpart, r.DeviceID)
if err == sql.ErrNoRows {
// device doesn't exist, create one
dev, err = deviceDB.CreateDevice( dev, err = deviceDB.CreateDevice(
ctx, acc.Localpart, nil, token, r.InitialDisplayName, ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName,
) )
}
return return
} }

View File

@ -121,7 +121,10 @@ type registerRequest struct {
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between the two
InitialDisplayName *string `json:"initial_device_display_name"` InitialDisplayName *string `json:"initial_device_display_name"`
DeviceID *string `json:"device_id"`
// Prevent this user from logging in // Prevent this user from logging in
InhibitLogin common.WeakBoolean `json:"inhibit_login"` InhibitLogin common.WeakBoolean `json:"inhibit_login"`
@ -626,7 +629,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, req.Context(), accountDB, deviceDB, r.Username, "", appserviceID,
r.InhibitLogin, r.InitialDisplayName, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -646,7 +649,7 @@ func checkAndCompleteFlow(
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, r.Password, "", req.Context(), accountDB, deviceDB, r.Username, r.Password, "",
r.InhibitLogin, r.InitialDisplayName, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -697,10 +700,10 @@ func LegacyRegister(
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
} }
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
default: default:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotImplemented, Code: http.StatusNotImplemented,
@ -738,13 +741,19 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
return nil return nil
} }
// completeRegistration runs some rudimentary checks against the submitted
// input, then if successful creates an account and a newly associated device
// We pass in each individual part of the request here instead of just passing a
// registerRequest, as this function serves requests encoded as both
// registerRequests and legacyRegisterRequests, which share some attributes but
// not all
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
accountDB *accounts.Database, accountDB *accounts.Database,
deviceDB *devices.Database, deviceDB *devices.Database,
username, password, appserviceID string, username, password, appserviceID string,
inhibitLogin common.WeakBoolean, inhibitLogin common.WeakBoolean,
displayName *string, displayName, deviceID *string,
) util.JSONResponse { ) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
@ -773,6 +782,9 @@ func completeRegistration(
} }
} }
// Increment prometheus counter for created users
amtRegUsers.Inc()
// Check whether inhibit_login option is set. If so, don't create an access // Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user // token or a device for this user
if inhibitLogin { if inhibitLogin {
@ -793,8 +805,7 @@ func completeRegistration(
} }
} }
// TODO: Use the device ID in the request. dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName)
dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -802,9 +813,6 @@ func completeRegistration(
} }
} }
// Increment prometheus counter for created users
amtRegUsers.Inc()
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{

View File

@ -149,3 +149,5 @@ Typing events appear in incremental sync
Typing events appear in gapped sync Typing events appear in gapped sync
Inbound federation of state requires event_id as a mandatory paramater Inbound federation of state requires event_id as a mandatory paramater
Inbound federation of state_ids requires event_id as a mandatory paramater Inbound federation of state_ids requires event_id as a mandatory paramater
POST /register returns the same device_id as that in the request
POST /login returns the same device_id as that in the request