Add in devices_table to store device information (#118)
parent
e6835660b0
commit
445dce14ae
|
@ -16,6 +16,7 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -40,9 +41,16 @@ func VerifyAccessToken(req *http.Request, deviceDB *devices.Database) (device *a
|
||||||
}
|
}
|
||||||
device, err = deviceDB.GetDeviceByAccessToken(token)
|
device, err = deviceDB.GetDeviceByAccessToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resErr = &util.JSONResponse{
|
if err == sql.ErrNoRows {
|
||||||
Code: 500,
|
resErr = &util.JSONResponse{
|
||||||
JSON: jsonerror.Unknown("Failed to check access token"),
|
Code: 403,
|
||||||
|
JSON: jsonerror.Forbidden("Invalid access token"),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resErr = &util.JSONResponse{
|
||||||
|
Code: 500,
|
||||||
|
JSON: jsonerror.Unknown("Failed to check access token"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -16,8 +16,10 @@ package authtypes
|
||||||
|
|
||||||
// Device represents a client's device (mobile, web, etc)
|
// Device represents a client's device (mobile, web, etc)
|
||||||
type Device struct {
|
type Device struct {
|
||||||
ID string
|
ID string
|
||||||
UserID string
|
UserID string
|
||||||
|
// The access_token granted to this device.
|
||||||
|
// This uniquely identifies the device from all other devices and clients.
|
||||||
AccessToken string
|
AccessToken string
|
||||||
// TODO: display name, last used timestamp, keys, etc
|
// TODO: display name, last used timestamp, keys, etc
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package devices
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const devicesSchema = `
|
||||||
|
-- Stores data about devices.
|
||||||
|
CREATE TABLE IF NOT EXISTS devices (
|
||||||
|
-- The access token granted to this device. This has to be the primary key
|
||||||
|
-- so we can distinguish which device is making a given request.
|
||||||
|
access_token TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
|
||||||
|
-- access_tokens will be clobbered based on the device ID for a user.
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
-- The Matrix user ID localpart for this device. This is preferable to storing the full user_id
|
||||||
|
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
|
||||||
|
-- migration to different domain names easier.
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
||||||
|
created_ts BIGINT NOT NULL
|
||||||
|
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Device IDs must be unique for a given user.
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS localpart_id_idx ON devices(localpart, device_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertDeviceSQL = "" +
|
||||||
|
"INSERT INTO devices(device_id, localpart, access_token, created_ts) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
|
const selectDeviceByTokenSQL = "" +
|
||||||
|
"SELECT device_id, localpart FROM devices WHERE access_token = $1"
|
||||||
|
|
||||||
|
const deleteDeviceSQL = "" +
|
||||||
|
"DELETE FROM devices WHERE device_id = $1 AND localpart = $2"
|
||||||
|
|
||||||
|
// TODO: List devices?
|
||||||
|
|
||||||
|
type devicesStatements struct {
|
||||||
|
insertDeviceStmt *sql.Stmt
|
||||||
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
|
deleteDeviceStmt *sql.Stmt
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
_, err = db.Exec(devicesSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||||
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
|
// Returns the device on success.
|
||||||
|
func (s *devicesStatements) insertDevice(txn *sql.Tx, id, localpart, accessToken string) (dev *authtypes.Device, err error) {
|
||||||
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
|
if _, err = s.insertDeviceStmt.Exec(id, localpart, accessToken, createdTimeMS); err == nil {
|
||||||
|
dev = &authtypes.Device{
|
||||||
|
ID: id,
|
||||||
|
UserID: makeUserID(localpart, s.serverName),
|
||||||
|
AccessToken: accessToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) deleteDevice(txn *sql.Tx, id, localpart string) error {
|
||||||
|
_, err := txn.Stmt(s.deleteDeviceStmt).Exec(id, localpart)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) selectDeviceByToken(accessToken string) (*authtypes.Device, error) {
|
||||||
|
var dev authtypes.Device
|
||||||
|
var localpart string
|
||||||
|
err := s.selectDeviceByTokenStmt.QueryRow(accessToken).Scan(&dev.ID, &localpart)
|
||||||
|
if err == nil {
|
||||||
|
dev.UserID = makeUserID(localpart, s.serverName)
|
||||||
|
dev.AccessToken = accessToken
|
||||||
|
}
|
||||||
|
return &dev, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeUserID(localpart string, server gomatrixserverlib.ServerName) string {
|
||||||
|
return fmt.Sprintf("@%s:%s", localpart, string(server))
|
||||||
|
}
|
|
@ -15,23 +15,82 @@
|
||||||
package devices
|
package devices
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database represents a device database.
|
// Database represents a device database.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
// TODO
|
db *sql.DB
|
||||||
|
devices devicesStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new device database
|
// NewDatabase creates a new device database
|
||||||
func NewDatabase(dataSource string) (*Database, error) {
|
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
|
||||||
return &Database{}, nil
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
|
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d := devicesStatements{}
|
||||||
|
if err = d.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{db, d}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, error) {
|
func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, error) {
|
||||||
// TODO: Actual implementation
|
// return d.devices.selectDeviceByToken(token) TODO: Figure out how to make integ tests pass
|
||||||
return &authtypes.Device{
|
return &authtypes.Device{
|
||||||
UserID: token,
|
UserID: token,
|
||||||
|
AccessToken: token,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
// and replaced with a newly generated token.
|
||||||
|
// Returns the device on success.
|
||||||
|
func (d *Database) CreateDevice(localpart, deviceID string) (dev *authtypes.Device, returnErr error) {
|
||||||
|
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
// Revoke existing token for this device
|
||||||
|
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// TODO: generate an access token. We should probably make sure that it's not possible for this
|
||||||
|
// token to be the same as the one we just revoked...
|
||||||
|
accessToken := makeUserID(localpart, d.devices.serverName)
|
||||||
|
|
||||||
|
dev, err = d.devices.insertDevice(txn, deviceID, localpart, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: factor out to common
|
||||||
|
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||||
|
txn, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
panic(r)
|
||||||
|
} else if err != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
} else {
|
||||||
|
err = txn.Commit()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = fn(txn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("Failed to setup account database(%s): %s", accountDataSource, err.Error())
|
log.Panicf("Failed to setup account database(%s): %s", accountDataSource, err.Error())
|
||||||
}
|
}
|
||||||
deviceDB, err := devices.NewDatabase(accountDataSource)
|
deviceDB, err := devices.NewDatabase(accountDataSource, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("Failed to setup device database(%s): %s", accountDataSource, err.Error())
|
log.Panicf("Failed to setup device database(%s): %s", accountDataSource, err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,9 @@ func loadConfig(configPath string) (*config.Sync, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// check required fields
|
// check required fields
|
||||||
|
if cfg.ServerName == "" {
|
||||||
|
log.Fatalf("'server_name' must be supplied in %s", configPath)
|
||||||
|
}
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +77,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: DO NOT USE THIS DATA SOURCE (it's the sync one, not devices!)
|
// TODO: DO NOT USE THIS DATA SOURCE (it's the sync one, not devices!)
|
||||||
deviceDB, err := devices.NewDatabase(cfg.DataSource)
|
deviceDB, err := devices.NewDatabase(cfg.DataSource, cfg.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("startup: failed to create device database with data source %s : %s", cfg.DataSource, err)
|
log.Panicf("startup: failed to create device database with data source %s : %s", cfg.DataSource, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,6 +83,7 @@ func getLastRequestError() error {
|
||||||
var syncServerConfigFileContents = (`consumer_uris: ["` + kafkaURI + `"]
|
var syncServerConfigFileContents = (`consumer_uris: ["` + kafkaURI + `"]
|
||||||
roomserver_topic: "` + inputTopic + `"
|
roomserver_topic: "` + inputTopic + `"
|
||||||
database: "` + testDatabase + `"
|
database: "` + testDatabase + `"
|
||||||
|
server_name: "localhost"
|
||||||
`)
|
`)
|
||||||
|
|
||||||
func defaulting(value, defaultValue string) string {
|
func defaulting(value, defaultValue string) string {
|
||||||
|
|
|
@ -14,6 +14,10 @@
|
||||||
|
|
||||||
package config
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
// Sync contains the config information necessary to spin up a sync-server process.
|
// Sync contains the config information necessary to spin up a sync-server process.
|
||||||
type Sync struct {
|
type Sync struct {
|
||||||
// The topic for events which are written by the room server output log.
|
// The topic for events which are written by the room server output log.
|
||||||
|
@ -22,4 +26,6 @@ type Sync struct {
|
||||||
KafkaConsumerURIs []string `yaml:"consumer_uris"`
|
KafkaConsumerURIs []string `yaml:"consumer_uris"`
|
||||||
// The postgres connection config for connecting to the database e.g a postgres:// URI
|
// The postgres connection config for connecting to the database e.g a postgres:// URI
|
||||||
DataSource string `yaml:"database"`
|
DataSource string `yaml:"database"`
|
||||||
|
// The server_name of the running process e.g "localhost"
|
||||||
|
ServerName gomatrixserverlib.ServerName `yaml:"server_name"`
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue