diff --git a/src/github.com/matrix-org/dendrite/clientapi/jsonerror/jsonerror.go b/src/github.com/matrix-org/dendrite/clientapi/jsonerror/jsonerror.go index ea64896d..0cc432a9 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/jsonerror/jsonerror.go +++ b/src/github.com/matrix-org/dendrite/clientapi/jsonerror/jsonerror.go @@ -2,6 +2,7 @@ package jsonerror import ( "fmt" + "github.com/matrix-org/util" ) diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index 9679fab6..afb1b6a8 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -49,7 +49,7 @@ func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.ClientAPI } // SetupSyncServerListeners configures the given mux with sync-server listeners -func SetupSyncServerListeners(servMux *http.ServeMux, httpClient *http.Client, cfg config.Sync, srp sync.RequestPool) { +func SetupSyncServerListeners(servMux *http.ServeMux, httpClient *http.Client, cfg config.Sync, srp *sync.RequestPool) { apiMux := mux.NewRouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() r0mux.Handle("/sync", make("sync", util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { diff --git a/src/github.com/matrix-org/dendrite/clientapi/storage/output_room_events_table.go b/src/github.com/matrix-org/dendrite/clientapi/storage/output_room_events_table.go index 9c75940c..61dab2fb 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/storage/output_room_events_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/storage/output_room_events_table.go @@ -31,14 +31,22 @@ CREATE UNIQUE INDEX IF NOT EXISTS event_id_idx ON output_room_events(event_id); ` const insertEventSQL = "" + - "INSERT INTO output_room_events (room_id, event_id, event_json, add_state_ids, remove_state_ids) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO output_room_events (room_id, event_id, event_json, add_state_ids, remove_state_ids) VALUES ($1, $2, $3, $4, $5) RETURNING id" const selectEventsSQL = "" + "SELECT event_json FROM output_room_events WHERE event_id = ANY($1)" +const selectEventsInRangeSQL = "" + + "SELECT event_json FROM output_room_events WHERE id > $1 AND id <= $2" + +const selectMaxIDSQL = "" + + "SELECT MAX(id) FROM output_room_events" + type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectMaxIDStmt *sql.Stmt + selectEventsInRangeStmt *sql.Stmt } func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { @@ -52,15 +60,63 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil { return } + if s.selectMaxIDStmt, err = db.Prepare(selectMaxIDSQL); err != nil { + return + } + if s.selectEventsInRangeStmt, err = db.Prepare(selectEventsInRangeSQL); err != nil { + return + } return } -// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. -func (s *outputRoomEventsStatements) InsertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) error { - _, err := txn.Stmt(s.insertEventStmt).Exec( +// MaxID returns the ID of the last inserted event in this table. This should only ever be used at startup, as it will +// race with inserting events if it is done afterwards. If there are no inserted events, 0 is returned. +func (s *outputRoomEventsStatements) MaxID() (id int64, err error) { + var nullableID sql.NullInt64 + err = s.selectMaxIDStmt.QueryRow().Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} + +// InRange returns all the events in the range between oldPos exclusive and newPos inclusive. Returns an empty array if +// there are no events between the provided range. Returns an error if events are missing in the range. +func (s *outputRoomEventsStatements) InRange(oldPos, newPos int64) ([]gomatrixserverlib.Event, error) { + rows, err := s.selectEventsInRangeStmt.Query(oldPos, newPos) + if err != nil { + return nil, err + } + defer rows.Close() + + var result []gomatrixserverlib.Event + var i int64 + for ; rows.Next(); i++ { + var eventBytes []byte + if err := rows.Scan(&eventBytes); err != nil { + return nil, err + } + ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false) + if err != nil { + return nil, err + } + result = append(result, ev) + } + // Expect one event per position, exclusive of old. eg old=3, new=5, expect 4,5 so 2 events. + wantNum := (newPos - oldPos) + if i != wantNum { + return nil, fmt.Errorf("failed to map all positions to events: (got %d, wanted, %d)", i, wantNum) + } + return result, nil +} + +// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position +// of the inserted event. +func (s *outputRoomEventsStatements) InsertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) { + err = txn.Stmt(s.insertEventStmt).QueryRow( event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState), - ) - return err + ).Scan(&streamPos) + return } // Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing diff --git a/src/github.com/matrix-org/dendrite/clientapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/clientapi/storage/syncserver.go index 39df640b..e63680ae 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/clientapi/storage/syncserver.go @@ -39,10 +39,13 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races -// when generating the stream position for this event. Returns an error if there was a problem inserting this event. -func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string) error { - return runTransaction(d.db, func(txn *sql.Tx) error { - if err := d.events.InsertEvent(txn, ev, addStateEventIDs, removeStateEventIDs); err != nil { +// when generating the stream position for this event. Returns the sync stream position for the inserted event. +// Returns an error if there was a problem inserting this event. +func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string) (streamPos int64, returnErr error) { + returnErr = runTransaction(d.db, func(txn *sql.Tx) error { + var err error + streamPos, err = d.events.InsertEvent(txn, ev, addStateEventIDs, removeStateEventIDs) + if err != nil { return err } @@ -56,7 +59,7 @@ func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEve // However, conflict resolution may result in there being different events being added, or even some removed. if len(removeStateEventIDs) == 0 && len(addStateEventIDs) == 1 && addStateEventIDs[0] == ev.EventID() { // common case - if err := d.roomstate.UpdateRoomState(txn, []gomatrixserverlib.Event{*ev}, nil); err != nil { + if err = d.roomstate.UpdateRoomState(txn, []gomatrixserverlib.Event{*ev}, nil); err != nil { return err } return nil @@ -69,6 +72,7 @@ func (d *SyncServerDatabase) WriteEvent(ev *gomatrixserverlib.Event, addStateEve } return d.roomstate.UpdateRoomState(txn, added, removeStateEventIDs) }) + return } // PartitionOffsets implements common.PartitionStorer @@ -81,6 +85,16 @@ func (d *SyncServerDatabase) SetPartitionOffset(topic string, partition int32, o return d.partitions.UpsertPartitionOffset(topic, partition, offset) } +// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. +func (d *SyncServerDatabase) SyncStreamPosition() (int64, error) { + return d.events.MaxID() +} + +// EventsInRange returns all events in the given range, exclusive of oldPos, inclusive of newPos. +func (d *SyncServerDatabase) EventsInRange(oldPos, newPos int64) ([]gomatrixserverlib.Event, error) { + return d.events.InRange(oldPos, newPos) +} + func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { txn, err := db.Begin() if err != nil { diff --git a/src/github.com/matrix-org/dendrite/clientapi/sync/requestpool.go b/src/github.com/matrix-org/dendrite/clientapi/sync/requestpool.go index 62b16a1f..051e8663 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/sync/requestpool.go +++ b/src/github.com/matrix-org/dendrite/clientapi/sync/requestpool.go @@ -3,11 +3,13 @@ package sync import ( "net/http" "strconv" + "sync" "time" log "github.com/Sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -18,13 +20,26 @@ const defaultSyncTimeout = time.Duration(30) * time.Second type syncRequest struct { userID string timeout time.Duration - since string + since syncStreamPosition wantFullState bool } // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { db *storage.SyncServerDatabase + // The latest sync stream position: guarded by 'cond'. + currPos syncStreamPosition + // A condition variable to notify all waiting goroutines of a new sync stream position + cond *sync.Cond +} + +// NewRequestPool makes a new RequestPool +func NewRequestPool(db *storage.SyncServerDatabase) (*RequestPool, error) { + pos, err := db.SyncStreamPosition() + if err != nil { + return nil, err + } + return &RequestPool{db, syncStreamPosition(pos), sync.NewCond(&sync.Mutex{})}, nil } // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be @@ -37,7 +52,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request) util.JSONRespons if resErr != nil { return *resErr } - since := req.URL.Query().Get("since") + since, err := getSyncStreamPosition(req.URL.Query().Get("since")) + if err != nil { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(err.Error()), + } + } timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" @@ -54,34 +75,71 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request) util.JSONRespons "timeout": timeout, }).Info("Incoming /sync request") - res, err := rp.currentSyncForUser(syncReq) - if err != nil { - return httputil.LogThenError(req, err) - } - return util.JSONResponse{ - Code: 200, - JSON: res, + // Fork off 2 goroutines: one to do the work, and one to serve as a timeout. + // Whichever returns first is the one we will serve back to the client. + // TODO: Currently this means that cpu work is timed, which may not be what we want long term. + timeoutChan := make(chan struct{}) + timer := time.AfterFunc(timeout, func() { + close(timeoutChan) // signal that the timeout has expired + }) + + done := make(chan util.JSONResponse) + go func() { + syncData, err := rp.currentSyncForUser(syncReq) + timer.Stop() + var res util.JSONResponse + if err != nil { + res = httputil.LogThenError(req, err) + } else { + res = util.JSONResponse{ + Code: 200, + JSON: syncData, + } + } + done <- res + close(done) + }() + + select { + case <-timeoutChan: // timeout fired + return util.JSONResponse{ + Code: 200, + JSON: []struct{}{}, // return empty array for now + } + case res := <-done: // received a response + return res } } -// OnNewEvent is called when a new event is received from the room server -func (rp *RequestPool) OnNewEvent(ev *gomatrixserverlib.Event) { +// OnNewEvent is called when a new event is received from the room server. Must only be +// called from a single goroutine, to avoid races between updates which could set the +// current position in the stream incorrectly. +func (rp *RequestPool) OnNewEvent(ev *gomatrixserverlib.Event, pos syncStreamPosition) { + // update the current position in a guard and then notify all /sync streams + rp.cond.L.Lock() + rp.currPos = pos + rp.cond.L.Unlock() + rp.cond.Broadcast() // notify ALL waiting goroutines +} + +func (rp *RequestPool) waitForEvents(req syncRequest) syncStreamPosition { + // In a guard, check if the /sync request should block, and block it until we get a new position + rp.cond.L.Lock() + currentPos := rp.currPos + for req.since == currentPos { + // we need to wait for a new event. + // TODO: This waits for ANY new event, we need to only wait for events which we care about. + rp.cond.Wait() // atomically unlocks and blocks goroutine, then re-acquires lock on unblock + currentPos = rp.currPos + } + rp.cond.L.Unlock() + return currentPos } func (rp *RequestPool) currentSyncForUser(req syncRequest) ([]gomatrixserverlib.Event, error) { - // https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L179 - // Check if we are going to return immediately and if so, calculate the current - // sync for this user and return. - if req.since == "" || req.timeout == time.Duration(0) || req.wantFullState { - return []gomatrixserverlib.Event{}, nil - } - - // TODO: wait for an event which affects this user or one of their rooms, then recheck for new - // sync data. - time.Sleep(req.timeout) - - return nil, nil + currentPos := rp.waitForEvents(req) + return rp.db.EventsInRange(int64(req.since), int64(currentPos)) } func getTimeout(timeoutMS string) time.Duration { @@ -95,7 +153,13 @@ func getTimeout(timeoutMS string) time.Duration { return time.Duration(i) * time.Millisecond } -// NewRequestPool makes a new RequestPool -func NewRequestPool(db *storage.SyncServerDatabase) RequestPool { - return RequestPool{db} +func getSyncStreamPosition(since string) (syncStreamPosition, error) { + if since == "" { + return syncStreamPosition(0), nil + } + i, err := strconv.Atoi(since) + if err != nil { + return syncStreamPosition(0), err + } + return syncStreamPosition(i), nil } diff --git a/src/github.com/matrix-org/dendrite/clientapi/sync/syncserver.go b/src/github.com/matrix-org/dendrite/clientapi/sync/syncserver.go index 5454c698..892c163d 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/sync/syncserver.go +++ b/src/github.com/matrix-org/dendrite/clientapi/sync/syncserver.go @@ -12,15 +12,18 @@ import ( sarama "gopkg.in/Shopify/sarama.v1" ) +// syncStreamPosition represents the offset in the sync stream a client is at. +type syncStreamPosition int64 + // Server contains all the logic for running a sync server type Server struct { roomServerConsumer *common.ContinualConsumer db *storage.SyncServerDatabase - rp RequestPool + rp *RequestPool } // NewServer creates a new sync server. Call Start() to begin consuming from room servers. -func NewServer(cfg *config.Sync, rp RequestPool, store *storage.SyncServerDatabase) (*Server, error) { +func NewServer(cfg *config.Sync, rp *RequestPool, store *storage.SyncServerDatabase) (*Server, error) { kafkaConsumer, err := sarama.NewConsumer(cfg.KafkaConsumerURIs, nil) if err != nil { return nil, err @@ -46,6 +49,9 @@ func (s *Server) Start() error { return s.roomServerConsumer.Start() } +// onMessage is called when the sync server receives a new event from the room server output log. +// It is not safe for this function to be called from multiple goroutines, or else the +// sync stream position may race and be incorrectly calculated. func (s *Server) onMessage(msg *sarama.ConsumerMessage) error { // Parse out the event JSON var output api.OutputRoomEvent @@ -65,7 +71,9 @@ func (s *Server) onMessage(msg *sarama.ConsumerMessage) error { "room_id": ev.RoomID(), }).Info("received event from roomserver") - if err := s.db.WriteEvent(&ev, output.AddsStateEventIDs, output.RemovesStateEventIDs); err != nil { + syncStreamPos, err := s.db.WriteEvent(&ev, output.AddsStateEventIDs, output.RemovesStateEventIDs) + + if err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ "event": string(ev.JSON()), @@ -75,7 +83,7 @@ func (s *Server) onMessage(msg *sarama.ConsumerMessage) error { }).Panicf("roomserver output log: write event failure") return nil } - s.rp.OnNewEvent(&ev) + s.rp.OnNewEvent(&ev, syncStreamPosition(syncStreamPos)) return nil } diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-server/main.go index e04797c2..106f8bf5 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-server/main.go @@ -74,7 +74,10 @@ func main() { log.Panicf("startup: failed to create sync server database with data source %s : %s", cfg.DataSource, err) } - rp := sync.NewRequestPool(db) + rp, err := sync.NewRequestPool(db) + if err != nil { + log.Panicf("startup: Failed to create request pool : %s", err) + } server, err := sync.NewServer(cfg, rp, db) if err != nil { diff --git a/src/github.com/matrix-org/dendrite/common/consumers.go b/src/github.com/matrix-org/dendrite/common/consumers.go index 891e080b..caeeabca 100644 --- a/src/github.com/matrix-org/dendrite/common/consumers.go +++ b/src/github.com/matrix-org/dendrite/common/consumers.go @@ -67,7 +67,9 @@ func (c *ContinualConsumer) Start() error { } for _, offset := range storedOffsets { // We've already processed events from this partition so advance the offset to where we got to. - offsets[offset.Partition] = offset.Offset + // ConsumePartition will start streaming from the message with the given offset (inclusive), + // so increment 1 to avoid getting the same message a second time. + offsets[offset.Partition] = 1 + offset.Offset } var partitionConsumers []sarama.PartitionConsumer