Associate transactions with session IDs instead of device IDs (#789)
This commit is contained in:
parent
5eb63f1d1e
commit
43308d2f3f
9 changed files with 55 additions and 39 deletions
|
@ -21,5 +21,9 @@ type Device struct {
|
|||
// The access_token granted to this device.
|
||||
// This uniquely identifies the device from all other devices and clients.
|
||||
AccessToken string
|
||||
// The unique ID of the session identified by the access token.
|
||||
// Can be used as a secure substitution in places where data needs to be
|
||||
// associated with access tokens.
|
||||
SessionID int64
|
||||
// TODO: display name, last used timestamp, keys, etc
|
||||
}
|
||||
|
|
|
@ -27,11 +27,19 @@ import (
|
|||
)
|
||||
|
||||
const devicesSchema = `
|
||||
-- This sequence is used for automatic allocation of session_id.
|
||||
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||
|
||||
-- Stores data about devices.
|
||||
CREATE TABLE IF NOT EXISTS device_devices (
|
||||
-- The access token granted to this device. This has to be the primary key
|
||||
-- so we can distinguish which device is making a given request.
|
||||
access_token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The auto-allocated unique ID of the session identified by the access token.
|
||||
-- This can be used as a secure substitution of the access token in situations
|
||||
-- where data is associated with access tokens (e.g. transaction storage),
|
||||
-- so we don't have to store users' access tokens everywhere.
|
||||
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
|
||||
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
|
||||
-- access_tokens will be clobbered based on the device ID for a user.
|
||||
device_id TEXT NOT NULL,
|
||||
|
@ -51,10 +59,11 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
|
|||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" +
|
||||
" RETURNING session_id"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
@ -120,14 +129,16 @@ func (s *devicesStatements) insertDevice(
|
|||
displayName *string,
|
||||
) (*authtypes.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil {
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authtypes.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -161,7 +172,7 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
var dev authtypes.Device
|
||||
var localpart string
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart)
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
if err == nil {
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.AccessToken = accessToken
|
||||
|
|
|
@ -60,18 +60,18 @@ func SendEvent(
|
|||
return *resErr
|
||||
}
|
||||
|
||||
var txnAndDeviceID *api.TransactionID
|
||||
var txnAndSessionID *api.TransactionID
|
||||
if txnID != nil {
|
||||
txnAndDeviceID = &api.TransactionID{
|
||||
txnAndSessionID = &api.TransactionID{
|
||||
TransactionID: *txnID,
|
||||
DeviceID: device.ID,
|
||||
SessionID: device.SessionID,
|
||||
}
|
||||
}
|
||||
|
||||
// pass the new event to the roomserver and receive the correct event ID
|
||||
// event ID in case of duplicate transaction is discarded
|
||||
eventID, err := producer.SendEvents(
|
||||
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndDeviceID,
|
||||
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
|
|
|
@ -75,9 +75,9 @@ type InputRoomEvent struct {
|
|||
}
|
||||
|
||||
// TransactionID contains the transaction ID sent by a client when sending an
|
||||
// event, along with the ID of that device.
|
||||
// event, along with the ID of the client session.
|
||||
type TransactionID struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
SessionID int64 `json:"session_id"`
|
||||
TransactionID string `json:"id"`
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ type RoomEventDatabase interface {
|
|||
StoreEvent(
|
||||
ctx context.Context,
|
||||
event gomatrixserverlib.Event,
|
||||
txnAndDeviceID *api.TransactionID,
|
||||
txnAndSessionID *api.TransactionID,
|
||||
authEventNIDs []types.EventNID,
|
||||
) (types.RoomNID, types.StateAtEvent, error)
|
||||
// Look up the state entries for a list of string event IDs
|
||||
|
@ -67,7 +67,7 @@ type RoomEventDatabase interface {
|
|||
// Returns an empty string if no such event exists.
|
||||
GetTransactionEventID(
|
||||
ctx context.Context, transactionID string,
|
||||
deviceID string, userID string,
|
||||
sessionID int64, userID string,
|
||||
) (string, error)
|
||||
}
|
||||
|
||||
|
@ -100,7 +100,7 @@ func processRoomEvent(
|
|||
if input.TransactionID != nil {
|
||||
tdID := input.TransactionID
|
||||
eventID, err = db.GetTransactionEventID(
|
||||
ctx, tdID.TransactionID, tdID.DeviceID, input.Event.Sender(),
|
||||
ctx, tdID.TransactionID, tdID.SessionID, input.Event.Sender(),
|
||||
)
|
||||
// On error OR event with the transaction already processed/processesing
|
||||
if err != nil || eventID != "" {
|
||||
|
|
|
@ -47,7 +47,7 @@ func Open(dataSourceName string) (*Database, error) {
|
|||
// StoreEvent implements input.EventDatabase
|
||||
func (d *Database) StoreEvent(
|
||||
ctx context.Context, event gomatrixserverlib.Event,
|
||||
txnAndDeviceID *api.TransactionID, authEventNIDs []types.EventNID,
|
||||
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
|
||||
) (types.RoomNID, types.StateAtEvent, error) {
|
||||
var (
|
||||
roomNID types.RoomNID
|
||||
|
@ -58,10 +58,10 @@ func (d *Database) StoreEvent(
|
|||
err error
|
||||
)
|
||||
|
||||
if txnAndDeviceID != nil {
|
||||
if txnAndSessionID != nil {
|
||||
if err = d.statements.insertTransaction(
|
||||
ctx, txnAndDeviceID.TransactionID,
|
||||
txnAndDeviceID.DeviceID, event.Sender(), event.EventID(),
|
||||
ctx, txnAndSessionID.TransactionID,
|
||||
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
|
||||
); err != nil {
|
||||
return 0, types.StateAtEvent{}, err
|
||||
}
|
||||
|
@ -322,9 +322,9 @@ func (d *Database) GetLatestEventsForUpdate(
|
|||
// GetTransactionEventID implements input.EventDatabase
|
||||
func (d *Database) GetTransactionEventID(
|
||||
ctx context.Context, transactionID string,
|
||||
deviceID string, userID string,
|
||||
sessionID int64, userID string,
|
||||
) (string, error) {
|
||||
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, deviceID, userID)
|
||||
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
|
|
@ -23,8 +23,8 @@ const transactionsSchema = `
|
|||
CREATE TABLE IF NOT EXISTS roomserver_transactions (
|
||||
-- The transaction ID of the event.
|
||||
transaction_id TEXT NOT NULL,
|
||||
-- The device ID of the originating transaction.
|
||||
device_id TEXT NOT NULL,
|
||||
-- The session ID of the originating transaction.
|
||||
session_id BIGINT NOT NULL,
|
||||
-- User ID of the sender who authored the event
|
||||
user_id TEXT NOT NULL,
|
||||
-- Event ID corresponding to the transaction
|
||||
|
@ -32,16 +32,16 @@ CREATE TABLE IF NOT EXISTS roomserver_transactions (
|
|||
event_id TEXT NOT NULL,
|
||||
-- A transaction ID is unique for a user and device
|
||||
-- This automatically creates an index.
|
||||
PRIMARY KEY (transaction_id, device_id, user_id)
|
||||
PRIMARY KEY (transaction_id, session_id, user_id)
|
||||
);
|
||||
`
|
||||
const insertTransactionSQL = "" +
|
||||
"INSERT INTO roomserver_transactions (transaction_id, device_id, user_id, event_id)" +
|
||||
"INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)" +
|
||||
" VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectTransactionEventIDSQL = "" +
|
||||
"SELECT event_id FROM roomserver_transactions" +
|
||||
" WHERE transaction_id = $1 AND device_id = $2 AND user_id = $3"
|
||||
" WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3"
|
||||
|
||||
type transactionStatements struct {
|
||||
insertTransactionStmt *sql.Stmt
|
||||
|
@ -63,12 +63,12 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) {
|
|||
func (s *transactionStatements) insertTransaction(
|
||||
ctx context.Context,
|
||||
transactionID string,
|
||||
deviceID string,
|
||||
sessionID int64,
|
||||
userID string,
|
||||
eventID string,
|
||||
) (err error) {
|
||||
_, err = s.insertTransactionStmt.ExecContext(
|
||||
ctx, transactionID, deviceID, userID, eventID,
|
||||
ctx, transactionID, sessionID, userID, eventID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
@ -76,11 +76,11 @@ func (s *transactionStatements) insertTransaction(
|
|||
func (s *transactionStatements) selectTransactionEventID(
|
||||
ctx context.Context,
|
||||
transactionID string,
|
||||
deviceID string,
|
||||
sessionID int64,
|
||||
userID string,
|
||||
) (eventID string, err error) {
|
||||
err = s.selectTransactionEventIDStmt.QueryRowContext(
|
||||
ctx, transactionID, deviceID, userID,
|
||||
ctx, transactionID, sessionID, userID,
|
||||
).Scan(&eventID)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
|
|||
-- if there is no delta.
|
||||
add_state_ids TEXT[],
|
||||
remove_state_ids TEXT[],
|
||||
device_id TEXT, -- The local device that sent the event, if any
|
||||
session_id BIGINT, -- The client session that sent the event, if any
|
||||
transaction_id TEXT -- The transaction id used to send the event, if any
|
||||
);
|
||||
-- for event selection
|
||||
|
@ -63,14 +63,14 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev
|
|||
|
||||
const insertEventSQL = "" +
|
||||
"INSERT INTO syncapi_output_room_events (" +
|
||||
"room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, device_id, transaction_id" +
|
||||
"room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" +
|
||||
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id"
|
||||
|
||||
const selectEventsSQL = "" +
|
||||
"SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT id, event_json, device_id, transaction_id FROM syncapi_output_room_events" +
|
||||
"SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
||||
" ORDER BY id DESC LIMIT $4"
|
||||
|
||||
|
@ -221,9 +221,10 @@ func (s *outputRoomEventsStatements) insertEvent(
|
|||
event *gomatrixserverlib.Event, addState, removeState []string,
|
||||
transactionID *api.TransactionID,
|
||||
) (streamPos int64, err error) {
|
||||
var deviceID, txnID *string
|
||||
var txnID *string
|
||||
var sessionID *int64
|
||||
if transactionID != nil {
|
||||
deviceID = &transactionID.DeviceID
|
||||
sessionID = &transactionID.SessionID
|
||||
txnID = &transactionID.TransactionID
|
||||
}
|
||||
|
||||
|
@ -246,7 +247,7 @@ func (s *outputRoomEventsStatements) insertEvent(
|
|||
containsURL,
|
||||
pq.StringArray(addState),
|
||||
pq.StringArray(removeState),
|
||||
deviceID,
|
||||
sessionID,
|
||||
txnID,
|
||||
).Scan(&streamPos)
|
||||
return
|
||||
|
@ -296,11 +297,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
|
|||
var (
|
||||
streamPos int64
|
||||
eventBytes []byte
|
||||
deviceID *string
|
||||
sessionID *int64
|
||||
txnID *string
|
||||
transactionID *api.TransactionID
|
||||
)
|
||||
if err := rows.Scan(&streamPos, &eventBytes, &deviceID, &txnID); err != nil {
|
||||
if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: Handle redacted events
|
||||
|
@ -309,9 +310,9 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if deviceID != nil && txnID != nil {
|
||||
if sessionID != nil && txnID != nil {
|
||||
transactionID = &api.TransactionID{
|
||||
DeviceID: *deviceID,
|
||||
SessionID: *sessionID,
|
||||
TransactionID: *txnID,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -893,7 +893,7 @@ func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrix
|
|||
for i := 0; i < len(in); i++ {
|
||||
out[i] = in[i].Event
|
||||
if device != nil && in[i].transactionID != nil {
|
||||
if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID {
|
||||
if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID {
|
||||
err := out[i].SetUnsignedField(
|
||||
"transaction_id", in[i].transactionID.TransactionID,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue