Add context to the syncapi database (#234)
parent
238646ee3c
commit
856bc5b52e
|
@ -15,6 +15,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -264,13 +265,13 @@ func (m *monolith) setupProducers() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *monolith) setupNotifiers() {
|
func (m *monolith) setupNotifiers() {
|
||||||
pos, err := m.syncAPIDB.SyncStreamPosition()
|
pos, err := m.syncAPIDB.SyncStreamPosition(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("startup: failed to get latest sync stream position : %s", err)
|
log.Panicf("startup: failed to get latest sync stream position : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.syncAPINotifier = syncapi_sync.NewNotifier(syncapi_types.StreamPosition(pos))
|
m.syncAPINotifier = syncapi_sync.NewNotifier(syncapi_types.StreamPosition(pos))
|
||||||
if err = m.syncAPINotifier.Load(m.syncAPIDB); err != nil {
|
if err = m.syncAPINotifier.Load(context.Background(), m.syncAPIDB); err != nil {
|
||||||
log.Panicf("startup: failed to set up notifier: %s", err)
|
log.Panicf("startup: failed to set up notifier: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -67,13 +68,13 @@ func main() {
|
||||||
log.Panicf("startup: failed to create account database with data source %s : %s", cfg.Database.Account, err)
|
log.Panicf("startup: failed to create account database with data source %s : %s", cfg.Database.Account, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pos, err := db.SyncStreamPosition()
|
pos, err := db.SyncStreamPosition(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("startup: failed to get latest sync stream position : %s", err)
|
log.Panicf("startup: failed to get latest sync stream position : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n := sync.NewNotifier(types.StreamPosition(pos))
|
n := sync.NewNotifier(types.StreamPosition(pos))
|
||||||
if err = n.Load(db); err != nil {
|
if err = n.Load(context.Background(), db); err != nil {
|
||||||
log.Panicf("startup: failed to set up notifier: %s", err)
|
log.Panicf("startup: failed to set up notifier: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package consumers
|
package consumers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
|
@ -77,7 +78,9 @@ func (s *OutputClientData) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
"room_id": output.RoomID,
|
"room_id": output.RoomID,
|
||||||
}).Info("received data from client API server")
|
}).Info("received data from client API server")
|
||||||
|
|
||||||
syncStreamPos, err := s.db.UpsertAccountData(string(msg.Key), output.RoomID, output.Type)
|
syncStreamPos, err := s.db.UpsertAccountData(
|
||||||
|
context.TODO(), string(msg.Key), output.RoomID, output.Type,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"type": output.Type,
|
"type": output.Type,
|
||||||
|
|
|
@ -122,7 +122,11 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
syncStreamPos, err := s.db.WriteEvent(
|
syncStreamPos, err := s.db.WriteEvent(
|
||||||
&ev, addsStateEvents, output.NewRoomEvent.AddsStateEventIDs, output.NewRoomEvent.RemovesStateEventIDs,
|
context.TODO(),
|
||||||
|
&ev,
|
||||||
|
addsStateEvents,
|
||||||
|
output.NewRoomEvent.AddsStateEventIDs,
|
||||||
|
output.NewRoomEvent.RemovesStateEventIDs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -157,7 +161,7 @@ func (s *OutputRoomEvent) lookupStateEvents(
|
||||||
// Check if this is re-adding a state events that we previously processed
|
// Check if this is re-adding a state events that we previously processed
|
||||||
// If we have previously received a state event it may still be in
|
// If we have previously received a state event it may still be in
|
||||||
// our event database.
|
// our event database.
|
||||||
result, err := s.db.Events(addsStateEventIDs)
|
result, err := s.db.Events(context.TODO(), addsStateEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -205,7 +209,9 @@ func (s *OutputRoomEvent) updateStateEvent(event gomatrixserverlib.Event) (gomat
|
||||||
stateKey = *event.StateKey()
|
stateKey = *event.StateKey()
|
||||||
}
|
}
|
||||||
|
|
||||||
prevEvent, err := s.db.GetStateEvent(event.Type(), event.RoomID(), stateKey)
|
prevEvent, err := s.db.GetStateEvent(
|
||||||
|
context.TODO(), event.Type(), event.RoomID(), stateKey,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return event, err
|
return event, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
@ -71,14 +72,18 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) insertAccountData(
|
||||||
pos types.StreamPosition, userID string, roomID string, dataType string,
|
ctx context.Context,
|
||||||
|
pos types.StreamPosition,
|
||||||
|
userID, roomID, dataType string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.insertAccountDataStmt.Exec(pos, userID, roomID, dataType)
|
_, err = s.insertAccountDataStmt.ExecContext(ctx, pos, userID, roomID, dataType)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountDataInRange(
|
func (s *accountDataStatements) selectAccountDataInRange(
|
||||||
userID string, oldPos types.StreamPosition, newPos types.StreamPosition,
|
ctx context.Context,
|
||||||
|
userID string,
|
||||||
|
oldPos, newPos types.StreamPosition,
|
||||||
) (data map[string][]string, err error) {
|
) (data map[string][]string, err error) {
|
||||||
data = make(map[string][]string)
|
data = make(map[string][]string)
|
||||||
|
|
||||||
|
@ -89,7 +94,7 @@ func (s *accountDataStatements) selectAccountDataInRange(
|
||||||
oldPos--
|
oldPos--
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := s.selectAccountDataInRangeStmt.Query(userID, oldPos, newPos)
|
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
@ -114,8 +115,10 @@ func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinedMemberLists returns a map of room ID to a list of joined user IDs.
|
// JoinedMemberLists returns a map of room ID to a list of joined user IDs.
|
||||||
func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, error) {
|
func (s *currentRoomStateStatements) selectJoinedUsers(
|
||||||
rows, err := s.selectJoinedUsersStmt.Query()
|
ctx context.Context,
|
||||||
|
) (map[string][]string, error) {
|
||||||
|
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -136,8 +139,11 @@ func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, e
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||||
func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) {
|
func (s *currentRoomStateStatements) selectRoomIDsWithMembership(
|
||||||
rows, err := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt).Query(userID, membership)
|
ctx context.Context, txn *sql.Tx, userID, membership string,
|
||||||
|
) ([]string, error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, userID, membership)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -155,8 +161,11 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, us
|
||||||
}
|
}
|
||||||
|
|
||||||
// CurrentState returns all the current state events for the given room.
|
// CurrentState returns all the current state events for the given room.
|
||||||
func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) {
|
func (s *currentRoomStateStatements) selectCurrentState(
|
||||||
rows, err := common.TxStmt(txn, s.selectCurrentStateStmt).Query(roomID)
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) ([]gomatrixserverlib.Event, error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectCurrentStateStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -165,22 +174,37 @@ func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID stri
|
||||||
return rowsToEvents(rows)
|
return rowsToEvents(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error {
|
func (s *currentRoomStateStatements) deleteRoomStateByEventID(
|
||||||
_, err := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt).Exec(eventID)
|
ctx context.Context, txn *sql.Tx, eventID string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, eventID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) upsertRoomState(
|
func (s *currentRoomStateStatements) upsertRoomState(
|
||||||
txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
event gomatrixserverlib.Event, membership *string, addedAt int64,
|
||||||
) error {
|
) error {
|
||||||
_, err := common.TxStmt(txn, s.upsertRoomStateStmt).Exec(
|
stmt := common.TxStmt(txn, s.upsertRoomStateStmt)
|
||||||
event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt,
|
_, err := stmt.ExecContext(
|
||||||
|
ctx,
|
||||||
|
event.RoomID(),
|
||||||
|
event.EventID(),
|
||||||
|
event.Type(),
|
||||||
|
*event.StateKey(),
|
||||||
|
event.JSON(),
|
||||||
|
membership,
|
||||||
|
addedAt,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
|
func (s *currentRoomStateStatements) selectEventsWithEventIDs(
|
||||||
rows, err := common.TxStmt(txn, s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs))
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) ([]streamEvent, error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -205,11 +229,18 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) selectStateEvent(evType string, roomID string, stateKey string) (*gomatrixserverlib.Event, error) {
|
func (s *currentRoomStateStatements) selectStateEvent(
|
||||||
|
ctx context.Context, evType string, roomID string, stateKey string,
|
||||||
|
) (*gomatrixserverlib.Event, error) {
|
||||||
|
stmt := s.selectStateEventStmt
|
||||||
var res []byte
|
var res []byte
|
||||||
if err := s.selectStateEventStmt.QueryRow(evType, roomID, stateKey).Scan(&res); err == sql.ErrNoRows {
|
err := stmt.QueryRowContext(ctx, evType, roomID, stateKey).Scan(&res)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(res, false)
|
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(res, false)
|
||||||
return &ev, err
|
return &ev, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
|
@ -104,9 +105,11 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
|
||||||
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
|
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
|
||||||
// two positions, only the most recent state is returned.
|
// two positions, only the most recent state is returned.
|
||||||
func (s *outputRoomEventsStatements) selectStateInRange(
|
func (s *outputRoomEventsStatements) selectStateInRange(
|
||||||
txn *sql.Tx, oldPos, newPos types.StreamPosition,
|
ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition,
|
||||||
) (map[string]map[string]bool, map[string]streamEvent, error) {
|
) (map[string]map[string]bool, map[string]streamEvent, error) {
|
||||||
rows, err := common.TxStmt(txn, s.selectStateInRangeStmt).Query(oldPos, newPos)
|
stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, oldPos, newPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -167,9 +170,12 @@ func (s *outputRoomEventsStatements) selectStateInRange(
|
||||||
// MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied,
|
// MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied,
|
||||||
// then this function should only ever be used at startup, as it will race with inserting events if it is
|
// then this function should only ever be used at startup, as it will race with inserting events if it is
|
||||||
// done afterwards. If there are no inserted events, 0 is returned.
|
// done afterwards. If there are no inserted events, 0 is returned.
|
||||||
func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) {
|
func (s *outputRoomEventsStatements) selectMaxID(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
) (id int64, err error) {
|
||||||
var nullableID sql.NullInt64
|
var nullableID sql.NullInt64
|
||||||
err = common.TxStmt(txn, s.selectMaxIDStmt).QueryRow().Scan(&nullableID)
|
stmt := common.TxStmt(txn, s.selectMaxIDStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
||||||
if nullableID.Valid {
|
if nullableID.Valid {
|
||||||
id = nullableID.Int64
|
id = nullableID.Int64
|
||||||
}
|
}
|
||||||
|
@ -178,18 +184,29 @@ func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err err
|
||||||
|
|
||||||
// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
|
// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
|
||||||
// of the inserted event.
|
// of the inserted event.
|
||||||
func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) {
|
func (s *outputRoomEventsStatements) insertEvent(
|
||||||
err = common.TxStmt(txn, s.insertEventStmt).QueryRow(
|
ctx context.Context, txn *sql.Tx,
|
||||||
event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState),
|
event *gomatrixserverlib.Event, addState, removeState []string,
|
||||||
|
) (streamPos int64, err error) {
|
||||||
|
stmt := common.TxStmt(txn, s.insertEventStmt)
|
||||||
|
err = stmt.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
event.RoomID(),
|
||||||
|
event.EventID(),
|
||||||
|
event.JSON(),
|
||||||
|
pq.StringArray(addState),
|
||||||
|
pq.StringArray(removeState),
|
||||||
).Scan(&streamPos)
|
).Scan(&streamPos)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
|
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
|
||||||
func (s *outputRoomEventsStatements) selectRecentEvents(
|
func (s *outputRoomEventsStatements) selectRecentEvents(
|
||||||
_ *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
roomID string, fromPos, toPos types.StreamPosition, limit int,
|
||||||
) ([]streamEvent, error) {
|
) ([]streamEvent, error) {
|
||||||
rows, err := s.selectRecentEventsStmt.Query(roomID, fromPos, toPos, limit)
|
stmt := common.TxStmt(txn, s.selectRecentEventsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -205,8 +222,11 @@ func (s *outputRoomEventsStatements) selectRecentEvents(
|
||||||
|
|
||||||
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
|
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
|
||||||
// from the database.
|
// from the database.
|
||||||
func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
|
func (s *outputRoomEventsStatements) selectEvents(
|
||||||
rows, err := common.TxStmt(txn, s.selectEventsStmt).Query(pq.StringArray(eventIDs))
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) ([]streamEvent, error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectEventsStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
// Import the postgres database driver.
|
// Import the postgres database driver.
|
||||||
|
@ -75,16 +76,16 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
|
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
|
||||||
func (d *SyncServerDatabase) AllJoinedUsersInRooms() (map[string][]string, error) {
|
func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
|
||||||
return d.roomstate.selectJoinedUsers()
|
return d.roomstate.selectJoinedUsers(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Events lookups a list of event by their event ID.
|
// Events lookups a list of event by their event ID.
|
||||||
// Returns a list of events matching the requested IDs found in the database.
|
// Returns a list of events matching the requested IDs found in the database.
|
||||||
// If an event is not found in the database then it will be omitted from the list.
|
// If an event is not found in the database then it will be omitted from the list.
|
||||||
// Returns an error if there was a problem talking with the database
|
// Returns an error if there was a problem talking with the database
|
||||||
func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
||||||
streamEvents, err := d.events.selectEvents(nil, eventIDs)
|
streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -95,11 +96,14 @@ func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Even
|
||||||
// when generating the stream position for this event. Returns the sync stream position for the inserted event.
|
// when generating the stream position for this event. Returns the sync stream position for the inserted event.
|
||||||
// Returns an error if there was a problem inserting this event.
|
// Returns an error if there was a problem inserting this event.
|
||||||
func (d *SyncServerDatabase) WriteEvent(
|
func (d *SyncServerDatabase) WriteEvent(
|
||||||
ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string,
|
ctx context.Context,
|
||||||
|
ev *gomatrixserverlib.Event,
|
||||||
|
addStateEvents []gomatrixserverlib.Event,
|
||||||
|
addStateEventIDs, removeStateEventIDs []string,
|
||||||
) (streamPos types.StreamPosition, returnErr error) {
|
) (streamPos types.StreamPosition, returnErr error) {
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs)
|
pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -110,17 +114,20 @@ func (d *SyncServerDatabase) WriteEvent(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.updateRoomState(txn, removeStateEventIDs, addStateEvents, streamPos)
|
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, streamPos)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatabase) updateRoomState(
|
func (d *SyncServerDatabase) updateRoomState(
|
||||||
txn *sql.Tx, removedEventIDs []string, addedEvents []gomatrixserverlib.Event, streamPos types.StreamPosition,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
removedEventIDs []string,
|
||||||
|
addedEvents []gomatrixserverlib.Event,
|
||||||
|
streamPos types.StreamPosition,
|
||||||
) error {
|
) error {
|
||||||
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
|
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
|
||||||
for _, eventID := range removedEventIDs {
|
for _, eventID := range removedEventIDs {
|
||||||
if err := d.roomstate.deleteRoomStateByEventID(txn, eventID); err != nil {
|
if err := d.roomstate.deleteRoomStateByEventID(ctx, txn, eventID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -138,7 +145,7 @@ func (d *SyncServerDatabase) updateRoomState(
|
||||||
}
|
}
|
||||||
membership = &value
|
membership = &value
|
||||||
}
|
}
|
||||||
if err := d.roomstate.upsertRoomState(txn, event, membership, int64(streamPos)); err != nil {
|
if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, int64(streamPos)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -149,8 +156,10 @@ func (d *SyncServerDatabase) updateRoomState(
|
||||||
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
|
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
|
||||||
// If no event could be found, returns nil
|
// If no event could be found, returns nil
|
||||||
// If there was an issue during the retrieval, returns an error
|
// If there was an issue during the retrieval, returns an error
|
||||||
func (d *SyncServerDatabase) GetStateEvent(evType string, roomID string, stateKey string) (*gomatrixserverlib.Event, error) {
|
func (d *SyncServerDatabase) GetStateEvent(
|
||||||
return d.roomstate.selectStateEvent(evType, roomID, stateKey)
|
ctx context.Context, evType, roomID, stateKey string,
|
||||||
|
) (*gomatrixserverlib.Event, error) {
|
||||||
|
return d.roomstate.selectStateEvent(ctx, evType, roomID, stateKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PartitionOffsets implements common.PartitionStorer
|
// PartitionOffsets implements common.PartitionStorer
|
||||||
|
@ -164,8 +173,8 @@ func (d *SyncServerDatabase) SetPartitionOffset(topic string, partition int32, o
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
|
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
|
||||||
func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error) {
|
func (d *SyncServerDatabase) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {
|
||||||
id, err := d.events.selectMaxID(nil)
|
id, err := d.events.selectMaxID(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.StreamPosition(0), err
|
return types.StreamPosition(0), err
|
||||||
}
|
}
|
||||||
|
@ -173,13 +182,18 @@ func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IncrementalSync returns all the data needed in order to create an incremental sync response.
|
// IncrementalSync returns all the data needed in order to create an incremental sync response.
|
||||||
func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) {
|
func (d *SyncServerDatabase) IncrementalSync(
|
||||||
|
ctx context.Context,
|
||||||
|
userID string,
|
||||||
|
fromPos, toPos types.StreamPosition,
|
||||||
|
numRecentEventsPerRoom int,
|
||||||
|
) (res *types.Response, returnErr error) {
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
// Work out which rooms to return in the response. This is done by getting not only the currently
|
// Work out which rooms to return in the response. This is done by getting not only the currently
|
||||||
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
|
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
|
||||||
// This works out what the 'state' key should be for each room as well as which membership block
|
// This works out what the 'state' key should be for each room as well as which membership block
|
||||||
// to put the room into.
|
// to put the room into.
|
||||||
deltas, err := d.getStateDeltas(txn, fromPos, toPos, userID)
|
deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -196,7 +210,9 @@ func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types
|
||||||
// This is all "okay" assuming history_visibility == "shared" which it is by default.
|
// This is all "okay" assuming history_visibility == "shared" which it is by default.
|
||||||
endPos = delta.membershipPos
|
endPos = delta.membershipPos
|
||||||
}
|
}
|
||||||
recentStreamEvents, err := d.events.selectRecentEvents(txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom)
|
recentStreamEvents, err := d.events.selectRecentEvents(
|
||||||
|
ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -224,27 +240,29 @@ func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: This should be done in getStateDeltas
|
// TODO: This should be done in getStateDeltas
|
||||||
return d.addInvitesToResponse(txn, userID, res)
|
return d.addInvitesToResponse(ctx, txn, userID, res)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompleteSync a complete /sync API response for the given user.
|
// CompleteSync a complete /sync API response for the given user.
|
||||||
func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom int) (res *types.Response, returnErr error) {
|
func (d *SyncServerDatabase) CompleteSync(
|
||||||
|
ctx context.Context, userID string, numRecentEventsPerRoom int,
|
||||||
|
) (res *types.Response, returnErr error) {
|
||||||
// This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have
|
// This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have
|
||||||
// a consistent view of the database throughout. This includes extracting the sync stream position.
|
// a consistent view of the database throughout. This includes extracting the sync stream position.
|
||||||
// This does have the unfortunate side-effect that all the matrixy logic resides in this function,
|
// This does have the unfortunate side-effect that all the matrixy logic resides in this function,
|
||||||
// but it's better to not hide the fact that this is being done in a transaction.
|
// but it's better to not hide the fact that this is being done in a transaction.
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
// Get the current stream position which we will base the sync response on.
|
// Get the current stream position which we will base the sync response on.
|
||||||
id, err := d.events.selectMaxID(txn)
|
id, err := d.events.selectMaxID(ctx, txn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
pos := types.StreamPosition(id)
|
pos := types.StreamPosition(id)
|
||||||
|
|
||||||
// Extract room state and recent events for all rooms the user is joined to.
|
// Extract room state and recent events for all rooms the user is joined to.
|
||||||
roomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "join")
|
roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -252,14 +270,14 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
|
||||||
// Build up a /sync response. Add joined rooms.
|
// Build up a /sync response. Add joined rooms.
|
||||||
res = types.NewResponse(pos)
|
res = types.NewResponse(pos)
|
||||||
for _, roomID := range roomIDs {
|
for _, roomID := range roomIDs {
|
||||||
stateEvents, err := d.roomstate.selectCurrentState(txn, roomID)
|
stateEvents, err := d.roomstate.selectCurrentState(ctx, txn, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
||||||
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
||||||
recentStreamEvents, err := d.events.selectRecentEvents(
|
recentStreamEvents, err := d.events.selectRecentEvents(
|
||||||
txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom,
|
ctx, txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -274,7 +292,7 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
|
||||||
res.Rooms.Join[roomID] = *jr
|
res.Rooms.Join[roomID] = *jr
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.addInvitesToResponse(txn, userID, res)
|
return d.addInvitesToResponse(ctx, txn, userID, res)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -285,9 +303,9 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom
|
||||||
// If no data is retrieved, returns an empty map
|
// If no data is retrieved, returns an empty map
|
||||||
// If there was an issue with the retrieval, returns an error
|
// If there was an issue with the retrieval, returns an error
|
||||||
func (d *SyncServerDatabase) GetAccountDataInRange(
|
func (d *SyncServerDatabase) GetAccountDataInRange(
|
||||||
userID string, oldPos types.StreamPosition, newPos types.StreamPosition,
|
ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
|
||||||
) (map[string][]string, error) {
|
) (map[string][]string, error) {
|
||||||
return d.accountData.selectAccountDataInRange(userID, oldPos, newPos)
|
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpsertAccountData keeps track of new or updated account data, by saving the type
|
// UpsertAccountData keeps track of new or updated account data, by saving the type
|
||||||
|
@ -296,19 +314,22 @@ func (d *SyncServerDatabase) GetAccountDataInRange(
|
||||||
// If no data with the given type, user ID and room ID exists in the database,
|
// If no data with the given type, user ID and room ID exists in the database,
|
||||||
// creates a new row, else update the existing one
|
// creates a new row, else update the existing one
|
||||||
// Returns an error if there was an issue with the upsert
|
// Returns an error if there was an issue with the upsert
|
||||||
func (d *SyncServerDatabase) UpsertAccountData(userID string, roomID string, dataType string) (types.StreamPosition, error) {
|
func (d *SyncServerDatabase) UpsertAccountData(
|
||||||
pos, err := d.SyncStreamPosition()
|
ctx context.Context, userID, roomID, dataType string,
|
||||||
|
) (types.StreamPosition, error) {
|
||||||
|
pos, err := d.SyncStreamPosition(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pos, err
|
return pos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = d.accountData.insertAccountData(pos, userID, roomID, dataType)
|
err = d.accountData.insertAccountData(ctx, pos, userID, roomID, dataType)
|
||||||
return pos, err
|
return pos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatabase) addInvitesToResponse(txn *sql.Tx, userID string, res *types.Response) error {
|
func (d *SyncServerDatabase) addInvitesToResponse(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID string, res *types.Response) error {
|
||||||
// Add invites - TODO: This will break over federation as they won't be in the current state table according to Mark.
|
// Add invites - TODO: This will break over federation as they won't be in the current state table according to Mark.
|
||||||
roomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "invite")
|
roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "invite")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -322,7 +343,11 @@ func (d *SyncServerDatabase) addInvitesToResponse(txn *sql.Tx, userID string, re
|
||||||
|
|
||||||
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
|
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
|
||||||
// Returns a map of room ID to list of events.
|
// Returns a map of room ID to list of events.
|
||||||
func (d *SyncServerDatabase) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent) (map[string][]streamEvent, error) {
|
func (d *SyncServerDatabase) fetchStateEvents(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
roomIDToEventIDSet map[string]map[string]bool,
|
||||||
|
eventIDToEvent map[string]streamEvent,
|
||||||
|
) (map[string][]streamEvent, error) {
|
||||||
stateBetween := make(map[string][]streamEvent)
|
stateBetween := make(map[string][]streamEvent)
|
||||||
missingEvents := make(map[string][]string)
|
missingEvents := make(map[string][]string)
|
||||||
for roomID, ids := range roomIDToEventIDSet {
|
for roomID, ids := range roomIDToEventIDSet {
|
||||||
|
@ -350,7 +375,7 @@ func (d *SyncServerDatabase) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet ma
|
||||||
for _, missingEvIDs := range missingEvents {
|
for _, missingEvIDs := range missingEvents {
|
||||||
allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
|
allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
|
||||||
}
|
}
|
||||||
evs, err := d.fetchMissingStateEvents(txn, allMissingEventIDs)
|
evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -363,10 +388,12 @@ func (d *SyncServerDatabase) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet ma
|
||||||
return stateBetween, nil
|
return stateBetween, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatabase) fetchMissingStateEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
|
func (d *SyncServerDatabase) fetchMissingStateEvents(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) ([]streamEvent, error) {
|
||||||
// Fetch from the events table first so we pick up the stream ID for the
|
// Fetch from the events table first so we pick up the stream ID for the
|
||||||
// event.
|
// event.
|
||||||
events, err := d.events.selectEvents(txn, eventIDs)
|
events, err := d.events.selectEvents(ctx, txn, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -388,7 +415,7 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(txn *sql.Tx, eventIDs []str
|
||||||
// If they are missing from the events table then they should be state
|
// If they are missing from the events table then they should be state
|
||||||
// events that we received from outside the main event stream.
|
// events that we received from outside the main event stream.
|
||||||
// These should be in the room state table.
|
// These should be in the room state table.
|
||||||
stateEvents, err := d.roomstate.selectEventsWithEventIDs(txn, missing)
|
stateEvents, err := d.roomstate.selectEventsWithEventIDs(ctx, txn, missing)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -402,7 +429,10 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(txn *sql.Tx, eventIDs []str
|
||||||
return events, nil
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string) ([]stateDelta, error) {
|
func (d *SyncServerDatabase) getStateDeltas(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
fromPos, toPos types.StreamPosition, userID string,
|
||||||
|
) ([]stateDelta, error) {
|
||||||
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
|
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
|
||||||
// - Get membership list changes for this user in this sync response
|
// - Get membership list changes for this user in this sync response
|
||||||
// - For each room which has membership list changes:
|
// - For each room which has membership list changes:
|
||||||
|
@ -414,11 +444,11 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
|
||||||
var deltas []stateDelta
|
var deltas []stateDelta
|
||||||
|
|
||||||
// get all the state events ever between these two positions
|
// get all the state events ever between these two positions
|
||||||
stateNeeded, eventMap, err := d.events.selectStateInRange(txn, fromPos, toPos)
|
stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
state, err := d.fetchStateEvents(txn, stateNeeded, eventMap)
|
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -434,7 +464,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
|
||||||
if membership == "join" {
|
if membership == "join" {
|
||||||
// send full room state down instead of a delta
|
// send full room state down instead of a delta
|
||||||
var allState []gomatrixserverlib.Event
|
var allState []gomatrixserverlib.Event
|
||||||
allState, err = d.roomstate.selectCurrentState(txn, roomID)
|
allState, err = d.roomstate.selectCurrentState(ctx, txn, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -458,7 +488,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add in currently joined rooms
|
// Add in currently joined rooms
|
||||||
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "join")
|
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package sync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
|
@ -131,8 +132,8 @@ func (n *Notifier) WaitForEvents(req syncRequest) types.StreamPosition {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the membership states required to notify users correctly.
|
// Load the membership states required to notify users correctly.
|
||||||
func (n *Notifier) Load(db *storage.SyncServerDatabase) error {
|
func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) error {
|
||||||
roomToUsers, err := db.AllJoinedUsersInRooms()
|
roomToUsers, err := db.AllJoinedUsersInRooms(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,9 +108,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
|
||||||
func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (*types.Response, error) {
|
func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (*types.Response, error) {
|
||||||
// TODO: handle ignored users
|
// TODO: handle ignored users
|
||||||
if req.since == types.StreamPosition(0) {
|
if req.since == types.StreamPosition(0) {
|
||||||
return rp.db.CompleteSync(req.userID, req.limit)
|
return rp.db.CompleteSync(req.ctx, req.userID, req.limit)
|
||||||
}
|
}
|
||||||
return rp.db.IncrementalSync(req.userID, req.since, currentPos, req.limit)
|
return rp.db.IncrementalSync(req.ctx, req.userID, req.since, currentPos, req.limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *RequestPool) appendAccountData(
|
func (rp *RequestPool) appendAccountData(
|
||||||
|
@ -145,7 +145,7 @@ func (rp *RequestPool) appendAccountData(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync is not initial, get all account data since the latest sync
|
// Sync is not initial, get all account data since the latest sync
|
||||||
dataTypes, err := rp.db.GetAccountDataInRange(userID, req.since, currentPos)
|
dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since, currentPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue