Add contexts to device database (#233)
* Add contexts to device database * Remove spurious whitespacemain
parent
e28ee27605
commit
238646ee3c
|
@ -7,6 +7,16 @@ export GOGC=400
|
||||||
export GOPATH="$(pwd):$(pwd)/vendor"
|
export GOPATH="$(pwd):$(pwd)/vendor"
|
||||||
export PATH="$PATH:$(pwd)/vendor/bin:$(pwd)/bin"
|
export PATH="$PATH:$(pwd)/vendor/bin:$(pwd)/bin"
|
||||||
|
|
||||||
|
echo "Checking that it builds"
|
||||||
|
gb build
|
||||||
|
|
||||||
|
# Check that all the packages can build.
|
||||||
|
# When `go build` is given multiple packages it won't output anything, and just
|
||||||
|
# checks that everything builds. This seems to do a better job of handling
|
||||||
|
# missing imports than `gb build` does.
|
||||||
|
echo "Double checking it builds..."
|
||||||
|
go build github.com/matrix-org/dendrite/cmd/...
|
||||||
|
|
||||||
echo "Installing lint search engine..."
|
echo "Installing lint search engine..."
|
||||||
go install github.com/alecthomas/gometalinter/
|
go install github.com/alecthomas/gometalinter/
|
||||||
gometalinter --config=linter.json ./... --install
|
gometalinter --config=linter.json ./... --install
|
||||||
|
@ -20,11 +30,5 @@ misspell -error src *.md
|
||||||
echo "Testing..."
|
echo "Testing..."
|
||||||
gb test
|
gb test
|
||||||
|
|
||||||
# Check that all the packages can build.
|
|
||||||
# When `go build` is given multiple packages it won't output anything, and just
|
|
||||||
# checks that everything builds. This seems to do a better job of handling
|
|
||||||
# missing imports than `gb build` does.
|
|
||||||
echo "Double checking it builds..."
|
|
||||||
go build github.com/matrix-org/dendrite/cmd/...
|
|
||||||
|
|
||||||
echo "Done!"
|
echo "Done!"
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -42,7 +43,7 @@ var tokenByteLength = 32
|
||||||
// DeviceDatabase represents a device database.
|
// DeviceDatabase represents a device database.
|
||||||
type DeviceDatabase interface {
|
type DeviceDatabase interface {
|
||||||
// Look up the device matching the given access token.
|
// Look up the device matching the given access token.
|
||||||
GetDeviceByAccessToken(token string) (*authtypes.Device, error)
|
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyAccessToken verifies that an access token was supplied in the given HTTP request
|
// VerifyAccessToken verifies that an access token was supplied in the given HTTP request
|
||||||
|
@ -57,7 +58,7 @@ func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *auth
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
device, err = deviceDB.GetDeviceByAccessToken(token)
|
device, err = deviceDB.GetDeviceByAccessToken(req.Context(), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
resErr = &util.JSONResponse{
|
resErr = &util.JSONResponse{
|
||||||
|
|
|
@ -15,10 +15,13 @@
|
||||||
package devices
|
package devices
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -84,27 +87,36 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
||||||
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) insertDevice(txn *sql.Tx, id, localpart, accessToken string) (dev *authtypes.Device, err error) {
|
func (s *devicesStatements) insertDevice(
|
||||||
|
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
if _, err = txn.Stmt(s.insertDeviceStmt).Exec(id, localpart, accessToken, createdTimeMS); err == nil {
|
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||||
dev = &authtypes.Device{
|
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &authtypes.Device{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: makeUserID(localpart, s.serverName),
|
UserID: makeUserID(localpart, s.serverName),
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
}
|
}, nil
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevice(txn *sql.Tx, id, localpart string) error {
|
func (s *devicesStatements) deleteDevice(
|
||||||
_, err := txn.Stmt(s.deleteDeviceStmt).Exec(id, localpart)
|
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDeviceByToken(accessToken string) (*authtypes.Device, error) {
|
func (s *devicesStatements) selectDeviceByToken(
|
||||||
|
ctx context.Context, accessToken string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
var dev authtypes.Device
|
var dev authtypes.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
err := s.selectDeviceByTokenStmt.QueryRow(accessToken).Scan(&dev.ID, &localpart)
|
stmt := s.selectDeviceByTokenStmt
|
||||||
|
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.UserID = makeUserID(localpart, s.serverName)
|
dev.UserID = makeUserID(localpart, s.serverName)
|
||||||
dev.AccessToken = accessToken
|
dev.AccessToken = accessToken
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package devices
|
package devices
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
@ -44,8 +45,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, error) {
|
func (d *Database) GetDeviceByAccessToken(
|
||||||
return d.devices.selectDeviceByToken(token)
|
ctx context.Context, token string,
|
||||||
|
) (*authtypes.Device, error) {
|
||||||
|
return d.devices.selectDeviceByToken(ctx, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
|
@ -53,15 +56,17 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
// an error will be returned.
|
// an error will be returned.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) {
|
func (d *Database) CreateDevice(
|
||||||
|
ctx context.Context, localpart, deviceID, accessToken string,
|
||||||
|
) (dev *authtypes.Device, returnErr error) {
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
// Revoke existing token for this device
|
// Revoke existing token for this device
|
||||||
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil {
|
if err = d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err = d.devices.insertDevice(txn, deviceID, localpart, accessToken)
|
dev, err = d.devices.insertDevice(ctx, txn, deviceID, localpart, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -74,9 +79,11 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a
|
||||||
// matching with the given device ID and user ID localpart
|
// matching with the given device ID and user ID localpart
|
||||||
// If the device doesn't exist, it will not return an error
|
// If the device doesn't exist, it will not return an error
|
||||||
// If something went wrong during the deletion, it will return the SQL error
|
// If something went wrong during the deletion, it will return the SQL error
|
||||||
func (d *Database) RemoveDevice(deviceID string, localpart string) error {
|
func (d *Database) RemoveDevice(
|
||||||
|
ctx context.Context, deviceID, localpart string,
|
||||||
|
) error {
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows {
|
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -98,7 +98,9 @@ func Login(
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Use the device ID in the request
|
// TODO: Use the device ID in the request
|
||||||
dev, err := deviceDB.CreateDevice(acc.Localpart, auth.UnknownDeviceID, token)
|
dev, err := deviceDB.CreateDevice(
|
||||||
|
req.Context(), acc.Localpart, auth.UnknownDeviceID, token,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 500,
|
Code: 500,
|
||||||
|
|
|
@ -41,7 +41,7 @@ func Logout(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := deviceDB.RemoveDevice(device.ID, localpart); err != nil {
|
if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -135,9 +135,7 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices
|
||||||
switch r.Auth.Type {
|
switch r.Auth.Type {
|
||||||
case authtypes.LoginTypeDummy:
|
case authtypes.LoginTypeDummy:
|
||||||
// there is nothing to do
|
// there is nothing to do
|
||||||
return completeRegistration(
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
||||||
req.Context(), accountDB, deviceDB, r.Username, r.Password,
|
|
||||||
)
|
|
||||||
default:
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 501,
|
Code: 501,
|
||||||
|
@ -182,7 +180,7 @@ func completeRegistration(
|
||||||
}
|
}
|
||||||
|
|
||||||
// // TODO: Use the device ID in the request.
|
// // TODO: Use the device ID in the request.
|
||||||
dev, err := deviceDB.CreateDevice(username, auth.UnknownDeviceID, token)
|
dev, err := deviceDB.CreateDevice(ctx, username, auth.UnknownDeviceID, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 500,
|
Code: 500,
|
||||||
|
|
|
@ -86,7 +86,9 @@ func main() {
|
||||||
accessToken = &t
|
accessToken = &t
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := deviceDB.CreateDevice(*username, "create-account-script", *accessToken)
|
device, err := deviceDB.CreateDevice(
|
||||||
|
context.Background(), *username, "create-account-script", *accessToken,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err.Error())
|
fmt.Println(err.Error())
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|
Loading…
Reference in New Issue