Add transaction ID to events if sending device (#368)
parent
de6529d766
commit
b835e585c4
|
@ -19,6 +19,9 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
// Import the postgres database driver.
|
// Import the postgres database driver.
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
|
@ -86,13 +89,17 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[str
|
||||||
// Events lookups a list of event by their event ID.
|
// Events lookups a list of event by their event ID.
|
||||||
// Returns a list of events matching the requested IDs found in the database.
|
// Returns a list of events matching the requested IDs found in the database.
|
||||||
// If an event is not found in the database then it will be omitted from the list.
|
// If an event is not found in the database then it will be omitted from the list.
|
||||||
// Returns an error if there was a problem talking with the database
|
// Returns an error if there was a problem talking with the database.
|
||||||
|
// Does not include any transaction IDs in the returned events.
|
||||||
func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
||||||
streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
|
streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return streamEventsToEvents(streamEvents), nil
|
|
||||||
|
// We don't include a device here as we only include transaction IDs in
|
||||||
|
// incremental syncs.
|
||||||
|
return streamEventsToEvents(nil, streamEvents), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
|
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
|
||||||
|
@ -208,10 +215,14 @@ func (d *SyncServerDatabase) syncStreamPositionTx(
|
||||||
return types.StreamPosition(maxID), nil
|
return types.StreamPosition(maxID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IncrementalSync returns all the data needed in order to create an incremental sync response.
|
// IncrementalSync returns all the data needed in order to create an incremental
|
||||||
|
// sync response for the given user. Events returned will include any client
|
||||||
|
// transaction IDs associated with the given device. These transaction IDs come
|
||||||
|
// from when the device sent the event via an API that included a transaction
|
||||||
|
// ID.
|
||||||
func (d *SyncServerDatabase) IncrementalSync(
|
func (d *SyncServerDatabase) IncrementalSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID string,
|
device authtypes.Device,
|
||||||
fromPos, toPos types.StreamPosition,
|
fromPos, toPos types.StreamPosition,
|
||||||
numRecentEventsPerRoom int,
|
numRecentEventsPerRoom int,
|
||||||
) (*types.Response, error) {
|
) (*types.Response, error) {
|
||||||
|
@ -226,21 +237,21 @@ func (d *SyncServerDatabase) IncrementalSync(
|
||||||
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
|
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
|
||||||
// This works out what the 'state' key should be for each room as well as which membership block
|
// This works out what the 'state' key should be for each room as well as which membership block
|
||||||
// to put the room into.
|
// to put the room into.
|
||||||
deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID)
|
deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
res := types.NewResponse(toPos)
|
res := types.NewResponse(toPos)
|
||||||
for _, delta := range deltas {
|
for _, delta := range deltas {
|
||||||
err = d.addRoomDeltaToResponse(ctx, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
|
err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: This should be done in getStateDeltas
|
// TODO: This should be done in getStateDeltas
|
||||||
if err = d.addInvitesToResponse(ctx, txn, userID, fromPos, toPos, res); err != nil {
|
if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,7 +303,10 @@ func (d *SyncServerDatabase) CompleteSync(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
recentEvents := streamEventsToEvents(recentStreamEvents)
|
|
||||||
|
// We don't include a device here as we don't need to send down
|
||||||
|
// transaction IDs for complete syncs
|
||||||
|
recentEvents := streamEventsToEvents(nil, recentStreamEvents)
|
||||||
|
|
||||||
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
stateEvents = removeDuplicates(stateEvents, recentEvents)
|
||||||
jr := types.NewJoinResponse()
|
jr := types.NewJoinResponse()
|
||||||
|
@ -390,7 +404,9 @@ func (d *SyncServerDatabase) addInvitesToResponse(
|
||||||
|
|
||||||
// addRoomDeltaToResponse adds a room state delta to a sync response
|
// addRoomDeltaToResponse adds a room state delta to a sync response
|
||||||
func (d *SyncServerDatabase) addRoomDeltaToResponse(
|
func (d *SyncServerDatabase) addRoomDeltaToResponse(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context,
|
||||||
|
device *authtypes.Device,
|
||||||
|
txn *sql.Tx,
|
||||||
fromPos, toPos types.StreamPosition,
|
fromPos, toPos types.StreamPosition,
|
||||||
delta stateDelta,
|
delta stateDelta,
|
||||||
numRecentEventsPerRoom int,
|
numRecentEventsPerRoom int,
|
||||||
|
@ -412,7 +428,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
recentEvents := streamEventsToEvents(recentStreamEvents)
|
recentEvents := streamEventsToEvents(device, recentStreamEvents)
|
||||||
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
|
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
|
||||||
|
|
||||||
// Don't bother appending empty room entries
|
// Don't bother appending empty room entries
|
||||||
|
@ -529,7 +545,7 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatabase) getStateDeltas(
|
func (d *SyncServerDatabase) getStateDeltas(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, device *authtypes.Device, txn *sql.Tx,
|
||||||
fromPos, toPos types.StreamPosition, userID string,
|
fromPos, toPos types.StreamPosition, userID string,
|
||||||
) ([]stateDelta, error) {
|
) ([]stateDelta, error) {
|
||||||
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
|
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
|
||||||
|
@ -578,7 +594,7 @@ func (d *SyncServerDatabase) getStateDeltas(
|
||||||
deltas = append(deltas, stateDelta{
|
deltas = append(deltas, stateDelta{
|
||||||
membership: membership,
|
membership: membership,
|
||||||
membershipPos: ev.streamPosition,
|
membershipPos: ev.streamPosition,
|
||||||
stateEvents: streamEventsToEvents(stateStreamEvents),
|
stateEvents: streamEventsToEvents(device, stateStreamEvents),
|
||||||
roomID: roomID,
|
roomID: roomID,
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
|
@ -594,7 +610,7 @@ func (d *SyncServerDatabase) getStateDeltas(
|
||||||
for _, joinedRoomID := range joinedRoomIDs {
|
for _, joinedRoomID := range joinedRoomIDs {
|
||||||
deltas = append(deltas, stateDelta{
|
deltas = append(deltas, stateDelta{
|
||||||
membership: "join",
|
membership: "join",
|
||||||
stateEvents: streamEventsToEvents(state[joinedRoomID]),
|
stateEvents: streamEventsToEvents(device, state[joinedRoomID]),
|
||||||
roomID: joinedRoomID,
|
roomID: joinedRoomID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -602,10 +618,25 @@ func (d *SyncServerDatabase) getStateDeltas(
|
||||||
return deltas, nil
|
return deltas, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamEventsToEvents(in []streamEvent) []gomatrixserverlib.Event {
|
// streamEventsToEvents converts streamEvent to Event. If device is non-nil and
|
||||||
|
// matches the streamevent.transactionID device then the transaction ID gets
|
||||||
|
// added to the unsigned section of the output event.
|
||||||
|
func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event {
|
||||||
out := make([]gomatrixserverlib.Event, len(in))
|
out := make([]gomatrixserverlib.Event, len(in))
|
||||||
for i := 0; i < len(in); i++ {
|
for i := 0; i < len(in); i++ {
|
||||||
out[i] = in[i].Event
|
out[i] = in[i].Event
|
||||||
|
if device != nil && in[i].transactionID != nil {
|
||||||
|
if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID {
|
||||||
|
err := out[i].SetUnsignedField(
|
||||||
|
"transaction_id", in[i].transactionID.TransactionID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"event_id": out[i].EventID(),
|
||||||
|
}).WithError(err).Warnf("Failed to add transaction ID to event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,7 +123,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener {
|
||||||
|
|
||||||
n.removeEmptyUserStreams()
|
n.removeEmptyUserStreams()
|
||||||
|
|
||||||
return n.fetchUserStream(req.userID, true).GetListener(req.ctx)
|
return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the membership states required to notify users correctly.
|
// Load the membership states required to notify users correctly.
|
||||||
|
|
|
@ -21,6 +21,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
@ -262,7 +264,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) {
|
||||||
select {
|
select {
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
return types.StreamPosition(0), fmt.Errorf(
|
return types.StreamPosition(0), fmt.Errorf(
|
||||||
"waitForEvents timed out waiting for %s (pos=%d)", req.userID, req.since,
|
"waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since,
|
||||||
)
|
)
|
||||||
case <-listener.GetNotifyChannel(*req.since):
|
case <-listener.GetNotifyChannel(*req.since):
|
||||||
p := listener.GetStreamPosition()
|
p := listener.GetStreamPosition()
|
||||||
|
@ -280,7 +282,7 @@ func waitForBlocking(s *UserStream, numBlocking uint) {
|
||||||
|
|
||||||
func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest {
|
func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest {
|
||||||
return syncRequest{
|
return syncRequest{
|
||||||
userID: userID,
|
device: authtypes.Device{UserID: userID},
|
||||||
timeout: 1 * time.Minute,
|
timeout: 1 * time.Minute,
|
||||||
since: &since,
|
since: &since,
|
||||||
wantFullState: false,
|
wantFullState: false,
|
||||||
|
|
|
@ -20,6 +20,8 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
@ -31,7 +33,7 @@ const defaultTimelineLimit = 20
|
||||||
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
|
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
|
||||||
type syncRequest struct {
|
type syncRequest struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
userID string
|
device authtypes.Device
|
||||||
limit int
|
limit int
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
since *types.StreamPosition // nil means that no since token was supplied
|
since *types.StreamPosition // nil means that no since token was supplied
|
||||||
|
@ -39,7 +41,7 @@ type syncRequest struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
|
func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, error) {
|
||||||
timeout := getTimeout(req.URL.Query().Get("timeout"))
|
timeout := getTimeout(req.URL.Query().Get("timeout"))
|
||||||
fullState := req.URL.Query().Get("full_state")
|
fullState := req.URL.Query().Get("full_state")
|
||||||
wantFullState := fullState != "" && fullState != "false"
|
wantFullState := fullState != "" && fullState != "false"
|
||||||
|
@ -50,7 +52,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
|
||||||
// TODO: Additional query params: set_presence, filter
|
// TODO: Additional query params: set_presence, filter
|
||||||
return &syncRequest{
|
return &syncRequest{
|
||||||
ctx: req.Context(),
|
ctx: req.Context(),
|
||||||
userID: userID,
|
device: device,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
since: since,
|
since: since,
|
||||||
wantFullState: wantFullState,
|
wantFullState: wantFullState,
|
||||||
|
|
|
@ -48,7 +48,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
// Extract values from request
|
// Extract values from request
|
||||||
logger := util.GetLogger(req.Context())
|
logger := util.GetLogger(req.Context())
|
||||||
userID := device.UserID
|
userID := device.UserID
|
||||||
syncReq, err := newSyncRequest(req, userID)
|
syncReq, err := newSyncRequest(req, *device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 400,
|
Code: 400,
|
||||||
|
@ -122,16 +122,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) {
|
func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) {
|
||||||
// TODO: handle ignored users
|
// TODO: handle ignored users
|
||||||
if req.since == nil {
|
if req.since == nil {
|
||||||
res, err = rp.db.CompleteSync(req.ctx, req.userID, req.limit)
|
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
|
||||||
} else {
|
} else {
|
||||||
res, err = rp.db.IncrementalSync(req.ctx, req.userID, *req.since, currentPos, req.limit)
|
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, currentPos, req.limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err = rp.appendAccountData(res, req.userID, req, currentPos)
|
res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue