diff --git a/clientapi/producers/eduserver.go b/clientapi/producers/eduserver.go index 30c40fb7..102c1fad 100644 --- a/clientapi/producers/eduserver.go +++ b/clientapi/producers/eduserver.go @@ -14,6 +14,7 @@ package producers import ( "context" + "encoding/json" "time" "github.com/matrix-org/dendrite/eduserver/api" @@ -52,3 +53,28 @@ func (p *EDUServerProducer) SendTyping( return err } + +// SendToDevice sends a typing event to EDU server +func (p *EDUServerProducer) SendToDevice( + ctx context.Context, sender, userID, deviceID, eventType string, + message interface{}, +) error { + js, err := json.Marshal(message) + if err != nil { + return err + } + requestData := api.InputSendToDeviceEvent{ + UserID: userID, + DeviceID: deviceID, + SendToDeviceEvent: gomatrixserverlib.SendToDeviceEvent{ + Sender: sender, + Type: eventType, + Content: js, + }, + } + request := api.InputSendToDeviceEventRequest{ + InputSendToDeviceEvent: requestData, + } + response := api.InputSendToDeviceEventResponse{} + return p.InputAPI.InputSendToDeviceEvent(ctx, &request, &response) +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 934d9f06..83e399ac 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -274,6 +274,31 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/sendToDevice/{eventType}/{txnID}", + internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendToDevice(req, device, eduProducer, transactionsCache, vars["eventType"], &txnID) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + // This is only here because sytest refers to /unstable for this endpoint + // rather than r0. It's an exact duplicate of the above handler. + // TODO: Remove this if/when sytest is fixed! + unstableMux.Handle("/sendToDevice/{eventType}/{txnID}", + internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := internal.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendToDevice(req, device, eduProducer, transactionsCache, vars["eventType"], &txnID) + }), + ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/account/whoami", internal.MakeAuthAPI("whoami", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { return Whoami(req, device) diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go new file mode 100644 index 00000000..5d3060d7 --- /dev/null +++ b/clientapi/routing/sendtodevice.go @@ -0,0 +1,70 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/internal/transactions" + "github.com/matrix-org/util" +) + +// SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId} +// sends the device events to the EDU Server +func SendToDevice( + req *http.Request, device *authtypes.Device, + eduProducer *producers.EDUServerProducer, + txnCache *transactions.Cache, + eventType string, txnID *string, +) util.JSONResponse { + if txnID != nil { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + return *res + } + } + + var httpReq struct { + Messages map[string]map[string]json.RawMessage `json:"messages"` + } + resErr := httputil.UnmarshalJSONRequest(req, &httpReq) + if resErr != nil { + return *resErr + } + + for userID, byUser := range httpReq.Messages { + for deviceID, message := range byUser { + if err := eduProducer.SendToDevice( + req.Context(), device.UserID, userID, deviceID, eventType, message, + ); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("eduProducer.SendToDevice failed") + return jsonerror.InternalServerError() + } + } + } + + res := util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + + if txnID != nil { + txnCache.AddTransaction(device.AccessToken, *txnID, &res) + } + + return res +} diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index e06adf8f..f919243d 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -39,7 +39,7 @@ func main() { rsAPI := base.CreateHTTPRoomserverAPIs() fsAPI := base.CreateHTTPFederationSenderAPIs() rsAPI.SetFederationSenderAPI(fsAPI) - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) + eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB) clientapi.SetupClientAPIComponent( base, deviceDB, accountDB, federation, keyRing, diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index fc56b9bb..e9d01fd9 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -148,7 +148,7 @@ func main() { &base.Base, keyRing, federation, ) eduInputAPI := eduserver.SetupEDUServerComponent( - &base.Base, cache.New(), + &base.Base, cache.New(), deviceDB, ) asAPI := appservice.SetupAppServiceAPIComponent( &base.Base, accountDB, deviceDB, federation, rsAPI, transactions.New(), diff --git a/cmd/dendrite-edu-server/main.go b/cmd/dendrite-edu-server/main.go index 66e17e57..ca0460f8 100644 --- a/cmd/dendrite-edu-server/main.go +++ b/cmd/dendrite-edu-server/main.go @@ -29,8 +29,9 @@ func main() { logrus.WithError(err).Warn("BaseDendrite close failed") } }() + deviceDB := base.CreateDeviceDB() - eduserver.SetupEDUServerComponent(base, cache.New()) + eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB) base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer)) diff --git a/cmd/dendrite-federation-api-server/main.go b/cmd/dendrite-federation-api-server/main.go index 5425d117..af63b549 100644 --- a/cmd/dendrite-federation-api-server/main.go +++ b/cmd/dendrite-federation-api-server/main.go @@ -39,7 +39,7 @@ func main() { rsAPI := base.CreateHTTPRoomserverAPIs() asAPI := base.CreateHTTPAppServiceAPIs() rsAPI.SetFederationSenderAPI(fsAPI) - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) + eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB) eduProducer := producers.NewEDUServerProducer(eduInputAPI) federationapi.SetupFederationAPIComponent( diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 8367cd9d..ef114ccd 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -87,7 +87,7 @@ func main() { } eduInputAPI := eduserver.SetupEDUServerComponent( - base, cache.New(), + base, cache.New(), deviceDB, ) if base.EnableHTTPAPIs { eduInputAPI = base.CreateHTTPEDUServerAPIs() diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 45f23d9a..9a02e71e 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -175,6 +175,7 @@ func main() { cfg.Database.SyncAPI = "file:/idb/dendritejs_syncapi.db" cfg.Kafka.Topics.UserUpdates = "user_updates" cfg.Kafka.Topics.OutputTypingEvent = "output_typing_event" + cfg.Kafka.Topics.OutputSendToDeviceEvent = "output_send_to_device_event" cfg.Kafka.Topics.OutputClientData = "output_client_data" cfg.Kafka.Topics.OutputRoomEvent = "output_room_event" cfg.Matrix.TrustedIDServers = []string{ @@ -206,7 +207,7 @@ func main() { p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node) rsAPI := roomserver.SetupRoomServerComponent(base, keyRing, federation) - eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) + eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB) asQuery := appservice.SetupAppServiceAPIComponent( base, accountDB, deviceDB, federation, rsAPI, transactions.New(), ) diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 1802b8b7..a5b29597 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -104,7 +104,8 @@ kafka: topics: output_room_event: roomserverOutput output_client_data: clientapiOutput - output_typing_event: eduServerOutput + output_typing_event: eduServerTypingOutput + output_send_to_device_event: eduServerSendToDeviceOutput user_updates: userUpdates # The postgres connection configs for connecting to the databases e.g a postgres:// URI @@ -137,8 +138,8 @@ listen: federation_sender: "localhost:7776" appservice_api: "localhost:7777" edu_server: "localhost:7778" - key_server: "localhost:7779" - server_key_api: "localhost:7780" + key_server: "localhost:7779" + server_key_api: "localhost:7780" # The configuration for tracing the dendrite components. tracing: diff --git a/eduserver/api/input.go b/eduserver/api/input.go index 8b5b6d76..fa7f30cb 100644 --- a/eduserver/api/input.go +++ b/eduserver/api/input.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -37,6 +41,12 @@ type InputTypingEvent struct { OriginServerTS gomatrixserverlib.Timestamp `json:"origin_server_ts"` } +type InputSendToDeviceEvent struct { + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + gomatrixserverlib.SendToDeviceEvent +} + // InputTypingEventRequest is a request to EDUServerInputAPI type InputTypingEventRequest struct { InputTypingEvent InputTypingEvent `json:"input_typing_event"` @@ -45,6 +55,14 @@ type InputTypingEventRequest struct { // InputTypingEventResponse is a response to InputTypingEvents type InputTypingEventResponse struct{} +// InputSendToDeviceEventRequest is a request to EDUServerInputAPI +type InputSendToDeviceEventRequest struct { + InputSendToDeviceEvent InputSendToDeviceEvent `json:"input_send_to_device_event"` +} + +// InputSendToDeviceEventResponse is a response to InputSendToDeviceEventRequest +type InputSendToDeviceEventResponse struct{} + // EDUServerInputAPI is used to write events to the typing server. type EDUServerInputAPI interface { InputTypingEvent( @@ -52,11 +70,20 @@ type EDUServerInputAPI interface { request *InputTypingEventRequest, response *InputTypingEventResponse, ) error + + InputSendToDeviceEvent( + ctx context.Context, + request *InputSendToDeviceEventRequest, + response *InputSendToDeviceEventResponse, + ) error } // EDUServerInputTypingEventPath is the HTTP path for the InputTypingEvent API. const EDUServerInputTypingEventPath = "/eduserver/input" +// EDUServerInputSendToDeviceEventPath is the HTTP path for the InputSendToDeviceEvent API. +const EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" + // NewEDUServerInputAPIHTTP creates a EDUServerInputAPI implemented by talking to a HTTP POST API. func NewEDUServerInputAPIHTTP(eduServerURL string, httpClient *http.Client) (EDUServerInputAPI, error) { if httpClient == nil { @@ -70,7 +97,7 @@ type httpEDUServerInputAPI struct { httpClient *http.Client } -// InputRoomEvents implements EDUServerInputAPI +// InputTypingEvent implements EDUServerInputAPI func (h *httpEDUServerInputAPI) InputTypingEvent( ctx context.Context, request *InputTypingEventRequest, @@ -82,3 +109,16 @@ func (h *httpEDUServerInputAPI) InputTypingEvent( apiURL := h.eduServerURL + EDUServerInputTypingEventPath return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// InputSendToDeviceEvent implements EDUServerInputAPI +func (h *httpEDUServerInputAPI) InputSendToDeviceEvent( + ctx context.Context, + request *InputSendToDeviceEventRequest, + response *InputSendToDeviceEventResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputSendToDeviceEvent") + defer span.Finish() + + apiURL := h.eduServerURL + EDUServerInputSendToDeviceEventPath + return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/eduserver/api/output.go b/eduserver/api/output.go index 8696acf4..e6ded841 100644 --- a/eduserver/api/output.go +++ b/eduserver/api/output.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -12,7 +16,11 @@ package api -import "time" +import ( + "time" + + "github.com/matrix-org/gomatrixserverlib" +) // OutputTypingEvent is an entry in typing server output kafka log. // This contains the event with extra fields used to create 'm.typing' event @@ -32,3 +40,12 @@ type TypingEvent struct { UserID string `json:"user_id"` Typing bool `json:"typing"` } + +// OutputSendToDeviceEvent is an entry in the send-to-device output kafka log. +// This contains the full event content, along with the user ID and device ID +// to which it is destined. +type OutputSendToDeviceEvent struct { + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + gomatrixserverlib.SendToDeviceEvent +} diff --git a/eduserver/cache/cache.go b/eduserver/cache/cache.go index 46f7a2b1..dd535a6d 100644 --- a/eduserver/cache/cache.go +++ b/eduserver/cache/cache.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -109,6 +113,19 @@ func (t *EDUCache) AddTypingUser( return t.GetLatestSyncPosition() } +// AddSendToDeviceMessage increases the sync position for +// send-to-device updates. +// Returns the sync position before update, as the caller +// will use this to record the current stream position +// at the time that the send-to-device message was sent. +func (t *EDUCache) AddSendToDeviceMessage() int64 { + t.Lock() + defer t.Unlock() + latestSyncPosition := t.latestSyncPosition + t.latestSyncPosition++ + return latestSyncPosition +} + // addUser with mutex lock & replace the previous timer. // Returns the latest typing sync position after update. func (t *EDUCache) addUser( diff --git a/eduserver/cache/cache_test.go b/eduserver/cache/cache_test.go index d1b2f8bd..c7d01879 100644 --- a/eduserver/cache/cache_test.go +++ b/eduserver/cache/cache_test.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index 14fbd332..6f664eb6 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -13,6 +17,7 @@ package eduserver import ( + "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/input" @@ -26,11 +31,15 @@ import ( func SetupEDUServerComponent( base *basecomponent.BaseDendrite, eduCache *cache.EDUCache, + deviceDB devices.Database, ) api.EDUServerInputAPI { inputAPI := &input.EDUServerInputAPI{ - Cache: eduCache, - Producer: base.KafkaProducer, - OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), + Cache: eduCache, + DeviceDB: deviceDB, + Producer: base.KafkaProducer, + OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), + OutputSendToDeviceEventTopic: string(base.Cfg.Kafka.Topics.OutputSendToDeviceEvent), + ServerName: base.Cfg.Matrix.ServerName, } inputAPI.SetupHTTP(base.InternalAPIMux) diff --git a/eduserver/input/input.go b/eduserver/input/input.go index 73777e32..4e305195 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -1,3 +1,7 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -20,11 +24,13 @@ import ( "github.com/Shopify/sarama" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // EDUServerInputAPI implements api.EDUServerInputAPI @@ -33,8 +39,14 @@ type EDUServerInputAPI struct { Cache *cache.EDUCache // The kafka topic to output new typing events to. OutputTypingEventTopic string + // The kafka topic to output new send to device events to. + OutputSendToDeviceEventTopic string // kafka producer Producer sarama.SyncProducer + // device database + DeviceDB devices.Database + // our server name + ServerName gomatrixserverlib.ServerName } // InputTypingEvent implements api.EDUServerInputAPI @@ -54,10 +66,20 @@ func (t *EDUServerInputAPI) InputTypingEvent( t.Cache.RemoveUser(ite.UserID, ite.RoomID) } - return t.sendEvent(ite) + return t.sendTypingEvent(ite) } -func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { +// InputTypingEvent implements api.EDUServerInputAPI +func (t *EDUServerInputAPI) InputSendToDeviceEvent( + ctx context.Context, + request *api.InputSendToDeviceEventRequest, + response *api.InputSendToDeviceEventResponse, +) error { + ise := &request.InputSendToDeviceEvent + return t.sendToDeviceEvent(ise) +} + +func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { ev := &api.TypingEvent{ Type: gomatrixserverlib.MTyping, RoomID: ite.RoomID, @@ -90,6 +112,65 @@ func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { return err } +func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) error { + devices := []string{} + localpart, domain, err := gomatrixserverlib.SplitID('@', ise.UserID) + if err != nil { + return err + } + + // If the event is targeted locally then we want to expand the wildcard + // out into individual device IDs so that we can send them to each respective + // device. If the event isn't targeted locally then we can't expand the + // wildcard as we don't know about the remote devices, so instead we leave it + // as-is, so that the federation sender can send it on with the wildcard intact. + if domain == t.ServerName && ise.DeviceID == "*" { + devs, err := t.DeviceDB.GetDevicesByLocalpart(context.TODO(), localpart) + if err != nil { + return err + } + for _, dev := range devs { + devices = append(devices, dev.ID) + } + } else { + devices = append(devices, ise.DeviceID) + } + + for _, device := range devices { + ote := &api.OutputSendToDeviceEvent{ + UserID: ise.UserID, + DeviceID: device, + SendToDeviceEvent: ise.SendToDeviceEvent, + } + + logrus.WithFields(logrus.Fields{ + "user_id": ise.UserID, + "device_id": ise.DeviceID, + "event_type": ise.Type, + }).Info("handling send-to-device message") + + eventJSON, err := json.Marshal(ote) + if err != nil { + logrus.WithError(err).Error("sendToDevice failed json.Marshal") + return err + } + + m := &sarama.ProducerMessage{ + Topic: string(t.OutputSendToDeviceEventTopic), + Key: sarama.StringEncoder(ote.UserID), + Value: sarama.ByteEncoder(eventJSON), + } + + _, _, err = t.Producer.SendMessage(m) + if err != nil { + logrus.WithError(err).Error("sendToDevice failed t.Producer.SendMessage") + return err + } + } + + return nil +} + // SetupHTTP adds the EDUServerInputAPI handlers to the http.ServeMux. func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) { internalAPIMux.Handle(api.EDUServerInputTypingEventPath, @@ -105,4 +186,17 @@ func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(api.EDUServerInputSendToDeviceEventPath, + internal.MakeInternalAPI("inputSendToDeviceEvents", func(req *http.Request) util.JSONResponse { + var request api.InputSendToDeviceEventRequest + var response api.InputSendToDeviceEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := t.InputSendToDeviceEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index b514af0a..74b4c014 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -265,6 +265,25 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { if err := t.eduProducer.SendTyping(t.context, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server") } + case gomatrixserverlib.MDirectToDevice: + // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema + var directPayload gomatrixserverlib.ToDeviceMessage + if err := json.Unmarshal(e.Content, &directPayload); err != nil { + util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal send-to-device events") + continue + } + for userID, byUser := range directPayload.Messages { + for deviceID, message := range byUser { + // TODO: check that the user and the device actually exist here + if err := t.eduProducer.SendToDevice(t.context, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + util.GetLogger(t.context).WithError(err).WithFields(logrus.Fields{ + "sender": directPayload.Sender, + "user_id": userID, + "device_id": deviceID, + }).Error("Failed to send send-to-device event to edu server") + } + } + } default: util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu") } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index cb8aec6f..3e28a347 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -77,6 +77,14 @@ func (p *testEDUProducer) InputTypingEvent( return nil } +func (p *testEDUProducer) InputSendToDeviceEvent( + ctx context.Context, + request *eduAPI.InputSendToDeviceEventRequest, + response *eduAPI.InputSendToDeviceEventResponse, +) error { + return nil +} + type testRoomserverAPI struct { inputRoomEvents []api.InputRoomEvent queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse diff --git a/go.mod b/go.mod index 4365ea50..cc60e1a2 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 - github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/mattn/go-sqlite3 v2.0.2+incompatible diff --git a/go.sum b/go.sum index c08cfa5d..6d9c2725 100644 --- a/go.sum +++ b/go.sum @@ -356,8 +356,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61 h1:3rgoGvj/skUWg+u9E6ycEFs2ZGenEjr28ZtAhAhmZeM= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf h1:iT2dfJ6JmYNRZBQTXeCNwsZIvfkBbFggzclM8iKnbR0= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= diff --git a/internal/config/config.go b/internal/config/config.go index 2a95069a..a20cc0ea 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -152,6 +152,8 @@ type Dendrite struct { OutputClientData Topic `yaml:"output_client_data"` // Topic for eduserver/api.OutputTypingEvent events. OutputTypingEvent Topic `yaml:"output_typing_event"` + // Topic for eduserver/api.OutputSendToDeviceEvent events. + OutputSendToDeviceEvent Topic `yaml:"output_send_to_device_event"` // Topic for user updates (profile, presence) UserUpdates Topic `yaml:"user_updates"` } diff --git a/internal/sql.go b/internal/sql.go index d6a5a308..546954bd 100644 --- a/internal/sql.go +++ b/internal/sql.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,9 +18,12 @@ package internal import ( "database/sql" + "errors" "fmt" "runtime" "time" + + "go.uber.org/atomic" ) // A Transaction is something that can be committed or rolledback. @@ -107,3 +112,60 @@ type DbProperties interface { MaxOpenConns() int ConnMaxLifetime() time.Duration } + +// TransactionWriter allows queuing database writes so that you don't +// contend on database locks in, e.g. SQLite. Only one task will run +// at a time on a given TransactionWriter. +type TransactionWriter struct { + running atomic.Bool + todo chan transactionWriterTask +} + +func NewTransactionWriter() *TransactionWriter { + return &TransactionWriter{ + todo: make(chan transactionWriterTask), + } +} + +// transactionWriterTask represents a specific task. +type transactionWriterTask struct { + db *sql.DB + f func(txn *sql.Tx) error + wait chan error +} + +// Do queues a task to be run by a TransactionWriter. The function +// provided will be ran within a transaction as supplied by the +// database parameter. This will block until the task is finished. +func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error { + if w.todo == nil { + return errors.New("not initialised") + } + if !w.running.Load() { + go w.run() + } + task := transactionWriterTask{ + db: db, + f: f, + wait: make(chan error, 1), + } + w.todo <- task + return <-task.wait +} + +// run processes the tasks for a given transaction writer. Only one +// of these goroutines will run at a time. A transaction will be +// opened using the database object from the task and then this will +// be passed as a parameter to the task function. +func (w *TransactionWriter) run() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for task := range w.todo { + task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { + return task.f(txn) + }) + close(task.wait) + } +} diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go new file mode 100644 index 00000000..48701803 --- /dev/null +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -0,0 +1,113 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + log "github.com/sirupsen/logrus" +) + +// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. +type OutputSendToDeviceEventConsumer struct { + sendToDeviceConsumer *internal.ContinualConsumer + db storage.Database + serverName gomatrixserverlib.ServerName // our server name + notifier *sync.Notifier +} + +// NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. +// Call Start() to begin consuming from the EDU server. +func NewOutputSendToDeviceEventConsumer( + cfg *config.Dendrite, + kafkaConsumer sarama.Consumer, + n *sync.Notifier, + store storage.Database, +) *OutputSendToDeviceEventConsumer { + + consumer := internal.ContinualConsumer{ + Topic: string(cfg.Kafka.Topics.OutputSendToDeviceEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + } + + s := &OutputSendToDeviceEventConsumer{ + sendToDeviceConsumer: &consumer, + db: store, + serverName: cfg.Matrix.ServerName, + notifier: n, + } + + consumer.ProcessMessage = s.onMessage + + return s +} + +// Start consuming from EDU api +func (s *OutputSendToDeviceEventConsumer) Start() error { + return s.sendToDeviceConsumer.Start() +} + +func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { + var output api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Value, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + return err + } + + _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + return err + } + if domain != s.serverName { + return nil + } + + util.GetLogger(context.TODO()).WithFields(log.Fields{ + "sender": output.Sender, + "user_id": output.UserID, + "device_id": output.DeviceID, + "event_type": output.Type, + }).Info("sync API received send-to-device event from EDU server") + + streamPos := s.db.AddSendToDevice() + + _, err = s.db.StoreNewSendForDeviceMessage( + context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent, + ) + if err != nil { + log.WithError(err).Errorf("failed to store send-to-device message") + return err + } + + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, + types.NewStreamToken(0, streamPos), + ) + + return nil +} diff --git a/syncapi/consumers/eduserver.go b/syncapi/consumers/eduserver_typing.go similarity index 100% rename from syncapi/consumers/eduserver.go rename to syncapi/consumers/eduserver_typing.go diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 7e1a40fd..566e5d58 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -55,10 +55,12 @@ type Database interface { // sync response for the given user. Events returned will include any client // transaction IDs associated with the given device. These transaction IDs come // from when the device sent the event via an API that included a transaction - // ID. - IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) - // CompleteSync returns a complete /sync API response for the given user. - CompleteSync(ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error) + // ID. A response object must be provided for IncrementaSync to populate - it + // will not create one. + IncrementalSync(ctx context.Context, res *types.Response, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + // CompleteSync returns a complete /sync API response for the given user. A response object + // must be provided for CompleteSync to populate - it will not create one. + CompleteSync(ctx context.Context, res *types.Response, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error) // GetAccountDataInRange returns all account data for a given user inserted or // updated between two given positions // Returns a map following the format data[roomID] = []dataTypes @@ -104,4 +106,26 @@ type Database interface { StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) + // AddSendToDevice increases the EDU position in the cache and returns the stream position. + AddSendToDevice() types.StreamPosition + // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: + // - "events": a list of send-to-device events that should be included in the sync + // - "changes": a list of send-to-device events that should be updated in the database by + // CleanSendToDeviceUpdates + // - "deletions": a list of send-to-device events which have been confirmed as sent and + // can be deleted altogether by CleanSendToDeviceUpdates + // The token supplied should be the current requested sync token, e.g. from the "since" + // parameter. + SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error) + // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. + StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) + // CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the + // result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows + // SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after + // starting to wait for an incremental sync with timeout). + // The token supplied should be the current requested sync token, e.g. from the "since" + // parameter. + CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) + // SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent. + SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error) } diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go new file mode 100644 index 00000000..335a05ef --- /dev/null +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -0,0 +1,171 @@ +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const sendToDeviceSchema = ` +CREATE SEQUENCE IF NOT EXISTS syncapi_send_to_device_id; + +-- Stores send-to-device messages. +CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( + -- The ID that uniquely identifies this message. + id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_send_to_device_id'), + -- The user ID to send the message to. + user_id TEXT NOT NULL, + -- The device ID to send the message to. + device_id TEXT NOT NULL, + -- The event content JSON. + content TEXT NOT NULL, + -- The token that was supplied to the /sync at the time that this + -- message was included in a sync response, or NULL if we haven't + -- included it in a /sync response yet. + sent_by_token TEXT +); +` + +const insertSendToDeviceMessageSQL = ` + INSERT INTO syncapi_send_to_device (user_id, device_id, content) + VALUES ($1, $2, $3) +` + +const countSendToDeviceMessagesSQL = ` + SELECT COUNT(*) + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 +` + +const selectSendToDeviceMessagesSQL = ` + SELECT id, user_id, device_id, content, sent_by_token + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 + ORDER BY id DESC +` + +const updateSentSendToDeviceMessagesSQL = ` + UPDATE syncapi_send_to_device SET sent_by_token = $1 + WHERE id = ANY($2) +` + +const deleteSendToDeviceMessagesSQL = ` + DELETE FROM syncapi_send_to_device WHERE id = ANY($1) +` + +type sendToDeviceStatements struct { + insertSendToDeviceMessageStmt *sql.Stmt + countSendToDeviceMessagesStmt *sql.Stmt + selectSendToDeviceMessagesStmt *sql.Stmt + updateSentSendToDeviceMessagesStmt *sql.Stmt + deleteSendToDeviceMessagesStmt *sql.Stmt +} + +func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { + s := &sendToDeviceStatements{} + _, err := db.Exec(sendToDeviceSchema) + if err != nil { + return nil, err + } + if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { + return nil, err + } + if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *sendToDeviceStatements) InsertSendToDeviceMessage( + ctx context.Context, txn *sql.Tx, userID, deviceID, content string, +) (err error) { + _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return +} + +func (s *sendToDeviceStatements) CountSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (count int, err error) { + row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + if err = row.Scan(&count); err != nil { + return + } + return count, nil +} + +func (s *sendToDeviceStatements) SelectSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (events []types.SendToDeviceEvent, err error) { + rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") + + for rows.Next() { + var id types.SendToDeviceNID + var userID, deviceID, content string + var sentByToken *string + if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { + return + } + event := types.SendToDeviceEvent{ + ID: id, + UserID: userID, + DeviceID: deviceID, + } + if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + return + } + if sentByToken != nil { + if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil { + event.SentByToken = &token + } + } + events = append(events, event) + } + + return events, rows.Err() +} + +func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, +) (err error) { + _, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) + return +} + +func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, +) (err error) { + _, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) + return +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index dc73350a..8a8f964a 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -69,6 +69,10 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (* if err != nil { return nil, err } + sendToDevice, err := NewPostgresSendToDeviceTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, Invites: invites, @@ -77,6 +81,8 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (* Topology: topology, CurrentRoomState: currState, BackwardExtremities: backwardExtremities, + SendToDevice: sendToDevice, + SendToDeviceWriter: internal.NewTransactionWriter(), EDUCache: cache.New(), } return &d, nil diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 888f85e0..497c043a 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package shared import ( @@ -27,6 +41,8 @@ type Database struct { Topology tables.Topology CurrentRoomState tables.CurrentRoomState BackwardExtremities tables.BackwardsExtremities + SendToDevice tables.SendToDevice + SendToDeviceWriter *internal.TransactionWriter EDUCache *cache.EDUCache } @@ -89,6 +105,10 @@ func (d *Database) RemoveTypingUser( return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID)) } +func (d *Database) AddSendToDevice() types.StreamPosition { + return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage()) +} + func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { d.EDUCache.SetTimeoutCallback(fn) } @@ -528,14 +548,14 @@ func (d *Database) addEDUDeltaToResponse( } func (d *Database) IncrementalSync( - ctx context.Context, + ctx context.Context, res *types.Response, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool, ) (*types.Response, error) { nextBatchPos := fromPos.WithUpdates(toPos) - res := types.NewResponse(nextBatchPos) + res.NextBatch = nextBatchPos.String() var joinedRoomIDs []string var err error @@ -568,12 +588,12 @@ func (d *Database) IncrementalSync( // getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed // to it. It returns toPos and joinedRoomIDs for use of adding EDUs. +// nolint:nakedret func (d *Database) getResponseWithPDUsForCompleteSync( - ctx context.Context, + ctx context.Context, res *types.Response, userID string, numRecentEventsPerRoom int, ) ( - res *types.Response, toPos types.StreamingToken, joinedRoomIDs []string, err error, @@ -604,7 +624,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync( To: toPos.PDUPosition(), } - res = types.NewResponse(toPos) + res.NextBatch = toPos.String() // Extract room state and recent events for all rooms the user is joined to. joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) @@ -662,14 +682,15 @@ func (d *Database) getResponseWithPDUsForCompleteSync( } succeeded = true - return res, toPos, joinedRoomIDs, err + return //res, toPos, joinedRoomIDs, err } func (d *Database) CompleteSync( - ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int, + ctx context.Context, res *types.Response, + device authtypes.Device, numRecentEventsPerRoom int, ) (*types.Response, error) { - res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( - ctx, device.UserID, numRecentEventsPerRoom, + toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( + ctx, res, device.UserID, numRecentEventsPerRoom, ) if err != nil { return nil, err @@ -1028,6 +1049,115 @@ func (d *Database) currentStateStreamEventsForRoom( return s, nil } +func (d *Database) SendToDeviceUpdatesWaiting( + ctx context.Context, userID, deviceID string, +) (bool, error) { + count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID) + if err != nil { + return false, err + } + return count > 0, nil +} + +func (d *Database) AddSendToDeviceEvent( + ctx context.Context, txn *sql.Tx, + userID, deviceID, content string, +) error { + return d.SendToDevice.InsertSendToDeviceMessage( + ctx, txn, userID, deviceID, content, + ) +} + +func (d *Database) StoreNewSendForDeviceMessage( + ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, +) (types.StreamPosition, error) { + j, err := json.Marshal(event) + if err != nil { + return streamPos, err + } + // Delegate the database write task to the SendToDeviceWriter. It'll guarantee + // that we don't lock the table for writes in more than one place. + err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + return d.AddSendToDeviceEvent( + ctx, txn, userID, deviceID, string(j), + ) + }) + if err != nil { + return streamPos, err + } + return streamPos, nil +} + +func (d *Database) SendToDeviceUpdatesForSync( + ctx context.Context, + userID, deviceID string, + token types.StreamingToken, +) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { + // First of all, get our send-to-device updates for this user. + events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) + if err != nil { + return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) + } + + // If there's nothing to do then stop here. + if len(events) == 0 { + return nil, nil, nil, nil + } + + // Work out whether we need to update any of the database entries. + toReturn := []types.SendToDeviceEvent{} + toUpdate := []types.SendToDeviceNID{} + toDelete := []types.SendToDeviceNID{} + for _, event := range events { + if event.SentByToken == nil { + // If the event has no sent-by token yet then we haven't attempted to send + // it. Record the current requested sync token in the database. + toUpdate = append(toUpdate, event.ID) + toReturn = append(toReturn, event) + event.SentByToken = &token + } else if token.IsAfter(*event.SentByToken) { + // The event had a sync token, therefore we've sent it before. The current + // sync token is now after the stored one so we can assume that the client + // successfully completed the previous sync (it would re-request it otherwise) + // so we can remove the entry from the database. + toDelete = append(toDelete, event.ID) + } else { + // It looks like the sync is being re-requested, maybe it timed out or + // failed. Re-send any that should have been acknowledged by now. + toReturn = append(toReturn, event) + } + } + + return toReturn, toUpdate, toDelete, nil +} + +func (d *Database) CleanSendToDeviceUpdates( + ctx context.Context, + toUpdate, toDelete []types.SendToDeviceNID, + token types.StreamingToken, +) (err error) { + if len(toUpdate) == 0 && len(toDelete) == 0 { + return nil + } + // If we need to write to the database then we'll ask the SendToDeviceWriter to + // do that for us. It'll guarantee that we don't lock the table for writes in + // more than one place. + err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + // Delete any send-to-device messages marked for deletion. + if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { + return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) + } + + // Now update any outstanding send-to-device messages with the new sync token. + if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil { + return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err) + } + + return nil + }) + return +} + // There may be some overlap where events in stateEvents are already in recentEvents, so filter // them out so we don't include them twice in the /sync response. They should be in recentEvents // only, so clients get to the correct state once they have rolled forward. diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go new file mode 100644 index 00000000..0d03f23e --- /dev/null +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -0,0 +1,172 @@ +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const sendToDeviceSchema = ` +-- Stores send-to-device messages. +CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( + -- The ID that uniquely identifies this message. + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The user ID to send the message to. + user_id TEXT NOT NULL, + -- The device ID to send the message to. + device_id TEXT NOT NULL, + -- The event content JSON. + content TEXT NOT NULL, + -- The token that was supplied to the /sync at the time that this + -- message was included in a sync response, or NULL if we haven't + -- included it in a /sync response yet. + sent_by_token TEXT +); +` + +const insertSendToDeviceMessageSQL = ` + INSERT INTO syncapi_send_to_device (user_id, device_id, content) + VALUES ($1, $2, $3) +` + +const countSendToDeviceMessagesSQL = ` + SELECT COUNT(*) + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 +` + +const selectSendToDeviceMessagesSQL = ` + SELECT id, user_id, device_id, content, sent_by_token + FROM syncapi_send_to_device + WHERE user_id = $1 AND device_id = $2 + ORDER BY id DESC +` + +const updateSentSendToDeviceMessagesSQL = ` + UPDATE syncapi_send_to_device SET sent_by_token = $1 + WHERE id IN ($2) +` + +const deleteSendToDeviceMessagesSQL = ` + DELETE FROM syncapi_send_to_device WHERE id IN ($1) +` + +type sendToDeviceStatements struct { + insertSendToDeviceMessageStmt *sql.Stmt + selectSendToDeviceMessagesStmt *sql.Stmt + countSendToDeviceMessagesStmt *sql.Stmt +} + +func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { + s := &sendToDeviceStatements{} + _, err := db.Exec(sendToDeviceSchema) + if err != nil { + return nil, err + } + if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil { + return nil, err + } + if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { + return nil, err + } + if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *sendToDeviceStatements) InsertSendToDeviceMessage( + ctx context.Context, txn *sql.Tx, userID, deviceID, content string, +) (err error) { + _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return +} + +func (s *sendToDeviceStatements) CountSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (count int, err error) { + row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID) + if err = row.Scan(&count); err != nil { + return + } + return count, nil +} + +func (s *sendToDeviceStatements) SelectSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (events []types.SendToDeviceEvent, err error) { + rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") + + for rows.Next() { + var id types.SendToDeviceNID + var userID, deviceID, content string + var sentByToken *string + if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { + return + } + event := types.SendToDeviceEvent{ + ID: id, + UserID: userID, + DeviceID: deviceID, + } + if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + return + } + if sentByToken != nil { + if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil { + event.SentByToken = &token + } + } + events = append(events, event) + } + + return events, rows.Err() +} + +func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, +) (err error) { + query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", internal.QueryVariadic(1+len(nids)), 1) + params := make([]interface{}, 1+len(nids)) + params[0] = token + for k, v := range nids { + params[k+1] = v + } + _, err = txn.ExecContext(ctx, query, params...) + return +} + +func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( + ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, +) (err error) { + query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", internal.QueryVariadic(len(nids)), 1) + params := make([]interface{}, 1+len(nids)) + for k, v := range nids { + params[k] = v + } + _, err = txn.ExecContext(ctx, query, params...) + return +} diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 8ab1d404..5ba07617 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -95,6 +95,10 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + sendToDevice, err := NewSqliteSendToDeviceTable(d.db) + if err != nil { + return err + } d.Database = shared.Database{ DB: d.db, Invites: invites, @@ -103,6 +107,8 @@ func (d *SyncServerDatasource) prepare() (err error) { BackwardExtremities: bwExtrem, CurrentRoomState: roomState, Topology: topology, + SendToDevice: sendToDevice, + SendToDeviceWriter: internal.NewTransactionWriter(), EDUCache: cache.New(), } return nil diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index bb8554f4..4661ede4 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -3,6 +3,7 @@ package storage_test import ( "context" "crypto/ed25519" + "encoding/json" "fmt" "testing" "time" @@ -157,7 +158,8 @@ func TestSyncResponse(t *testing.T) { from := types.NewStreamToken( // pretend we are at the penultimate event positions[len(positions)-2], types.StreamPosition(0), ) - return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) + res := types.NewResponse() + return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) }, WantTimeline: events[len(events)-1:], }, @@ -169,8 +171,9 @@ func TestSyncResponse(t *testing.T) { from := types.NewStreamToken( // pretend we are 10 events behind positions[len(positions)-11], types.StreamPosition(0), ) + res := types.NewResponse() // limit is set to 5 - return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) + return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) }, // want the last 5 events, NOT the last 10. WantTimeline: events[len(events)-5:], @@ -180,8 +183,9 @@ func TestSyncResponse(t *testing.T) { { Name: "CompleteSync limited", DoSync: func() (*types.Response, error) { + res := types.NewResponse() // limit set to 5 - return db.CompleteSync(ctx, testUserDeviceA, 5) + return db.CompleteSync(ctx, res, testUserDeviceA, 5) }, // want the last 5 events WantTimeline: events[len(events)-5:], @@ -193,7 +197,8 @@ func TestSyncResponse(t *testing.T) { { Name: "CompleteSync", DoSync: func() (*types.Response, error) { - return db.CompleteSync(ctx, testUserDeviceA, len(events)+1) + res := types.NewResponse() + return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) }, WantTimeline: events, // We want no state at all as that field in /sync is the delta between the token (beginning of time) @@ -234,7 +239,8 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) { positions[len(positions)-2], types.StreamPosition(0), ) - res, err := db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) + res := types.NewResponse() + res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) if err != nil { t.Fatalf("failed to IncrementalSync with latest token") } @@ -512,6 +518,89 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) { } } +func TestSendToDeviceBehaviour(t *testing.T) { + //t.Parallel() + db := MustCreateDatabase(t) + + // At this point there should be no messages. We haven't sent anything + // yet. + events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("first call should have no updates") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0)) + if err != nil { + return + } + + // Try sending a message. + streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{ + Sender: "bob", + Type: "m.type", + Content: json.RawMessage("{}"), + }) + if err != nil { + t.Fatal(err) + } + + // At this point we should get exactly one message. We're sending the sync position + // that we were given from the update and the send-to-device update will be updated + // in the database to reflect that this was the sync position we sent the message at. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + if err != nil { + t.Fatal(err) + } + if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { + t.Fatal("second call should have one update") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + if err != nil { + return + } + + // At this point we should still have one message because we haven't progressed the + // sync position yet. This is equivalent to the client failing to /sync and retrying + // with the same position. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) + if err != nil { + t.Fatal(err) + } + if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("third call should have one update still") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) + if err != nil { + return + } + + // At this point we should now have no updates, because we've progressed the sync + // position. Therefore the update from before will not be sent again. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { + t.Fatal("fourth call should have no updates") + } + err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1)) + if err != nil { + return + } + + // At this point we should still have no updates, because no new updates have been + // sent. + events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2)) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { + t.Fatal("fifth call should have no updates") + } +} + func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { if len(gots) != len(wants) { t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index bc3b6941..0b7d1595 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package tables import ( @@ -94,3 +108,28 @@ type BackwardsExtremities interface { // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) } + +// SendToDevice tracks send-to-device messages which are sent to individual +// clients. Each message gets inserted into this table at the point that we +// receive it from the EDU server. +// +// We're supposed to try and do our best to deliver send-to-device messages +// once, but the only way that we can really guarantee that they have been +// delivered is if the client successfully requests the next sync as given +// in the next_batch. Each time the device syncs, we will request all of the +// updates that either haven't been sent yet, along with all updates that we +// *have* sent but we haven't confirmed to have been received yet. If it's the +// first time we're sending a given update then we update the table to say +// what the "since" parameter was when we tried to send it. +// +// When the client syncs again, if their "since" parameter is *later* than +// the recorded one, we drop the entry from the DB as it's "sent". If the +// sync parameter isn't later then we will keep including the updates in the +// sync response, as the client is seemingly trying to repeat the same /sync. +type SendToDevice interface { + InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error) + SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error) + UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) + DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) + CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) +} diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 9b410a0c..325e7535 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -120,6 +120,18 @@ func (n *Notifier) OnNewEvent( } } +func (n *Notifier) OnNewSendToDevice( + userID string, deviceIDs []string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos + + n.wakeupUserDevice(userID, deviceIDs, latestPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos @@ -189,8 +201,8 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { // wakeupUserDevice will wake up the sync stream for a specific user device. Other // device streams will be left alone. // nolint:unused -func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) { - for userID, deviceID := range userDevices { +func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) { + for _, deviceID := range deviceIDs { if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil { stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 14ddef20..13231557 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -172,7 +172,7 @@ func TestCorrectStreamWakeup(t *testing.T) { time.Sleep(1 * time.Second) wake := "two" - n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter) + n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter) if result := <-awoken; result != wake { t.Fatalf("expected to wake %q, got %q", wake, result) diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index bd29b333..8b93cad4 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -1,4 +1,6 @@ // Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +17,7 @@ package sync import ( + "context" "net/http" "time" @@ -54,17 +57,18 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype JSON: jsonerror.Unknown(err.Error()), } } + logger := util.GetLogger(req.Context()).WithFields(log.Fields{ - "userID": device.UserID, - "deviceID": device.ID, - "since": syncReq.since, - "timeout": syncReq.timeout, - "limit": syncReq.limit, + "user_id": device.UserID, + "device_id": device.ID, + "since": syncReq.since, + "timeout": syncReq.timeout, + "limit": syncReq.limit, }) currPos := rp.notifier.CurrentPosition() - if shouldReturnImmediately(syncReq) { + if rp.shouldReturnImmediately(syncReq) { syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { logger.WithError(err).Error("rp.currentSyncForUser failed") @@ -116,7 +120,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype // response. This ensures that we don't waste the hard work // of calculating the sync only to get timed out before we // can respond - syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { logger.WithError(err).Error("rp.currentSyncForUser failed") @@ -134,19 +137,59 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { - // TODO: handle ignored users - if req.since == nil { - res, err = rp.db.CompleteSync(req.ctx, req.device, req.limit) - } else { - res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState) + res = types.NewResponse() + + since := types.NewStreamToken(0, 0) + if req.since != nil { + since = *req.since } + // See if we have any new tasks to do for the send-to-device messaging. + events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, since) + if err != nil { + return nil, err + } + + // TODO: handle ignored users + if req.since == nil { + res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit) + } else { + res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) + } if err != nil { return } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) + if err != nil { + return + } + + // Before we return the sync response, make sure that we take action on + // any send-to-device database updates or deletions that we need to do. + // Then add the updates into the sync response. + if len(updates) > 0 || len(deletions) > 0 { + // Handle the updates and deletions in the database. + err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since) + if err != nil { + return + } + } + if len(events) > 0 { + // Add the updates into the sync response. + for _, event := range events { + res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent) + } + + // Get the next_batch from the sync response and increase the + // EDU counter. + if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil { + pos.Positions[1]++ + res.NextBatch = pos.String() + } + } + return } @@ -238,6 +281,10 @@ func (rp *RequestPool) appendAccountData( // shouldReturnImmediately returns whether the /sync request is an initial sync, // or timeout=0, or full_state=true, in any of the cases the request should // return immediately. -func shouldReturnImmediately(syncReq *syncRequest) bool { - return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState +func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool { + if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState { + return true + } + waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID) + return werr == nil && waiting } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 9251f618..762f4e9d 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -78,7 +78,14 @@ func SetupSyncAPIComponent( base.Cfg, base.KafkaConsumer, notifier, syncDB, ) if err = typingConsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start typing server consumer") + logrus.WithError(err).Panicf("failed to start typing consumer") + } + + sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( + base.Cfg, base.KafkaConsumer, notifier, syncDB, + ) + if err = sendToDeviceConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start send-to-device consumer") } routing.Setup(base.PublicAPIMux, requestPool, syncDB, deviceDB, federation, rsAPI, cfg) diff --git a/syncapi/types/types.go b/syncapi/types/types.go index caa1b3ad..c1f09fba 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -296,13 +296,14 @@ type Response struct { Invite map[string]InviteResponse `json:"invite"` Leave map[string]LeaveResponse `json:"leave"` } `json:"rooms"` + ToDevice struct { + Events []gomatrixserverlib.SendToDeviceEvent `json:"events"` + } `json:"to_device"` } // NewResponse creates an empty response with initialised maps. -func NewResponse(token StreamingToken) *Response { - res := Response{ - NextBatch: token.String(), - } +func NewResponse() *Response { + res := Response{} // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. res.Rooms.Join = make(map[string]JoinResponse) @@ -315,6 +316,7 @@ func NewResponse(token StreamingToken) *Response { // This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse. res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) + res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0) return &res } @@ -326,7 +328,8 @@ func (r *Response) IsEmpty() bool { len(r.Rooms.Invite) == 0 && len(r.Rooms.Leave) == 0 && len(r.AccountData.Events) == 0 && - len(r.Presence.Events) == 0 + len(r.Presence.Events) == 0 && + len(r.ToDevice.Events) == 0 } // JoinResponse represents a /sync response for a room which is under the 'join' key. @@ -393,3 +396,13 @@ func NewLeaveResponse() *LeaveResponse { res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0) return &res } + +type SendToDeviceNID int + +type SendToDeviceEvent struct { + gomatrixserverlib.SendToDeviceEvent + ID SendToDeviceNID + UserID string + DeviceID string + SentByToken *StreamingToken +} diff --git a/sytest-blacklist b/sytest-blacklist index caad2545..1efc207f 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -39,3 +39,8 @@ Ignore invite in incremental sync # Blacklisted because this test calls /r0/events which we don't implement New room members see their own join event Existing members see new members' join events + +# Blacklisted because the federation work for these hasn't been finished yet. +Can recv device messages over federation +Device messages over federation wake up /sync +Wildcard device messages over federation wake up /sync diff --git a/sytest-whitelist b/sytest-whitelist index d4e6be9a..6236b28e 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -289,3 +289,15 @@ Existing members see new members' join events Inbound federation can receive events Inbound federation can receive redacted events Can logout current device +Can send a message directly to a device using PUT /sendToDevice +Can recv a device message using /sync +Can recv device messages until they are acknowledged +Device messages with the same txn_id are deduplicated +Device messages wake up /sync +# TODO: separate PR for: Can recv device messages over federation +# TODO: separate PR for: Device messages over federation wake up /sync +Can send messages with a wildcard device id +Can send messages with a wildcard device id to two devices +Wildcard device messages wake up /sync +# TODO: separate PR for: Wildcard device messages over federation wake up /sync +Can send a to-device message to two users which both receive it using /sync