Send-to-device support (#1072)

* Groundwork for send-to-device messaging

* Update sample config

* Add unstable routing for now

* Send to device consumer in sync API

* Start the send-to-device consumer

* fix indentation in dendrite-config.yaml

* Create send-to-device database tables, other tweaks

* Add some logic for send-to-device messages, add them into sync stream

* Handle incoming send-to-device messages, count them with EDU stream pos

* Undo changes to test

* pq.Array

* Fix sync

* Logging

* Fix a couple of transaction things, fix client API

* Add send-to-device test, hopefully fix bugs

* Comments

* Refactor a bit

* Fix schema

* Fix queries

* Debug logging

* Fix storing and retrieving of send-to-device messages

* Try to avoid database locks

* Update sync position

* Use latest sync position

* Jiggle about sync a bit

* Fix tests

* Break out the retrieval from the update/delete behaviour

* Comments

* nolint on getResponseWithPDUsForCompleteSync

* Try to line up sync tokens again

* Implement wildcard

* Add all send-to-device tests to whitelist, what could possibly go wrong?

* Only care about wildcard when targeted locally

* Deduplicate transactions

* Handle tokens properly, return immediately if waiting send-to-device messages

* Fix sync

* Update sytest-whitelist

* Fix copyright notice (need to do more of this)

* Comments, copyrights

* Return errors from Do, fix dendritejs

* Review comments

* Comments

* Constructor for TransactionWriter

* defletions

* Update gomatrixserverlib, sytest-blacklist
main
Neil Alexander 2020-06-01 17:50:19 +01:00 committed by GitHub
parent 1f43c24f86
commit a5d822004d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1302 additions and 60 deletions

View File

@ -14,6 +14,7 @@ package producers
import ( import (
"context" "context"
"encoding/json"
"time" "time"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
@ -52,3 +53,28 @@ func (p *EDUServerProducer) SendTyping(
return err return err
} }
// SendToDevice sends a typing event to EDU server
func (p *EDUServerProducer) SendToDevice(
ctx context.Context, sender, userID, deviceID, eventType string,
message interface{},
) error {
js, err := json.Marshal(message)
if err != nil {
return err
}
requestData := api.InputSendToDeviceEvent{
UserID: userID,
DeviceID: deviceID,
SendToDeviceEvent: gomatrixserverlib.SendToDeviceEvent{
Sender: sender,
Type: eventType,
Content: js,
},
}
request := api.InputSendToDeviceEventRequest{
InputSendToDeviceEvent: requestData,
}
response := api.InputSendToDeviceEventResponse{}
return p.InputAPI.InputSendToDeviceEvent(ctx, &request, &response)
}

View File

@ -274,6 +274,31 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/sendToDevice/{eventType}/{txnID}",
internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := internal.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
txnID := vars["txnID"]
return SendToDevice(req, device, eduProducer, transactionsCache, vars["eventType"], &txnID)
}),
).Methods(http.MethodPut, http.MethodOptions)
// This is only here because sytest refers to /unstable for this endpoint
// rather than r0. It's an exact duplicate of the above handler.
// TODO: Remove this if/when sytest is fixed!
unstableMux.Handle("/sendToDevice/{eventType}/{txnID}",
internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := internal.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
txnID := vars["txnID"]
return SendToDevice(req, device, eduProducer, transactionsCache, vars["eventType"], &txnID)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/account/whoami", r0mux.Handle("/account/whoami",
internal.MakeAuthAPI("whoami", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { internal.MakeAuthAPI("whoami", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return Whoami(req, device) return Whoami(req, device)

View File

@ -0,0 +1,70 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"encoding/json"
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/internal/transactions"
"github.com/matrix-org/util"
)
// SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}
// sends the device events to the EDU Server
func SendToDevice(
req *http.Request, device *authtypes.Device,
eduProducer *producers.EDUServerProducer,
txnCache *transactions.Cache,
eventType string, txnID *string,
) util.JSONResponse {
if txnID != nil {
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
return *res
}
}
var httpReq struct {
Messages map[string]map[string]json.RawMessage `json:"messages"`
}
resErr := httputil.UnmarshalJSONRequest(req, &httpReq)
if resErr != nil {
return *resErr
}
for userID, byUser := range httpReq.Messages {
for deviceID, message := range byUser {
if err := eduProducer.SendToDevice(
req.Context(), device.UserID, userID, deviceID, eventType, message,
); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("eduProducer.SendToDevice failed")
return jsonerror.InternalServerError()
}
}
}
res := util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
if txnID != nil {
txnCache.AddTransaction(device.AccessToken, *txnID, &res)
}
return res
}

View File

@ -39,7 +39,7 @@ func main() {
rsAPI := base.CreateHTTPRoomserverAPIs() rsAPI := base.CreateHTTPRoomserverAPIs()
fsAPI := base.CreateHTTPFederationSenderAPIs() fsAPI := base.CreateHTTPFederationSenderAPIs()
rsAPI.SetFederationSenderAPI(fsAPI) rsAPI.SetFederationSenderAPI(fsAPI)
eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
clientapi.SetupClientAPIComponent( clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB, federation, keyRing, base, deviceDB, accountDB, federation, keyRing,

View File

@ -148,7 +148,7 @@ func main() {
&base.Base, keyRing, federation, &base.Base, keyRing, federation,
) )
eduInputAPI := eduserver.SetupEDUServerComponent( eduInputAPI := eduserver.SetupEDUServerComponent(
&base.Base, cache.New(), &base.Base, cache.New(), deviceDB,
) )
asAPI := appservice.SetupAppServiceAPIComponent( asAPI := appservice.SetupAppServiceAPIComponent(
&base.Base, accountDB, deviceDB, federation, rsAPI, transactions.New(), &base.Base, accountDB, deviceDB, federation, rsAPI, transactions.New(),

View File

@ -29,8 +29,9 @@ func main() {
logrus.WithError(err).Warn("BaseDendrite close failed") logrus.WithError(err).Warn("BaseDendrite close failed")
} }
}() }()
deviceDB := base.CreateDeviceDB()
eduserver.SetupEDUServerComponent(base, cache.New()) eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer)) base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer))

View File

@ -39,7 +39,7 @@ func main() {
rsAPI := base.CreateHTTPRoomserverAPIs() rsAPI := base.CreateHTTPRoomserverAPIs()
asAPI := base.CreateHTTPAppServiceAPIs() asAPI := base.CreateHTTPAppServiceAPIs()
rsAPI.SetFederationSenderAPI(fsAPI) rsAPI.SetFederationSenderAPI(fsAPI)
eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
eduProducer := producers.NewEDUServerProducer(eduInputAPI) eduProducer := producers.NewEDUServerProducer(eduInputAPI)
federationapi.SetupFederationAPIComponent( federationapi.SetupFederationAPIComponent(

View File

@ -87,7 +87,7 @@ func main() {
} }
eduInputAPI := eduserver.SetupEDUServerComponent( eduInputAPI := eduserver.SetupEDUServerComponent(
base, cache.New(), base, cache.New(), deviceDB,
) )
if base.EnableHTTPAPIs { if base.EnableHTTPAPIs {
eduInputAPI = base.CreateHTTPEDUServerAPIs() eduInputAPI = base.CreateHTTPEDUServerAPIs()

View File

@ -175,6 +175,7 @@ func main() {
cfg.Database.SyncAPI = "file:/idb/dendritejs_syncapi.db" cfg.Database.SyncAPI = "file:/idb/dendritejs_syncapi.db"
cfg.Kafka.Topics.UserUpdates = "user_updates" cfg.Kafka.Topics.UserUpdates = "user_updates"
cfg.Kafka.Topics.OutputTypingEvent = "output_typing_event" cfg.Kafka.Topics.OutputTypingEvent = "output_typing_event"
cfg.Kafka.Topics.OutputSendToDeviceEvent = "output_send_to_device_event"
cfg.Kafka.Topics.OutputClientData = "output_client_data" cfg.Kafka.Topics.OutputClientData = "output_client_data"
cfg.Kafka.Topics.OutputRoomEvent = "output_room_event" cfg.Kafka.Topics.OutputRoomEvent = "output_room_event"
cfg.Matrix.TrustedIDServers = []string{ cfg.Matrix.TrustedIDServers = []string{
@ -206,7 +207,7 @@ func main() {
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node)
rsAPI := roomserver.SetupRoomServerComponent(base, keyRing, federation) rsAPI := roomserver.SetupRoomServerComponent(base, keyRing, federation)
eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New()) eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
asQuery := appservice.SetupAppServiceAPIComponent( asQuery := appservice.SetupAppServiceAPIComponent(
base, accountDB, deviceDB, federation, rsAPI, transactions.New(), base, accountDB, deviceDB, federation, rsAPI, transactions.New(),
) )

View File

@ -104,7 +104,8 @@ kafka:
topics: topics:
output_room_event: roomserverOutput output_room_event: roomserverOutput
output_client_data: clientapiOutput output_client_data: clientapiOutput
output_typing_event: eduServerOutput output_typing_event: eduServerTypingOutput
output_send_to_device_event: eduServerSendToDeviceOutput
user_updates: userUpdates user_updates: userUpdates
# The postgres connection configs for connecting to the databases e.g a postgres:// URI # The postgres connection configs for connecting to the databases e.g a postgres:// URI
@ -137,8 +138,8 @@ listen:
federation_sender: "localhost:7776" federation_sender: "localhost:7776"
appservice_api: "localhost:7777" appservice_api: "localhost:7777"
edu_server: "localhost:7778" edu_server: "localhost:7778"
key_server: "localhost:7779" key_server: "localhost:7779"
server_key_api: "localhost:7780" server_key_api: "localhost:7780"
# The configuration for tracing the dendrite components. # The configuration for tracing the dendrite components.
tracing: tracing:

View File

@ -1,3 +1,7 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
@ -37,6 +41,12 @@ type InputTypingEvent struct {
OriginServerTS gomatrixserverlib.Timestamp `json:"origin_server_ts"` OriginServerTS gomatrixserverlib.Timestamp `json:"origin_server_ts"`
} }
type InputSendToDeviceEvent struct {
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
gomatrixserverlib.SendToDeviceEvent
}
// InputTypingEventRequest is a request to EDUServerInputAPI // InputTypingEventRequest is a request to EDUServerInputAPI
type InputTypingEventRequest struct { type InputTypingEventRequest struct {
InputTypingEvent InputTypingEvent `json:"input_typing_event"` InputTypingEvent InputTypingEvent `json:"input_typing_event"`
@ -45,6 +55,14 @@ type InputTypingEventRequest struct {
// InputTypingEventResponse is a response to InputTypingEvents // InputTypingEventResponse is a response to InputTypingEvents
type InputTypingEventResponse struct{} type InputTypingEventResponse struct{}
// InputSendToDeviceEventRequest is a request to EDUServerInputAPI
type InputSendToDeviceEventRequest struct {
InputSendToDeviceEvent InputSendToDeviceEvent `json:"input_send_to_device_event"`
}
// InputSendToDeviceEventResponse is a response to InputSendToDeviceEventRequest
type InputSendToDeviceEventResponse struct{}
// EDUServerInputAPI is used to write events to the typing server. // EDUServerInputAPI is used to write events to the typing server.
type EDUServerInputAPI interface { type EDUServerInputAPI interface {
InputTypingEvent( InputTypingEvent(
@ -52,11 +70,20 @@ type EDUServerInputAPI interface {
request *InputTypingEventRequest, request *InputTypingEventRequest,
response *InputTypingEventResponse, response *InputTypingEventResponse,
) error ) error
InputSendToDeviceEvent(
ctx context.Context,
request *InputSendToDeviceEventRequest,
response *InputSendToDeviceEventResponse,
) error
} }
// EDUServerInputTypingEventPath is the HTTP path for the InputTypingEvent API. // EDUServerInputTypingEventPath is the HTTP path for the InputTypingEvent API.
const EDUServerInputTypingEventPath = "/eduserver/input" const EDUServerInputTypingEventPath = "/eduserver/input"
// EDUServerInputSendToDeviceEventPath is the HTTP path for the InputSendToDeviceEvent API.
const EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice"
// NewEDUServerInputAPIHTTP creates a EDUServerInputAPI implemented by talking to a HTTP POST API. // NewEDUServerInputAPIHTTP creates a EDUServerInputAPI implemented by talking to a HTTP POST API.
func NewEDUServerInputAPIHTTP(eduServerURL string, httpClient *http.Client) (EDUServerInputAPI, error) { func NewEDUServerInputAPIHTTP(eduServerURL string, httpClient *http.Client) (EDUServerInputAPI, error) {
if httpClient == nil { if httpClient == nil {
@ -70,7 +97,7 @@ type httpEDUServerInputAPI struct {
httpClient *http.Client httpClient *http.Client
} }
// InputRoomEvents implements EDUServerInputAPI // InputTypingEvent implements EDUServerInputAPI
func (h *httpEDUServerInputAPI) InputTypingEvent( func (h *httpEDUServerInputAPI) InputTypingEvent(
ctx context.Context, ctx context.Context,
request *InputTypingEventRequest, request *InputTypingEventRequest,
@ -82,3 +109,16 @@ func (h *httpEDUServerInputAPI) InputTypingEvent(
apiURL := h.eduServerURL + EDUServerInputTypingEventPath apiURL := h.eduServerURL + EDUServerInputTypingEventPath
return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// InputSendToDeviceEvent implements EDUServerInputAPI
func (h *httpEDUServerInputAPI) InputSendToDeviceEvent(
ctx context.Context,
request *InputSendToDeviceEventRequest,
response *InputSendToDeviceEventResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputSendToDeviceEvent")
defer span.Finish()
apiURL := h.eduServerURL + EDUServerInputSendToDeviceEventPath
return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}

View File

@ -1,3 +1,7 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
@ -12,7 +16,11 @@
package api package api
import "time" import (
"time"
"github.com/matrix-org/gomatrixserverlib"
)
// OutputTypingEvent is an entry in typing server output kafka log. // OutputTypingEvent is an entry in typing server output kafka log.
// This contains the event with extra fields used to create 'm.typing' event // This contains the event with extra fields used to create 'm.typing' event
@ -32,3 +40,12 @@ type TypingEvent struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
Typing bool `json:"typing"` Typing bool `json:"typing"`
} }
// OutputSendToDeviceEvent is an entry in the send-to-device output kafka log.
// This contains the full event content, along with the user ID and device ID
// to which it is destined.
type OutputSendToDeviceEvent struct {
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
gomatrixserverlib.SendToDeviceEvent
}

View File

@ -1,3 +1,7 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
@ -109,6 +113,19 @@ func (t *EDUCache) AddTypingUser(
return t.GetLatestSyncPosition() return t.GetLatestSyncPosition()
} }
// AddSendToDeviceMessage increases the sync position for
// send-to-device updates.
// Returns the sync position before update, as the caller
// will use this to record the current stream position
// at the time that the send-to-device message was sent.
func (t *EDUCache) AddSendToDeviceMessage() int64 {
t.Lock()
defer t.Unlock()
latestSyncPosition := t.latestSyncPosition
t.latestSyncPosition++
return latestSyncPosition
}
// addUser with mutex lock & replace the previous timer. // addUser with mutex lock & replace the previous timer.
// Returns the latest typing sync position after update. // Returns the latest typing sync position after update.
func (t *EDUCache) addUser( func (t *EDUCache) addUser(

View File

@ -1,3 +1,7 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at

View File

@ -1,3 +1,7 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
@ -13,6 +17,7 @@
package eduserver package eduserver
import ( import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/eduserver/input" "github.com/matrix-org/dendrite/eduserver/input"
@ -26,11 +31,15 @@ import (
func SetupEDUServerComponent( func SetupEDUServerComponent(
base *basecomponent.BaseDendrite, base *basecomponent.BaseDendrite,
eduCache *cache.EDUCache, eduCache *cache.EDUCache,
deviceDB devices.Database,
) api.EDUServerInputAPI { ) api.EDUServerInputAPI {
inputAPI := &input.EDUServerInputAPI{ inputAPI := &input.EDUServerInputAPI{
Cache: eduCache, Cache: eduCache,
Producer: base.KafkaProducer, DeviceDB: deviceDB,
OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent), Producer: base.KafkaProducer,
OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent),
OutputSendToDeviceEventTopic: string(base.Cfg.Kafka.Topics.OutputSendToDeviceEvent),
ServerName: base.Cfg.Matrix.ServerName,
} }
inputAPI.SetupHTTP(base.InternalAPIMux) inputAPI.SetupHTTP(base.InternalAPIMux)

View File

@ -1,3 +1,7 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
@ -20,11 +24,13 @@ import (
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// EDUServerInputAPI implements api.EDUServerInputAPI // EDUServerInputAPI implements api.EDUServerInputAPI
@ -33,8 +39,14 @@ type EDUServerInputAPI struct {
Cache *cache.EDUCache Cache *cache.EDUCache
// The kafka topic to output new typing events to. // The kafka topic to output new typing events to.
OutputTypingEventTopic string OutputTypingEventTopic string
// The kafka topic to output new send to device events to.
OutputSendToDeviceEventTopic string
// kafka producer // kafka producer
Producer sarama.SyncProducer Producer sarama.SyncProducer
// device database
DeviceDB devices.Database
// our server name
ServerName gomatrixserverlib.ServerName
} }
// InputTypingEvent implements api.EDUServerInputAPI // InputTypingEvent implements api.EDUServerInputAPI
@ -54,10 +66,20 @@ func (t *EDUServerInputAPI) InputTypingEvent(
t.Cache.RemoveUser(ite.UserID, ite.RoomID) t.Cache.RemoveUser(ite.UserID, ite.RoomID)
} }
return t.sendEvent(ite) return t.sendTypingEvent(ite)
} }
func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { // InputTypingEvent implements api.EDUServerInputAPI
func (t *EDUServerInputAPI) InputSendToDeviceEvent(
ctx context.Context,
request *api.InputSendToDeviceEventRequest,
response *api.InputSendToDeviceEventResponse,
) error {
ise := &request.InputSendToDeviceEvent
return t.sendToDeviceEvent(ise)
}
func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error {
ev := &api.TypingEvent{ ev := &api.TypingEvent{
Type: gomatrixserverlib.MTyping, Type: gomatrixserverlib.MTyping,
RoomID: ite.RoomID, RoomID: ite.RoomID,
@ -90,6 +112,65 @@ func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error {
return err return err
} }
func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) error {
devices := []string{}
localpart, domain, err := gomatrixserverlib.SplitID('@', ise.UserID)
if err != nil {
return err
}
// If the event is targeted locally then we want to expand the wildcard
// out into individual device IDs so that we can send them to each respective
// device. If the event isn't targeted locally then we can't expand the
// wildcard as we don't know about the remote devices, so instead we leave it
// as-is, so that the federation sender can send it on with the wildcard intact.
if domain == t.ServerName && ise.DeviceID == "*" {
devs, err := t.DeviceDB.GetDevicesByLocalpart(context.TODO(), localpart)
if err != nil {
return err
}
for _, dev := range devs {
devices = append(devices, dev.ID)
}
} else {
devices = append(devices, ise.DeviceID)
}
for _, device := range devices {
ote := &api.OutputSendToDeviceEvent{
UserID: ise.UserID,
DeviceID: device,
SendToDeviceEvent: ise.SendToDeviceEvent,
}
logrus.WithFields(logrus.Fields{
"user_id": ise.UserID,
"device_id": ise.DeviceID,
"event_type": ise.Type,
}).Info("handling send-to-device message")
eventJSON, err := json.Marshal(ote)
if err != nil {
logrus.WithError(err).Error("sendToDevice failed json.Marshal")
return err
}
m := &sarama.ProducerMessage{
Topic: string(t.OutputSendToDeviceEventTopic),
Key: sarama.StringEncoder(ote.UserID),
Value: sarama.ByteEncoder(eventJSON),
}
_, _, err = t.Producer.SendMessage(m)
if err != nil {
logrus.WithError(err).Error("sendToDevice failed t.Producer.SendMessage")
return err
}
}
return nil
}
// SetupHTTP adds the EDUServerInputAPI handlers to the http.ServeMux. // SetupHTTP adds the EDUServerInputAPI handlers to the http.ServeMux.
func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) { func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) {
internalAPIMux.Handle(api.EDUServerInputTypingEventPath, internalAPIMux.Handle(api.EDUServerInputTypingEventPath,
@ -105,4 +186,17 @@ func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(api.EDUServerInputSendToDeviceEventPath,
internal.MakeInternalAPI("inputSendToDeviceEvents", func(req *http.Request) util.JSONResponse {
var request api.InputSendToDeviceEventRequest
var response api.InputSendToDeviceEventResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := t.InputSendToDeviceEvent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View File

@ -265,6 +265,25 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) {
if err := t.eduProducer.SendTyping(t.context, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { if err := t.eduProducer.SendTyping(t.context, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server") util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server")
} }
case gomatrixserverlib.MDirectToDevice:
// https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema
var directPayload gomatrixserverlib.ToDeviceMessage
if err := json.Unmarshal(e.Content, &directPayload); err != nil {
util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal send-to-device events")
continue
}
for userID, byUser := range directPayload.Messages {
for deviceID, message := range byUser {
// TODO: check that the user and the device actually exist here
if err := t.eduProducer.SendToDevice(t.context, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil {
util.GetLogger(t.context).WithError(err).WithFields(logrus.Fields{
"sender": directPayload.Sender,
"user_id": userID,
"device_id": deviceID,
}).Error("Failed to send send-to-device event to edu server")
}
}
}
default: default:
util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu") util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu")
} }

View File

@ -77,6 +77,14 @@ func (p *testEDUProducer) InputTypingEvent(
return nil return nil
} }
func (p *testEDUProducer) InputSendToDeviceEvent(
ctx context.Context,
request *eduAPI.InputSendToDeviceEventRequest,
response *eduAPI.InputSendToDeviceEventResponse,
) error {
return nil
}
type testRoomserverAPI struct { type testRoomserverAPI struct {
inputRoomEvents []api.InputRoomEvent inputRoomEvents []api.InputRoomEvent
queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse

2
go.mod
View File

@ -18,7 +18,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61 github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
github.com/mattn/go-sqlite3 v2.0.2+incompatible github.com/mattn/go-sqlite3 v2.0.2+incompatible

4
go.sum
View File

@ -356,8 +356,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61 h1:3rgoGvj/skUWg+u9E6ycEFs2ZGenEjr28ZtAhAhmZeM= github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf h1:iT2dfJ6JmYNRZBQTXeCNwsZIvfkBbFggzclM8iKnbR0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=

View File

@ -152,6 +152,8 @@ type Dendrite struct {
OutputClientData Topic `yaml:"output_client_data"` OutputClientData Topic `yaml:"output_client_data"`
// Topic for eduserver/api.OutputTypingEvent events. // Topic for eduserver/api.OutputTypingEvent events.
OutputTypingEvent Topic `yaml:"output_typing_event"` OutputTypingEvent Topic `yaml:"output_typing_event"`
// Topic for eduserver/api.OutputSendToDeviceEvent events.
OutputSendToDeviceEvent Topic `yaml:"output_send_to_device_event"`
// Topic for user updates (profile, presence) // Topic for user updates (profile, presence)
UserUpdates Topic `yaml:"user_updates"` UserUpdates Topic `yaml:"user_updates"`
} }

View File

@ -1,4 +1,6 @@
// Copyright 2017 Vector Creations Ltd // Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -16,9 +18,12 @@ package internal
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"runtime" "runtime"
"time" "time"
"go.uber.org/atomic"
) )
// A Transaction is something that can be committed or rolledback. // A Transaction is something that can be committed or rolledback.
@ -107,3 +112,60 @@ type DbProperties interface {
MaxOpenConns() int MaxOpenConns() int
ConnMaxLifetime() time.Duration ConnMaxLifetime() time.Duration
} }
// TransactionWriter allows queuing database writes so that you don't
// contend on database locks in, e.g. SQLite. Only one task will run
// at a time on a given TransactionWriter.
type TransactionWriter struct {
running atomic.Bool
todo chan transactionWriterTask
}
func NewTransactionWriter() *TransactionWriter {
return &TransactionWriter{
todo: make(chan transactionWriterTask),
}
}
// transactionWriterTask represents a specific task.
type transactionWriterTask struct {
db *sql.DB
f func(txn *sql.Tx) error
wait chan error
}
// Do queues a task to be run by a TransactionWriter. The function
// provided will be ran within a transaction as supplied by the
// database parameter. This will block until the task is finished.
func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error {
if w.todo == nil {
return errors.New("not initialised")
}
if !w.running.Load() {
go w.run()
}
task := transactionWriterTask{
db: db,
f: f,
wait: make(chan error, 1),
}
w.todo <- task
return <-task.wait
}
// run processes the tasks for a given transaction writer. Only one
// of these goroutines will run at a time. A transaction will be
// opened using the database object from the task and then this will
// be passed as a parameter to the task function.
func (w *TransactionWriter) run() {
if !w.running.CAS(false, true) {
return
}
defer w.running.Store(false)
for task := range w.todo {
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
return task.f(txn)
})
close(task.wait)
}
}

View File

@ -0,0 +1,113 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
)
// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server.
type OutputSendToDeviceEventConsumer struct {
sendToDeviceConsumer *internal.ContinualConsumer
db storage.Database
serverName gomatrixserverlib.ServerName // our server name
notifier *sync.Notifier
}
// NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer.
// Call Start() to begin consuming from the EDU server.
func NewOutputSendToDeviceEventConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store storage.Database,
) *OutputSendToDeviceEventConsumer {
consumer := internal.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputSendToDeviceEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
}
s := &OutputSendToDeviceEventConsumer{
sendToDeviceConsumer: &consumer,
db: store,
serverName: cfg.Matrix.ServerName,
notifier: n,
}
consumer.ProcessMessage = s.onMessage
return s
}
// Start consuming from EDU api
func (s *OutputSendToDeviceEventConsumer) Start() error {
return s.sendToDeviceConsumer.Start()
}
func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
var output api.OutputSendToDeviceEvent
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("EDU server output log: message parse failure")
return err
}
_, domain, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
return err
}
if domain != s.serverName {
return nil
}
util.GetLogger(context.TODO()).WithFields(log.Fields{
"sender": output.Sender,
"user_id": output.UserID,
"device_id": output.DeviceID,
"event_type": output.Type,
}).Info("sync API received send-to-device event from EDU server")
streamPos := s.db.AddSendToDevice()
_, err = s.db.StoreNewSendForDeviceMessage(
context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent,
)
if err != nil {
log.WithError(err).Errorf("failed to store send-to-device message")
return err
}
s.notifier.OnNewSendToDevice(
output.UserID,
[]string{output.DeviceID},
types.NewStreamToken(0, streamPos),
)
return nil
}

View File

@ -55,10 +55,12 @@ type Database interface {
// sync response for the given user. Events returned will include any client // sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come // transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction // from when the device sent the event via an API that included a transaction
// ID. // ID. A response object must be provided for IncrementaSync to populate - it
IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) // will not create one.
// CompleteSync returns a complete /sync API response for the given user. IncrementalSync(ctx context.Context, res *types.Response, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
CompleteSync(ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error) // CompleteSync returns a complete /sync API response for the given user. A response object
// must be provided for CompleteSync to populate - it will not create one.
CompleteSync(ctx context.Context, res *types.Response, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error)
// GetAccountDataInRange returns all account data for a given user inserted or // GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions // updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
@ -104,4 +106,26 @@ type Database interface {
StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error)
// AddSendToDevice increases the EDU position in the cache and returns the stream position.
AddSendToDevice() types.StreamPosition
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists:
// - "events": a list of send-to-device events that should be included in the sync
// - "changes": a list of send-to-device events that should be updated in the database by
// CleanSendToDeviceUpdates
// - "deletions": a list of send-to-device events which have been confirmed as sent and
// can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after
// starting to wait for an incremental sync with timeout).
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
} }

View File

@ -0,0 +1,171 @@
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const sendToDeviceSchema = `
CREATE SEQUENCE IF NOT EXISTS syncapi_send_to_device_id;
-- Stores send-to-device messages.
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The ID that uniquely identifies this message.
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_send_to_device_id'),
-- The user ID to send the message to.
user_id TEXT NOT NULL,
-- The device ID to send the message to.
device_id TEXT NOT NULL,
-- The event content JSON.
content TEXT NOT NULL,
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
);
`
const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3)
`
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
ORDER BY id DESC
`
const updateSentSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1
WHERE id = ANY($2)
`
const deleteSendToDeviceMessagesSQL = `
DELETE FROM syncapi_send_to_device WHERE id = ANY($1)
`
type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
updateSentSendToDeviceMessagesStmt *sql.Stmt
deleteSendToDeviceMessagesStmt *sql.Stmt
}
func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{}
_, err := db.Exec(sendToDeviceSchema)
if err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
}
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
_, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) {
rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() {
var id types.SendToDeviceNID
var userID, deviceID, content string
var sentByToken *string
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return
}
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return
}
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
}
events = append(events, event)
}
return events, rows.Err()
}
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
) (err error) {
_, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
return
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
) (err error) {
_, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
return
}

View File

@ -69,6 +69,10 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
sendToDevice, err := NewPostgresSendToDeviceTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Invites: invites, Invites: invites,
@ -77,6 +81,8 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (*
Topology: topology, Topology: topology,
CurrentRoomState: currState, CurrentRoomState: currState,
BackwardExtremities: backwardExtremities, BackwardExtremities: backwardExtremities,
SendToDevice: sendToDevice,
SendToDeviceWriter: internal.NewTransactionWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),
} }
return &d, nil return &d, nil

View File

@ -1,3 +1,17 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared package shared
import ( import (
@ -27,6 +41,8 @@ type Database struct {
Topology tables.Topology Topology tables.Topology
CurrentRoomState tables.CurrentRoomState CurrentRoomState tables.CurrentRoomState
BackwardExtremities tables.BackwardsExtremities BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice
SendToDeviceWriter *internal.TransactionWriter
EDUCache *cache.EDUCache EDUCache *cache.EDUCache
} }
@ -89,6 +105,10 @@ func (d *Database) RemoveTypingUser(
return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID)) return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID))
} }
func (d *Database) AddSendToDevice() types.StreamPosition {
return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage())
}
func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.EDUCache.SetTimeoutCallback(fn) d.EDUCache.SetTimeoutCallback(fn)
} }
@ -528,14 +548,14 @@ func (d *Database) addEDUDeltaToResponse(
} }
func (d *Database) IncrementalSync( func (d *Database) IncrementalSync(
ctx context.Context, ctx context.Context, res *types.Response,
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.StreamingToken, fromPos, toPos types.StreamingToken,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
) (*types.Response, error) { ) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos) nextBatchPos := fromPos.WithUpdates(toPos)
res := types.NewResponse(nextBatchPos) res.NextBatch = nextBatchPos.String()
var joinedRoomIDs []string var joinedRoomIDs []string
var err error var err error
@ -568,12 +588,12 @@ func (d *Database) IncrementalSync(
// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed // getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed
// to it. It returns toPos and joinedRoomIDs for use of adding EDUs. // to it. It returns toPos and joinedRoomIDs for use of adding EDUs.
// nolint:nakedret
func (d *Database) getResponseWithPDUsForCompleteSync( func (d *Database) getResponseWithPDUsForCompleteSync(
ctx context.Context, ctx context.Context, res *types.Response,
userID string, userID string,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) ( ) (
res *types.Response,
toPos types.StreamingToken, toPos types.StreamingToken,
joinedRoomIDs []string, joinedRoomIDs []string,
err error, err error,
@ -604,7 +624,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
To: toPos.PDUPosition(), To: toPos.PDUPosition(),
} }
res = types.NewResponse(toPos) res.NextBatch = toPos.String()
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -662,14 +682,15 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
} }
succeeded = true succeeded = true
return res, toPos, joinedRoomIDs, err return //res, toPos, joinedRoomIDs, err
} }
func (d *Database) CompleteSync( func (d *Database) CompleteSync(
ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int, ctx context.Context, res *types.Response,
device authtypes.Device, numRecentEventsPerRoom int,
) (*types.Response, error) { ) (*types.Response, error) {
res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
ctx, device.UserID, numRecentEventsPerRoom, ctx, res, device.UserID, numRecentEventsPerRoom,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -1028,6 +1049,115 @@ func (d *Database) currentStateStreamEventsForRoom(
return s, nil return s, nil
} }
func (d *Database) SendToDeviceUpdatesWaiting(
ctx context.Context, userID, deviceID string,
) (bool, error) {
count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *Database) AddSendToDeviceEvent(
ctx context.Context, txn *sql.Tx,
userID, deviceID, content string,
) error {
return d.SendToDevice.InsertSendToDeviceMessage(
ctx, txn, userID, deviceID, content,
)
}
func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (types.StreamPosition, error) {
j, err := json.Marshal(event)
if err != nil {
return streamPos, err
}
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place.
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
return d.AddSendToDeviceEvent(
ctx, txn, userID, deviceID, string(j),
)
})
if err != nil {
return streamPos, err
}
return streamPos, nil
}
func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context,
userID, deviceID string,
token types.StreamingToken,
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
// First of all, get our send-to-device updates for this user.
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
}
// If there's nothing to do then stop here.
if len(events) == 0 {
return nil, nil, nil, nil
}
// Work out whether we need to update any of the database entries.
toReturn := []types.SendToDeviceEvent{}
toUpdate := []types.SendToDeviceNID{}
toDelete := []types.SendToDeviceNID{}
for _, event := range events {
if event.SentByToken == nil {
// If the event has no sent-by token yet then we haven't attempted to send
// it. Record the current requested sync token in the database.
toUpdate = append(toUpdate, event.ID)
toReturn = append(toReturn, event)
event.SentByToken = &token
} else if token.IsAfter(*event.SentByToken) {
// The event had a sync token, therefore we've sent it before. The current
// sync token is now after the stored one so we can assume that the client
// successfully completed the previous sync (it would re-request it otherwise)
// so we can remove the entry from the database.
toDelete = append(toDelete, event.ID)
} else {
// It looks like the sync is being re-requested, maybe it timed out or
// failed. Re-send any that should have been acknowledged by now.
toReturn = append(toReturn, event)
}
}
return toReturn, toUpdate, toDelete, nil
}
func (d *Database) CleanSendToDeviceUpdates(
ctx context.Context,
toUpdate, toDelete []types.SendToDeviceNID,
token types.StreamingToken,
) (err error) {
if len(toUpdate) == 0 && len(toDelete) == 0 {
return nil
}
// If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
}
// Now update any outstanding send-to-device messages with the new sync token.
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
}
return nil
})
return
}
// There may be some overlap where events in stateEvents are already in recentEvents, so filter // There may be some overlap where events in stateEvents are already in recentEvents, so filter
// them out so we don't include them twice in the /sync response. They should be in recentEvents // them out so we don't include them twice in the /sync response. They should be in recentEvents
// only, so clients get to the correct state once they have rolled forward. // only, so clients get to the correct state once they have rolled forward.

View File

@ -0,0 +1,172 @@
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const sendToDeviceSchema = `
-- Stores send-to-device messages.
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The ID that uniquely identifies this message.
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The user ID to send the message to.
user_id TEXT NOT NULL,
-- The device ID to send the message to.
device_id TEXT NOT NULL,
-- The event content JSON.
content TEXT NOT NULL,
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
);
`
const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3)
`
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
ORDER BY id DESC
`
const updateSentSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1
WHERE id IN ($2)
`
const deleteSendToDeviceMessagesSQL = `
DELETE FROM syncapi_send_to_device WHERE id IN ($1)
`
type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
}
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{}
_, err := db.Exec(sendToDeviceSchema)
if err != nil {
return nil, err
}
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
_, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) {
rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() {
var id types.SendToDeviceNID
var userID, deviceID, content string
var sentByToken *string
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return
}
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return
}
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
}
events = append(events, event)
}
return events, rows.Err()
}
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
) (err error) {
query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", internal.QueryVariadic(1+len(nids)), 1)
params := make([]interface{}, 1+len(nids))
params[0] = token
for k, v := range nids {
params[k+1] = v
}
_, err = txn.ExecContext(ctx, query, params...)
return
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
) (err error) {
query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", internal.QueryVariadic(len(nids)), 1)
params := make([]interface{}, 1+len(nids))
for k, v := range nids {
params[k] = v
}
_, err = txn.ExecContext(ctx, query, params...)
return
}

View File

@ -95,6 +95,10 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
sendToDevice, err := NewSqliteSendToDeviceTable(d.db)
if err != nil {
return err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Invites: invites, Invites: invites,
@ -103,6 +107,8 @@ func (d *SyncServerDatasource) prepare() (err error) {
BackwardExtremities: bwExtrem, BackwardExtremities: bwExtrem,
CurrentRoomState: roomState, CurrentRoomState: roomState,
Topology: topology, Topology: topology,
SendToDevice: sendToDevice,
SendToDeviceWriter: internal.NewTransactionWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),
} }
return nil return nil

View File

@ -3,6 +3,7 @@ package storage_test
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"encoding/json"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -157,7 +158,8 @@ func TestSyncResponse(t *testing.T) {
from := types.NewStreamToken( // pretend we are at the penultimate event from := types.NewStreamToken( // pretend we are at the penultimate event
positions[len(positions)-2], types.StreamPosition(0), positions[len(positions)-2], types.StreamPosition(0),
) )
return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) res := types.NewResponse()
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
}, },
WantTimeline: events[len(events)-1:], WantTimeline: events[len(events)-1:],
}, },
@ -169,8 +171,9 @@ func TestSyncResponse(t *testing.T) {
from := types.NewStreamToken( // pretend we are 10 events behind from := types.NewStreamToken( // pretend we are 10 events behind
positions[len(positions)-11], types.StreamPosition(0), positions[len(positions)-11], types.StreamPosition(0),
) )
res := types.NewResponse()
// limit is set to 5 // limit is set to 5
return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
}, },
// want the last 5 events, NOT the last 10. // want the last 5 events, NOT the last 10.
WantTimeline: events[len(events)-5:], WantTimeline: events[len(events)-5:],
@ -180,8 +183,9 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "CompleteSync limited", Name: "CompleteSync limited",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
res := types.NewResponse()
// limit set to 5 // limit set to 5
return db.CompleteSync(ctx, testUserDeviceA, 5) return db.CompleteSync(ctx, res, testUserDeviceA, 5)
}, },
// want the last 5 events // want the last 5 events
WantTimeline: events[len(events)-5:], WantTimeline: events[len(events)-5:],
@ -193,7 +197,8 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "CompleteSync", Name: "CompleteSync",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
return db.CompleteSync(ctx, testUserDeviceA, len(events)+1) res := types.NewResponse()
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
}, },
WantTimeline: events, WantTimeline: events,
// We want no state at all as that field in /sync is the delta between the token (beginning of time) // We want no state at all as that field in /sync is the delta between the token (beginning of time)
@ -234,7 +239,8 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
positions[len(positions)-2], types.StreamPosition(0), positions[len(positions)-2], types.StreamPosition(0),
) )
res, err := db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false) res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
if err != nil { if err != nil {
t.Fatalf("failed to IncrementalSync with latest token") t.Fatalf("failed to IncrementalSync with latest token")
} }
@ -512,6 +518,89 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
} }
} }
func TestSendToDeviceBehaviour(t *testing.T) {
//t.Parallel()
db := MustCreateDatabase(t)
// At this point there should be no messages. We haven't sent anything
// yet.
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("first call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0))
if err != nil {
return
}
// Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{
Sender: "bob",
Type: "m.type",
Content: json.RawMessage("{}"),
})
if err != nil {
t.Fatal(err)
}
// At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
if err != nil {
return
}
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("third call should have one update still")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
if err != nil {
return
}
// At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
t.Fatal("fourth call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1))
if err != nil {
return
}
// At this point we should still have no updates, because no new updates have been
// sent.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("fifth call should have no updates")
}
}
func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {
if len(gots) != len(wants) { if len(gots) != len(wants) {
t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants))

View File

@ -1,3 +1,17 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables package tables
import ( import (
@ -94,3 +108,28 @@ type BackwardsExtremities interface {
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
} }
// SendToDevice tracks send-to-device messages which are sent to individual
// clients. Each message gets inserted into this table at the point that we
// receive it from the EDU server.
//
// We're supposed to try and do our best to deliver send-to-device messages
// once, but the only way that we can really guarantee that they have been
// delivered is if the client successfully requests the next sync as given
// in the next_batch. Each time the device syncs, we will request all of the
// updates that either haven't been sent yet, along with all updates that we
// *have* sent but we haven't confirmed to have been received yet. If it's the
// first time we're sending a given update then we update the table to say
// what the "since" parameter was when we tried to send it.
//
// When the client syncs again, if their "since" parameter is *later* than
// the recorded one, we drop the entry from the DB as it's "sent". If the
// sync parameter isn't later then we will keep including the updates in the
// sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
}

View File

@ -120,6 +120,18 @@ func (n *Notifier) OnNewEvent(
} }
} }
func (n *Notifier) OnNewSendToDevice(
userID string, deviceIDs []string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUserDevice(userID, deviceIDs, latestPos)
}
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed. // updates for a user. Must be closed.
// notify for anything before sincePos // notify for anything before sincePos
@ -189,8 +201,8 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
// wakeupUserDevice will wake up the sync stream for a specific user device. Other // wakeupUserDevice will wake up the sync stream for a specific user device. Other
// device streams will be left alone. // device streams will be left alone.
// nolint:unused // nolint:unused
func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) { func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) {
for userID, deviceID := range userDevices { for _, deviceID := range deviceIDs {
if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil { if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil {
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
} }

View File

@ -172,7 +172,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
wake := "two" wake := "two"
n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter) n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter)
if result := <-awoken; result != wake { if result := <-awoken; result != wake {
t.Fatalf("expected to wake %q, got %q", wake, result) t.Fatalf("expected to wake %q, got %q", wake, result)

View File

@ -1,4 +1,6 @@
// Copyright 2017 Vector Creations Ltd // Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -15,6 +17,7 @@
package sync package sync
import ( import (
"context"
"net/http" "net/http"
"time" "time"
@ -54,17 +57,18 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
JSON: jsonerror.Unknown(err.Error()), JSON: jsonerror.Unknown(err.Error()),
} }
} }
logger := util.GetLogger(req.Context()).WithFields(log.Fields{ logger := util.GetLogger(req.Context()).WithFields(log.Fields{
"userID": device.UserID, "user_id": device.UserID,
"deviceID": device.ID, "device_id": device.ID,
"since": syncReq.since, "since": syncReq.since,
"timeout": syncReq.timeout, "timeout": syncReq.timeout,
"limit": syncReq.limit, "limit": syncReq.limit,
}) })
currPos := rp.notifier.CurrentPosition() currPos := rp.notifier.CurrentPosition()
if shouldReturnImmediately(syncReq) { if rp.shouldReturnImmediately(syncReq) {
syncData, err = rp.currentSyncForUser(*syncReq, currPos) syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil { if err != nil {
logger.WithError(err).Error("rp.currentSyncForUser failed") logger.WithError(err).Error("rp.currentSyncForUser failed")
@ -116,7 +120,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// response. This ensures that we don't waste the hard work // response. This ensures that we don't waste the hard work
// of calculating the sync only to get timed out before we // of calculating the sync only to get timed out before we
// can respond // can respond
syncData, err = rp.currentSyncForUser(*syncReq, currPos) syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil { if err != nil {
logger.WithError(err).Error("rp.currentSyncForUser failed") logger.WithError(err).Error("rp.currentSyncForUser failed")
@ -134,19 +137,59 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
} }
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
// TODO: handle ignored users res = types.NewResponse()
if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device, req.limit) since := types.NewStreamToken(0, 0)
} else { if req.since != nil {
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState) since = *req.since
} }
// See if we have any new tasks to do for the send-to-device messaging.
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, since)
if err != nil {
return nil, err
}
// TODO: handle ignored users
if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
} else {
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
}
if err != nil { if err != nil {
return return
} }
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
if err != nil {
return
}
// Before we return the sync response, make sure that we take action on
// any send-to-device database updates or deletions that we need to do.
// Then add the updates into the sync response.
if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database.
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since)
if err != nil {
return
}
}
if len(events) > 0 {
// Add the updates into the sync response.
for _, event := range events {
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
}
// Get the next_batch from the sync response and increase the
// EDU counter.
if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
pos.Positions[1]++
res.NextBatch = pos.String()
}
}
return return
} }
@ -238,6 +281,10 @@ func (rp *RequestPool) appendAccountData(
// shouldReturnImmediately returns whether the /sync request is an initial sync, // shouldReturnImmediately returns whether the /sync request is an initial sync,
// or timeout=0, or full_state=true, in any of the cases the request should // or timeout=0, or full_state=true, in any of the cases the request should
// return immediately. // return immediately.
func shouldReturnImmediately(syncReq *syncRequest) bool { func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool {
return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState {
return true
}
waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID)
return werr == nil && waiting
} }

View File

@ -78,7 +78,14 @@ func SetupSyncAPIComponent(
base.Cfg, base.KafkaConsumer, notifier, syncDB, base.Cfg, base.KafkaConsumer, notifier, syncDB,
) )
if err = typingConsumer.Start(); err != nil { if err = typingConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start typing server consumer") logrus.WithError(err).Panicf("failed to start typing consumer")
}
sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer(
base.Cfg, base.KafkaConsumer, notifier, syncDB,
)
if err = sendToDeviceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
} }
routing.Setup(base.PublicAPIMux, requestPool, syncDB, deviceDB, federation, rsAPI, cfg) routing.Setup(base.PublicAPIMux, requestPool, syncDB, deviceDB, federation, rsAPI, cfg)

View File

@ -296,13 +296,14 @@ type Response struct {
Invite map[string]InviteResponse `json:"invite"` Invite map[string]InviteResponse `json:"invite"`
Leave map[string]LeaveResponse `json:"leave"` Leave map[string]LeaveResponse `json:"leave"`
} `json:"rooms"` } `json:"rooms"`
ToDevice struct {
Events []gomatrixserverlib.SendToDeviceEvent `json:"events"`
} `json:"to_device"`
} }
// NewResponse creates an empty response with initialised maps. // NewResponse creates an empty response with initialised maps.
func NewResponse(token StreamingToken) *Response { func NewResponse() *Response {
res := Response{ res := Response{}
NextBatch: token.String(),
}
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
res.Rooms.Join = make(map[string]JoinResponse) res.Rooms.Join = make(map[string]JoinResponse)
@ -315,6 +316,7 @@ func NewResponse(token StreamingToken) *Response {
// This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse. // This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse.
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0)
return &res return &res
} }
@ -326,7 +328,8 @@ func (r *Response) IsEmpty() bool {
len(r.Rooms.Invite) == 0 && len(r.Rooms.Invite) == 0 &&
len(r.Rooms.Leave) == 0 && len(r.Rooms.Leave) == 0 &&
len(r.AccountData.Events) == 0 && len(r.AccountData.Events) == 0 &&
len(r.Presence.Events) == 0 len(r.Presence.Events) == 0 &&
len(r.ToDevice.Events) == 0
} }
// JoinResponse represents a /sync response for a room which is under the 'join' key. // JoinResponse represents a /sync response for a room which is under the 'join' key.
@ -393,3 +396,13 @@ func NewLeaveResponse() *LeaveResponse {
res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0)
return &res return &res
} }
type SendToDeviceNID int
type SendToDeviceEvent struct {
gomatrixserverlib.SendToDeviceEvent
ID SendToDeviceNID
UserID string
DeviceID string
SentByToken *StreamingToken
}

View File

@ -39,3 +39,8 @@ Ignore invite in incremental sync
# Blacklisted because this test calls /r0/events which we don't implement # Blacklisted because this test calls /r0/events which we don't implement
New room members see their own join event New room members see their own join event
Existing members see new members' join events Existing members see new members' join events
# Blacklisted because the federation work for these hasn't been finished yet.
Can recv device messages over federation
Device messages over federation wake up /sync
Wildcard device messages over federation wake up /sync

View File

@ -289,3 +289,15 @@ Existing members see new members' join events
Inbound federation can receive events Inbound federation can receive events
Inbound federation can receive redacted events Inbound federation can receive redacted events
Can logout current device Can logout current device
Can send a message directly to a device using PUT /sendToDevice
Can recv a device message using /sync
Can recv device messages until they are acknowledged
Device messages with the same txn_id are deduplicated
Device messages wake up /sync
# TODO: separate PR for: Can recv device messages over federation
# TODO: separate PR for: Device messages over federation wake up /sync
Can send messages with a wildcard device id
Can send messages with a wildcard device id to two devices
Wildcard device messages wake up /sync
# TODO: separate PR for: Wildcard device messages over federation wake up /sync
Can send a to-device message to two users which both receive it using /sync