Factor out runTransaction to common code (#162)
parent
d3a29b7816
commit
b06d1124f7
|
@ -24,7 +24,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
@ -40,10 +39,16 @@ var UnknownDeviceID = "unknown-device"
|
||||||
// 32 bytes => 256 bits
|
// 32 bytes => 256 bits
|
||||||
var tokenByteLength = 32
|
var tokenByteLength = 32
|
||||||
|
|
||||||
|
// DeviceDatabase represents a device database.
|
||||||
|
type DeviceDatabase interface {
|
||||||
|
// Lookup the device matching the given access token.
|
||||||
|
GetDeviceByAccessToken(token string) (*authtypes.Device, error)
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyAccessToken verifies that an access token was supplied in the given HTTP request
|
// VerifyAccessToken verifies that an access token was supplied in the given HTTP request
|
||||||
// and returns the device it corresponds to. Returns resErr (an error response which can be
|
// and returns the device it corresponds to. Returns resErr (an error response which can be
|
||||||
// sent to the client) if the token is invalid or there was a problem querying the database.
|
// sent to the client) if the token is invalid or there was a problem querying the database.
|
||||||
func VerifyAccessToken(req *http.Request, deviceDB *devices.Database) (device *authtypes.Device, resErr *util.JSONResponse) {
|
func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *authtypes.Device, resErr *util.JSONResponse) {
|
||||||
token, err := extractAccessToken(req)
|
token, err := extractAccessToken(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resErr = &util.JSONResponse{
|
resErr = &util.JSONResponse{
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,7 +54,7 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro
|
||||||
// an error will be returned.
|
// an error will be returned.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) {
|
func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) {
|
||||||
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
// Revoke existing token for this device
|
// Revoke existing token for this device
|
||||||
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil {
|
if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil {
|
||||||
|
@ -74,30 +75,10 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a
|
||||||
// If the device doesn't exist, it will not return an error
|
// If the device doesn't exist, it will not return an error
|
||||||
// If something went wrong during the deletion, it will return the SQL error
|
// If something went wrong during the deletion, it will return the SQL error
|
||||||
func (d *Database) RemoveDevice(deviceID string, localpart string) error {
|
func (d *Database) RemoveDevice(deviceID string, localpart string) error {
|
||||||
return runTransaction(d.db, func(txn *sql.Tx) error {
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows {
|
if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: factor out to common
|
|
||||||
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
|
||||||
txn, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
txn.Rollback()
|
|
||||||
panic(r)
|
|
||||||
} else if err != nil {
|
|
||||||
txn.Rollback()
|
|
||||||
} else {
|
|
||||||
err = txn.Commit()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
err = fn(txn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"net/http"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request.
|
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request.
|
||||||
func MakeAuthAPI(metricsName string, deviceDB *devices.Database, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler {
|
func MakeAuthAPI(metricsName string, deviceDB auth.DeviceDatabase, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler {
|
||||||
h := util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
|
h := util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
|
||||||
device, resErr := auth.VerifyAccessToken(req, deviceDB)
|
device, resErr := auth.VerifyAccessToken(req, deviceDB)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithTransaction runs a block of code passing in an SQL transaction
|
||||||
|
// If the code returns an error or panics then the transactions is rolledback
|
||||||
|
// Otherwise the transaction is committed.
|
||||||
|
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||||
|
txn, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
panic(r)
|
||||||
|
} else if err != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
} else {
|
||||||
|
err = txn.Commit()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = fn(txn)
|
||||||
|
return
|
||||||
|
}
|
|
@ -77,7 +77,7 @@ func (d *Database) UpdateRoom(
|
||||||
addHosts []types.JoinedHost,
|
addHosts []types.JoinedHost,
|
||||||
removeHosts []string,
|
removeHosts []string,
|
||||||
) (joinedHosts []types.JoinedHost, err error) {
|
) (joinedHosts []types.JoinedHost, err error) {
|
||||||
err = runTransaction(d.db, func(txn *sql.Tx) error {
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
if err = d.insertRoom(txn, roomID); err != nil {
|
if err = d.insertRoom(txn, roomID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -105,22 +105,3 @@ func (d *Database) UpdateRoom(
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
|
||||||
txn, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
txn.Rollback()
|
|
||||||
panic(r)
|
|
||||||
} else if err != nil {
|
|
||||||
txn.Rollback()
|
|
||||||
} else {
|
|
||||||
err = txn.Commit()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
err = fn(txn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
|
@ -92,7 +92,7 @@ func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Even
|
||||||
func (d *SyncServerDatabase) WriteEvent(
|
func (d *SyncServerDatabase) WriteEvent(
|
||||||
ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string,
|
ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string,
|
||||||
) (streamPos types.StreamPosition, returnErr error) {
|
) (streamPos types.StreamPosition, returnErr error) {
|
||||||
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs)
|
pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -162,7 +162,7 @@ func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error)
|
||||||
|
|
||||||
// IncrementalSync returns all the data needed in order to create an incremental sync response.
|
// IncrementalSync returns all the data needed in order to create an incremental sync response.
|
||||||
func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) {
|
func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) {
|
||||||
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
// Work out which rooms to return in the response. This is done by getting not only the currently
|
// Work out which rooms to return in the response. This is done by getting not only the currently
|
||||||
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
|
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
|
||||||
// This works out what the 'state' key should be for each room as well as which membership block
|
// This works out what the 'state' key should be for each room as well as which membership block
|
||||||
|
@ -223,7 +223,7 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
|
||||||
// a consistent view of the database throughout. This includes extracting the sync stream position.
|
// a consistent view of the database throughout. This includes extracting the sync stream position.
|
||||||
// This does have the unfortunate side-effect that all the matrixy logic resides in this function,
|
// This does have the unfortunate side-effect that all the matrixy logic resides in this function,
|
||||||
// but it's better to not hide the fact that this is being done in a transaction.
|
// but it's better to not hide the fact that this is being done in a transaction.
|
||||||
returnErr = runTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
// Get the current stream position which we will base the sync response on.
|
// Get the current stream position which we will base the sync response on.
|
||||||
id, err := d.events.selectMaxID(txn)
|
id, err := d.events.selectMaxID(txn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -479,22 +479,3 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
|
||||||
txn, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
txn.Rollback()
|
|
||||||
panic(r)
|
|
||||||
} else if err != nil {
|
|
||||||
txn.Rollback()
|
|
||||||
} else {
|
|
||||||
err = txn.Commit()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
err = fn(txn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue