diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index e935805f..70b81bbc 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/mscs" "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" @@ -148,6 +149,12 @@ func main() { base.PublicMediaAPIMux, ) + if len(base.Cfg.MSCs.MSCs) > 0 { + if err := mscs.Enable(base, &monolith); err != nil { + logrus.WithError(err).Fatalf("Failed to enable MSCs") + } + } + // Expose the matrix APIs directly rather than putting them under a /api path. go func() { base.SetupAndServeHTTP( diff --git a/federationsender/api/api.go b/federationsender/api/api.go index b0522516..a4d15f1f 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -48,6 +48,7 @@ type FederationSenderInternalAPI interface { // Query the server names of the joined hosts in a room. // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice // containing only the server names (without information for membership events). + // The response will include this server if they are joined to the room. QueryJoinedHostServerNamesInRoom( ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, @@ -104,6 +105,7 @@ type PerformJoinRequest struct { } type PerformJoinResponse struct { + JoinedVia gomatrixserverlib.ServerName LastError *gomatrix.HTTPError } diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index a7484476..45f33ff7 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -105,6 +105,7 @@ func (r *FederationSenderInternalAPI) PerformJoin( } // We're all good. + response.JoinedVia = serverName return } diff --git a/federationsender/internal/query.go b/federationsender/internal/query.go index 253400a2..8ba228d1 100644 --- a/federationsender/internal/query.go +++ b/federationsender/internal/query.go @@ -4,7 +4,6 @@ import ( "context" "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/gomatrixserverlib" ) // QueryJoinedHostServerNamesInRoom implements api.FederationSenderInternalAPI @@ -13,17 +12,11 @@ func (f *FederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) (err error) { - joinedHosts, err := f.db.GetJoinedHosts(ctx, request.RoomID) + joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}) if err != nil { return } - - response.ServerNames = make([]gomatrixserverlib.ServerName, 0, len(joinedHosts)) - for _, host := range joinedHosts { - response.ServerNames = append(response.ServerNames, host.ServerName) - } - - // TODO: remove duplicates? + response.ServerNames = joinedHosts return } diff --git a/internal/config/config.go b/internal/config/config.go index 9d9e2414..b8b12d0c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -66,6 +66,8 @@ type Dendrite struct { SyncAPI SyncAPI `yaml:"sync_api"` UserAPI UserAPI `yaml:"user_api"` + MSCs MSCs `yaml:"mscs"` + // The config for tracing the dendrite servers. Tracing struct { // Set to true to enable tracer hooks. If false, no tracing is set up. @@ -306,6 +308,7 @@ func (c *Dendrite) Defaults() { c.SyncAPI.Defaults() c.UserAPI.Defaults() c.AppServiceAPI.Defaults() + c.MSCs.Defaults() c.Wiring() } @@ -319,7 +322,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { &c.EDUServer, &c.FederationAPI, &c.FederationSender, &c.KeyServer, &c.MediaAPI, &c.RoomServer, &c.SigningKeyServer, &c.SyncAPI, &c.UserAPI, - &c.AppServiceAPI, + &c.AppServiceAPI, &c.MSCs, } { c.Verify(configErrs, isMonolith) } @@ -337,6 +340,7 @@ func (c *Dendrite) Wiring() { c.SyncAPI.Matrix = &c.Global c.UserAPI.Matrix = &c.Global c.AppServiceAPI.Matrix = &c.Global + c.MSCs.Matrix = &c.Global c.ClientAPI.Derived = &c.Derived c.AppServiceAPI.Derived = &c.Derived diff --git a/internal/config/config_mscs.go b/internal/config/config_mscs.go new file mode 100644 index 00000000..776d0b64 --- /dev/null +++ b/internal/config/config_mscs.go @@ -0,0 +1,19 @@ +package config + +type MSCs struct { + Matrix *Global `yaml:"-"` + + // The MSCs to enable, currently only `msc2836` is supported. + MSCs []string `yaml:"mscs"` + + Database DatabaseOptions `yaml:"database"` +} + +func (c *MSCs) Defaults() { + c.Database.Defaults() + c.Database.ConnectionString = "file:mscs.db" +} + +func (c *MSCs) Verify(configErrs *ConfigErrors, isMonolith bool) { + checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) +} diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go new file mode 100644 index 00000000..223282a2 --- /dev/null +++ b/internal/hooks/hooks.go @@ -0,0 +1,74 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package hooks exposes places in Dendrite where custom code can be executed, useful for MSCs. +// Hooks can only be run in monolith mode. +package hooks + +import "sync" + +const ( + // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent + // It is run when a new event is persisted in the roomserver. + // Usage: + // hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... }) + KindNewEventPersisted = "new_event_persisted" + // KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent + // It is run before a new event is processed by the roomserver. This hook can be used + // to modify the event before it is persisted by adding data to `unsigned`. + // Usage: + // hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { + // ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + // _ = ev.SetUnsignedField("key", "val") + // }) + KindNewEventReceived = "new_event_received" +) + +var ( + hookMap = make(map[string][]func(interface{})) + hookMu = sync.Mutex{} + enabled = false +) + +// Enable all hooks. This may slow down the server slightly. Required for MSCs to work. +func Enable() { + enabled = true +} + +// Run any hooks +func Run(kind string, data interface{}) { + if !enabled { + return + } + cbs := callbacks(kind) + for _, cb := range cbs { + cb(data) + } +} + +// Attach a hook +func Attach(kind string, callback func(interface{})) { + if !enabled { + return + } + hookMu.Lock() + defer hookMu.Unlock() + hookMap[kind] = append(hookMap[kind], callback) +} + +func callbacks(kind string) []func(interface{}) { + hookMu.Lock() + defer hookMu.Unlock() + return hookMap[kind] +} diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go new file mode 100644 index 00000000..865bc311 --- /dev/null +++ b/internal/mscs/msc2836/msc2836.go @@ -0,0 +1,530 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package msc2836 'Threading' implements https://github.com/matrix-org/matrix-doc/pull/2836 +package msc2836 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + fs "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/hooks" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +const ( + constRelType = "m.reference" + constRoomIDKey = "relationship_room_id" + constRoomServers = "relationship_servers" +) + +type EventRelationshipRequest struct { + EventID string `json:"event_id"` + MaxDepth int `json:"max_depth"` + MaxBreadth int `json:"max_breadth"` + Limit int `json:"limit"` + DepthFirst bool `json:"depth_first"` + RecentFirst bool `json:"recent_first"` + IncludeParent bool `json:"include_parent"` + IncludeChildren bool `json:"include_children"` + Direction string `json:"direction"` + Batch string `json:"batch"` + AutoJoin bool `json:"auto_join"` +} + +func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) { + var relation EventRelationshipRequest + relation.Defaults() + if err := json.NewDecoder(body).Decode(&relation); err != nil { + return nil, err + } + return &relation, nil +} + +func (r *EventRelationshipRequest) Defaults() { + r.Limit = 100 + r.MaxBreadth = 10 + r.MaxDepth = 3 + r.DepthFirst = false + r.RecentFirst = true + r.IncludeParent = false + r.IncludeChildren = false + r.Direction = "down" +} + +type EventRelationshipResponse struct { + Events []gomatrixserverlib.ClientEvent `json:"events"` + NextBatch string `json:"next_batch"` + Limited bool `json:"limited"` +} + +// Enable this MSC +// nolint:gocyclo +func Enable( + base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI, + userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, +) error { + db, err := NewDatabase(&base.Cfg.MSCs.Database) + if err != nil { + return fmt.Errorf("Cannot enable MSC2836: %w", err) + } + hooks.Enable() + hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { + he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + hookErr := db.StoreRelation(context.Background(), he) + if hookErr != nil { + util.GetLogger(context.Background()).WithError(hookErr).Error( + "failed to StoreRelation", + ) + } + }) + hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { + he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + ctx := context.Background() + // we only inject metadata for events our server sends + userID := he.Sender() + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return + } + if domain != base.Cfg.Global.ServerName { + return + } + // if this event has an m.relationship, add on the room_id and servers to unsigned + parent, child, relType := parentChildEventIDs(he) + if parent == "" || child == "" || relType == "" { + return + } + event, joinedToRoom := getEventIfVisible(ctx, rsAPI, parent, userID) + if !joinedToRoom { + return + } + err = he.SetUnsignedField(constRoomIDKey, event.RoomID()) + if err != nil { + util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField") + return + } + + var servers []gomatrixserverlib.ServerName + if fsAPI != nil { + var res fs.QueryJoinedHostServerNamesInRoomResponse + err = fsAPI.QueryJoinedHostServerNamesInRoom(ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: event.RoomID(), + }, &res) + if err != nil { + util.GetLogger(context.Background()).WithError(err).Warn("Failed to QueryJoinedHostServerNamesInRoom") + return + } + servers = res.ServerNames + } else { + servers = []gomatrixserverlib.ServerName{ + base.Cfg.Global.ServerName, + } + } + err = he.SetUnsignedField(constRoomServers, servers) + if err != nil { + util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField") + return + } + }) + + base.PublicClientAPIMux.Handle("/unstable/event_relationships", + httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI)), + ).Methods(http.MethodPost, http.MethodOptions) + + base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( + "msc2836_event_relationships", func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), base.Cfg.Global.ServerName, keyRing, + ) + if fedReq == nil { + return errResp + } + return federatedEventRelationship(req.Context(), fedReq, db, rsAPI) + }, + )).Methods(http.MethodPost, http.MethodOptions) + return nil +} + +type reqCtx struct { + ctx context.Context + rsAPI roomserver.RoomserverInternalAPI + db Database + req *EventRelationshipRequest + userID string + isFederatedRequest bool +} + +func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { + return func(req *http.Request, device *userapi.Device) util.JSONResponse { + relation, err := NewEventRelationshipRequest(req.Body) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to decode HTTP request as JSON") + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } + rc := reqCtx{ + ctx: req.Context(), + req: relation, + userID: device.UserID, + rsAPI: rsAPI, + isFederatedRequest: false, + db: db, + } + res, resErr := rc.process() + if resErr != nil { + return *resErr + } + + return util.JSONResponse{ + Code: 200, + JSON: res, + } + } +} + +func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI) util.JSONResponse { + relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content())) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON") + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } + rc := reqCtx{ + ctx: ctx, + req: relation, + userID: "", + rsAPI: rsAPI, + isFederatedRequest: true, + db: db, + } + res, resErr := rc.process() + if resErr != nil { + return *resErr + } + + return util.JSONResponse{ + Code: 200, + JSON: res, + } +} + +func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) { + var res EventRelationshipResponse + var returnEvents []*gomatrixserverlib.HeaderedEvent + // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. + // We should have the event being referenced so don't give any claimed room ID / servers + event := rc.getEventIfVisible(rc.req.EventID, "", nil) + if event == nil { + return nil, &util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), + } + } + + // Retrieve the event. Add it to response array. + returnEvents = append(returnEvents, event) + + if rc.req.IncludeParent { + if parentEvent := rc.includeParent(event); parentEvent != nil { + returnEvents = append(returnEvents, parentEvent) + } + } + + if rc.req.IncludeChildren { + remaining := rc.req.Limit - len(returnEvents) + if remaining > 0 { + children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, rc.req.RecentFirst) + if resErr != nil { + return nil, resErr + } + returnEvents = append(returnEvents, children...) + } + } + + remaining := rc.req.Limit - len(returnEvents) + var walkLimited bool + if remaining > 0 { + included := make(map[string]bool, len(returnEvents)) + for _, ev := range returnEvents { + included[ev.EventID()] = true + } + var events []*gomatrixserverlib.HeaderedEvent + events, walkLimited = walkThread( + rc.ctx, rc.db, rc, included, remaining, + ) + returnEvents = append(returnEvents, events...) + } + res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) + for i, ev := range returnEvents { + res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll) + } + res.Limited = remaining == 0 || walkLimited + return &res, nil +} + +// If include_parent: true and there is a valid m.relationship field in the event, +// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. +func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) { + parentID, _, _ := parentChildEventIDs(event) + if parentID == "" { + return nil + } + claimedRoomID, claimedServers := roomIDAndServers(event) + return rc.getEventIfVisible(parentID, claimedRoomID, claimedServers) +} + +// If include_children: true, lookup all events which have event_id as an m.relationship +// Apply history visibility checks to all these events and add the ones which pass into the response array, +// honouring the recent_first flag and the limit. +func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { + children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent") + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + var childEvents []*gomatrixserverlib.HeaderedEvent + for _, child := range children { + // in order for us to even know about the children the server must be joined to those rooms, hence pass no claimed room ID or servers. + childEvent := rc.getEventIfVisible(child.EventID, "", nil) + if childEvent != nil { + childEvents = append(childEvents, childEvent) + } + } + if len(childEvents) > limit { + return childEvents[:limit], nil + } + return childEvents, nil +} + +// Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, +// honouring the limit, max_depth and max_breadth values according to the following rules +// nolint: unparam +func walkThread( + ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, +) ([]*gomatrixserverlib.HeaderedEvent, bool) { + if rc.req.Direction != "down" { + util.GetLogger(ctx).Error("not implemented: direction=up") + return nil, false + } + var result []*gomatrixserverlib.HeaderedEvent + eventWalker := walker{ + ctx: ctx, + req: rc.req, + db: db, + fn: func(wi *walkInfo) bool { + // If already processed event, skip. + if included[wi.EventID] { + return false + } + + // If the response array is >= limit, stop. + if len(result) >= limit { + return true + } + + // Process the event. + // TODO: Include edge information: room ID and servers + event := rc.getEventIfVisible(wi.EventID, "", nil) + if event != nil { + result = append(result, event) + } + included[wi.EventID] = true + return false + }, + } + limited, err := eventWalker.WalkFrom(rc.req.EventID) + if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", rc.req.EventID) + } + return result, limited +} + +func (rc *reqCtx) getEventIfVisible(eventID string, claimedRoomID string, claimedServers []string) *gomatrixserverlib.HeaderedEvent { + event, joinedToRoom := getEventIfVisible(rc.ctx, rc.rsAPI, eventID, rc.userID) + if event != nil && joinedToRoom { + return event + } + // either we don't have the event or we aren't joined to the room, regardless we should try joining if auto join is enabled + if !rc.req.AutoJoin { + return nil + } + // if we're doing this on behalf of a random server don't auto-join rooms regardless of what the request says + if rc.isFederatedRequest { + return nil + } + roomID := claimedRoomID + var servers []gomatrixserverlib.ServerName + if event != nil { + roomID = event.RoomID() + } + for _, s := range claimedServers { + servers = append(servers, gomatrixserverlib.ServerName(s)) + } + var joinRes roomserver.PerformJoinResponse + rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{ + UserID: rc.userID, + Content: map[string]interface{}{}, + RoomIDOrAlias: roomID, + ServerNames: servers, + }, &joinRes) + if joinRes.Error != nil { + util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", roomID).Error("Failed to auto-join room") + return nil + } + if event != nil { + return event + } + // TODO: hit /event_relationships on the server we joined via + util.GetLogger(rc.ctx).Infof("joined room but need to fetch event TODO") + return nil +} + +func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) (*gomatrixserverlib.HeaderedEvent, bool) { + var queryEventsRes roomserver.QueryEventsByIDResponse + err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{ + EventIDs: []string{eventID}, + }, &queryEventsRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID") + return nil, false + } + if len(queryEventsRes.Events) == 0 { + util.GetLogger(ctx).Infof("event does not exist") + return nil, false // event does not exist + } + event := queryEventsRes.Events[0] + + // Allow events if the member is in the room + // TODO: This does not honour history_visibility + // TODO: This does not honour m.room.create content + var queryMembershipRes roomserver.QueryMembershipForUserResponse + err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{ + RoomID: event.RoomID(), + UserID: userID, + }, &queryMembershipRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") + return nil, false + } + return event, queryMembershipRes.IsInRoom +} + +type walkInfo struct { + eventInfo + SiblingNumber int + Depth int +} + +type walker struct { + ctx context.Context + req *EventRelationshipRequest + db Database + fn func(wi *walkInfo) bool // callback invoked for each event walked, return true to terminate the walk +} + +// WalkFrom the event ID given +func (w *walker) WalkFrom(eventID string) (limited bool, err error) { + children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) + if err != nil { + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + return false, err + } + var next *walkInfo + toWalk := w.addChildren(nil, children, 1) + next, toWalk = w.nextChild(toWalk) + for next != nil { + stop := w.fn(next) + if stop { + return true, nil + } + // find the children's children + children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst) + if err != nil { + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + return false, err + } + toWalk = w.addChildren(toWalk, children, next.Depth+1) + next, toWalk = w.nextChild(toWalk) + } + + return false, nil +} + +// addChildren adds an event's children to the to walk data structure +func (w *walker) addChildren(toWalk []walkInfo, children []eventInfo, depthOfChildren int) []walkInfo { + // Check what number child this event is (ordered by recent_first) compared to its parent, does it exceed (greater than) max_breadth? If yes, skip. + if len(children) > w.req.MaxBreadth { + children = children[:w.req.MaxBreadth] + } + // Check how deep the event is compared to event_id, does it exceed (greater than) max_depth? If yes, skip. + if depthOfChildren > w.req.MaxDepth { + return toWalk + } + + if w.req.DepthFirst { + // the slice is a stack so push them in reverse order so we pop them in the correct order + // e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3 + for i := len(children) - 1; i >= 0; i-- { + toWalk = append(toWalk, walkInfo{ + eventInfo: children[i], + SiblingNumber: i + 1, // index from 1 + Depth: depthOfChildren, + }) + } + } else { + // the slice is a queue so push them in normal order to we dequeue them in the correct order + // e.g [1,2,3] => 1, [2, 3] => 2 , [3] => 3, [] + for i := range children { + toWalk = append(toWalk, walkInfo{ + eventInfo: children[i], + SiblingNumber: i + 1, // index from 1 + Depth: depthOfChildren, + }) + } + } + return toWalk +} + +func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) { + if len(toWalk) == 0 { + return nil, nil + } + var child walkInfo + if w.req.DepthFirst { + // toWalk is a stack so pop the child off + child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1] + return &child, toWalk + } + // toWalk is a queue so shift the child off + child, toWalk = toWalk[0], toWalk[1:] + return &child, toWalk +} diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go new file mode 100644 index 00000000..cbf8b726 --- /dev/null +++ b/internal/mscs/msc2836/msc2836_test.go @@ -0,0 +1,574 @@ +package msc2836_test + +import ( + "bytes" + "context" + "crypto/ed25519" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/hooks" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/mscs/msc2836" + "github.com/matrix-org/dendrite/internal/setup" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +var ( + client = &http.Client{ + Timeout: 10 * time.Second, + } +) + +// Basic sanity check of MSC2836 logic. Injects a thread that looks like: +// A +// | +// B +// / \ +// C D +// /|\ +// E F G +// | +// H +// And makes sure POST /event_relationships works with various parameters +func TestMSC2836(t *testing.T) { + alice := "@alice:localhost" + bob := "@bob:localhost" + charlie := "@charlie:localhost" + roomIDA := "!alice:localhost" + roomIDB := "!bob:localhost" + roomIDC := "!charlie:localhost" + // give access tokens to all three users + nopUserAPI := &testUserAPI{ + accessTokens: make(map[string]userapi.Device), + } + nopUserAPI.accessTokens["alice"] = userapi.Device{ + AccessToken: "alice", + DisplayName: "Alice", + UserID: alice, + } + nopUserAPI.accessTokens["bob"] = userapi.Device{ + AccessToken: "bob", + DisplayName: "Bob", + UserID: bob, + } + nopUserAPI.accessTokens["charlie"] = userapi.Device{ + AccessToken: "charlie", + DisplayName: "Charles", + UserID: charlie, + } + eventA := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDA, + Sender: alice, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[A] Do you know shelties?", + }, + }) + eventB := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDB, + Sender: bob, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[B] I <3 shelties", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventA.EventID(), + }, + }, + }) + eventC := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDB, + Sender: bob, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[C] like so much", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventB.EventID(), + }, + }, + }) + eventD := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDA, + Sender: alice, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[D] but what are shelties???", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventB.EventID(), + }, + }, + }) + eventE := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDB, + Sender: bob, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[E] seriously???", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventD.EventID(), + }, + }, + }) + eventF := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDC, + Sender: charlie, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[F] omg how do you not know what shelties are", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventD.EventID(), + }, + }, + }) + eventG := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDA, + Sender: alice, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[G] looked it up, it's a sheltered person?", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventD.EventID(), + }, + }, + }) + eventH := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDB, + Sender: bob, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[H] it's a dog!!!!!", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventE.EventID(), + }, + }, + }) + // make everyone joined to each other's rooms + nopRsAPI := &testRoomserverAPI{ + userToJoinedRooms: map[string][]string{ + alice: []string{roomIDA, roomIDB, roomIDC}, + bob: []string{roomIDA, roomIDB, roomIDC}, + charlie: []string{roomIDA, roomIDB, roomIDC}, + }, + events: map[string]*gomatrixserverlib.HeaderedEvent{ + eventA.EventID(): eventA, + eventB.EventID(): eventB, + eventC.EventID(): eventC, + eventD.EventID(): eventD, + eventE.EventID(): eventE, + eventF.EventID(): eventF, + eventG.EventID(): eventG, + eventH.EventID(): eventH, + }, + } + router := injectEvents(t, nopUserAPI, nopRsAPI, []*gomatrixserverlib.HeaderedEvent{ + eventA, eventB, eventC, eventD, eventE, eventF, eventG, eventH, + }) + cancel := runServer(t, router) + defer cancel() + + t.Run("returns 403 on invalid event IDs", func(t *testing.T) { + _ = postRelationships(t, 403, "alice", newReq(t, map[string]interface{}{ + "event_id": "$invalid", + })) + }) + t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) { + nopUserAPI.accessTokens["frank"] = userapi.Device{ + AccessToken: "frank", + DisplayName: "Frank Not In Room", + UserID: "@frank:localhost", + } + _ = postRelationships(t, 403, "frank", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "limit": 1, + "include_parent": true, + })) + }) + t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) { + nopUserAPI.accessTokens["frank2"] = userapi.Device{ + AccessToken: "frank2", + DisplayName: "Frank2 Not In Room", + UserID: "@frank2:localhost", + } + // Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB + nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB} + body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "limit": 1, + "include_parent": true, + })) + assertContains(t, body, []string{eventB.EventID()}) + }) + t.Run("returns the parent if include_parent is true", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "include_parent": true, + "limit": 2, + })) + assertContains(t, body, []string{eventB.EventID(), eventA.EventID()}) + }) + t.Run("returns the children in the right order if include_children is true", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventD.EventID(), + "include_children": true, + "recent_first": true, + "limit": 4, + })) + assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventD.EventID(), + "include_children": true, + "recent_first": false, + "limit": 4, + })) + assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) + }) + t.Run("walks the graph depth first", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": true, + "limit": 6, + })) + // Oldest first so: + // A + // | + // B1 + // / \ + // C2 D3 + // /| \ + // 4E 6F G + // | + // 5H + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()}) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": true, + "depth_first": true, + "limit": 6, + })) + // Recent first so: + // A + // | + // B1 + // / \ + // C D2 + // /| \ + // E5 F4 G3 + // | + // H6 + assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()}) + }) + t.Run("walks the graph breadth first", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 6, + })) + // Oldest first so: + // A + // | + // B1 + // / \ + // C2 D3 + // /| \ + // E4 F5 G6 + // | + // H + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": true, + "depth_first": false, + "limit": 6, + })) + // Recent first so: + // A + // | + // B1 + // / \ + // C3 D2 + // /| \ + // E6 F5 G4 + // | + // H + assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) + }) + t.Run("caps via max_breadth", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "max_breadth": 2, + "limit": 10, + })) + // Event G gets omitted because of max_breadth + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()}) + }) + t.Run("caps via max_depth", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "max_depth": 2, + "limit": 10, + })) + // Event H gets omitted because of max_depth + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) + }) + t.Run("terminates when reaching the limit", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 4, + })) + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()}) + }) + t.Run("returns all events with a high enough limit", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 400, + })) + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()}) + }) +} + +// TODO: TestMSC2836TerminatesLoops (short and long) +// TODO: TestMSC2836UnknownEventsSkipped +// TODO: TestMSC2836SkipEventIfNotInRoom + +func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelationshipRequest { + t.Helper() + b, err := json.Marshal(jsonBody) + if err != nil { + t.Fatalf("Failed to marshal request: %s", err) + } + r, err := msc2836.NewEventRelationshipRequest(bytes.NewBuffer(b)) + if err != nil { + t.Fatalf("Failed to NewEventRelationshipRequest: %s", err) + } + return r +} + +func runServer(t *testing.T, router *mux.Router) func() { + t.Helper() + externalServ := &http.Server{ + Addr: string(":8009"), + WriteTimeout: 60 * time.Second, + Handler: router, + } + go func() { + externalServ.ListenAndServe() + }() + // wait to listen on the port + time.Sleep(500 * time.Millisecond) + return func() { + externalServ.Shutdown(context.TODO()) + } +} + +func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse { + t.Helper() + var r msc2836.EventRelationshipRequest + r.Defaults() + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal request: %s", err) + } + httpReq, err := http.NewRequest( + "POST", "http://localhost:8009/_matrix/client/unstable/event_relationships", + bytes.NewBuffer(data), + ) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + if err != nil { + t.Fatalf("failed to prepare request: %s", err) + } + res, err := client.Do(httpReq) + if err != nil { + t.Fatalf("failed to do request: %s", err) + } + if res.StatusCode != expectCode { + body, _ := ioutil.ReadAll(res.Body) + t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body)) + } + if res.StatusCode == 200 { + var result msc2836.EventRelationshipResponse + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + t.Fatalf("response 200 OK but failed to deserialise JSON : %s", err) + } + return &result + } + return nil +} + +func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wantEventIDs []string) { + t.Helper() + gotEventIDs := make([]string, len(result.Events)) + for i, ev := range result.Events { + gotEventIDs[i] = ev.EventID + } + if len(gotEventIDs) != len(wantEventIDs) { + t.Fatalf("length mismatch: got %v want %v", gotEventIDs, wantEventIDs) + } + for i := range gotEventIDs { + if gotEventIDs[i] != wantEventIDs[i] { + t.Errorf("wrong item in position %d - got %s want %s", i, gotEventIDs[i], wantEventIDs[i]) + } + } +} + +type testUserAPI struct { + accessTokens map[string]userapi.Device +} + +func (u *testUserAPI) InputAccountData(ctx context.Context, req *userapi.InputAccountDataRequest, res *userapi.InputAccountDataResponse) error { + return nil +} +func (u *testUserAPI) PerformAccountCreation(ctx context.Context, req *userapi.PerformAccountCreationRequest, res *userapi.PerformAccountCreationResponse) error { + return nil +} +func (u *testUserAPI) PerformPasswordUpdate(ctx context.Context, req *userapi.PerformPasswordUpdateRequest, res *userapi.PerformPasswordUpdateResponse) error { + return nil +} +func (u *testUserAPI) PerformDeviceCreation(ctx context.Context, req *userapi.PerformDeviceCreationRequest, res *userapi.PerformDeviceCreationResponse) error { + return nil +} +func (u *testUserAPI) PerformDeviceDeletion(ctx context.Context, req *userapi.PerformDeviceDeletionRequest, res *userapi.PerformDeviceDeletionResponse) error { + return nil +} +func (u *testUserAPI) PerformDeviceUpdate(ctx context.Context, req *userapi.PerformDeviceUpdateRequest, res *userapi.PerformDeviceUpdateResponse) error { + return nil +} +func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { + return nil +} +func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { + return nil +} +func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { + dev, ok := u.accessTokens[req.AccessToken] + if !ok { + res.Err = fmt.Errorf("unknown token") + return nil + } + res.Device = &dev + return nil +} +func (u *testUserAPI) QueryDevices(ctx context.Context, req *userapi.QueryDevicesRequest, res *userapi.QueryDevicesResponse) error { + return nil +} +func (u *testUserAPI) QueryAccountData(ctx context.Context, req *userapi.QueryAccountDataRequest, res *userapi.QueryAccountDataResponse) error { + return nil +} +func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDeviceInfosRequest, res *userapi.QueryDeviceInfosResponse) error { + return nil +} +func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { + return nil +} + +type testRoomserverAPI struct { + // use a trace API as it implements method stubs so we don't need to have them here. + // We'll override the functions we care about. + roomserver.RoomserverInternalAPITrace + userToJoinedRooms map[string][]string + events map[string]*gomatrixserverlib.HeaderedEvent +} + +func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { + for _, eventID := range req.EventIDs { + ev := r.events[eventID] + if ev != nil { + res.Events = append(res.Events, ev) + } + } + return nil +} + +func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { + rooms := r.userToJoinedRooms[req.UserID] + for _, roomID := range rooms { + if roomID == req.RoomID { + res.IsInRoom = true + res.HasBeenInRoom = true + res.Membership = "join" + break + } + } + return nil +} + +func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { + t.Helper() + cfg := &config.Dendrite{} + cfg.Defaults() + cfg.Global.ServerName = "localhost" + cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db" + cfg.MSCs.MSCs = []string{"msc2836"} + base := &setup.BaseDendrite{ + Cfg: cfg, + PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(), + PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), + } + + err := msc2836.Enable(base, rsAPI, nil, userAPI, nil) + if err != nil { + t.Fatalf("failed to enable MSC2836: %s", err) + } + for _, ev := range events { + hooks.Run(hooks.KindNewEventPersisted, ev) + } + return base.PublicClientAPIMux +} + +type fledglingEvent struct { + Type string + StateKey *string + Content interface{} + Sender string + RoomID string +} + +func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { + t.Helper() + roomVer := gomatrixserverlib.RoomVersionV6 + seed := make([]byte, ed25519.SeedSize) // zero seed + key := ed25519.NewKeyFromSeed(seed) + eb := gomatrixserverlib.EventBuilder{ + Sender: ev.Sender, + Depth: 999, + Type: ev.Type, + StateKey: ev.StateKey, + RoomID: ev.RoomID, + } + err := eb.SetContent(ev.Content) + if err != nil { + t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) + } + // make sure the origin_server_ts changes so we can test recency + time.Sleep(1 * time.Millisecond) + signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) + if err != nil { + t.Fatalf("mustCreateEvent: failed to sign event: %s", err) + } + h := signedEvent.Headered(roomVer) + return h +} diff --git a/internal/mscs/msc2836/storage.go b/internal/mscs/msc2836/storage.go new file mode 100644 index 00000000..f524165f --- /dev/null +++ b/internal/mscs/msc2836/storage.go @@ -0,0 +1,226 @@ +package msc2836 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +type eventInfo struct { + EventID string + OriginServerTS gomatrixserverlib.Timestamp + RoomID string + Servers []string +} + +type Database interface { + // StoreRelation stores the parent->child and child->parent relationship for later querying. + // Also stores the event metadata e.g timestamp + StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error + // ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the + // provided `relType`. The returned slice is sorted by origin_server_ts according to whether + // `recentFirst` is true or false. + ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) +} + +type DB struct { + db *sql.DB + writer sqlutil.Writer + insertEdgeStmt *sql.Stmt + insertNodeStmt *sql.Stmt + selectChildrenForParentOldestFirstStmt *sql.Stmt + selectChildrenForParentRecentFirstStmt *sql.Stmt +} + +// NewDatabase loads the database for msc2836 +func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + if dbOpts.ConnectionString.IsPostgres() { + return newPostgresDatabase(dbOpts) + } + return newSQLiteDatabase(dbOpts) +} + +func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{ + writer: sqlutil.NewDummyWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbOpts); err != nil { + return nil, err + } + _, err = d.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2836_edges ( + parent_event_id TEXT NOT NULL, + child_event_id TEXT NOT NULL, + rel_type TEXT NOT NULL, + parent_room_id TEXT NOT NULL, + parent_servers TEXT NOT NULL, + CONSTRAINT msc2836_edges_uniq UNIQUE (parent_event_id, child_event_id, rel_type) + ); + + CREATE TABLE IF NOT EXISTS msc2836_nodes ( + event_id TEXT PRIMARY KEY NOT NULL, + origin_server_ts BIGINT NOT NULL, + room_id TEXT NOT NULL + ); + `) + if err != nil { + return nil, err + } + if d.insertEdgeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + if d.insertNodeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + selectChildrenQuery := ` + SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges + LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id + WHERE parent_event_id = $1 AND rel_type = $2 + ORDER BY origin_server_ts + ` + if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil { + return nil, err + } + if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { + return nil, err + } + return &d, err +} + +func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{ + writer: sqlutil.NewExclusiveWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbOpts); err != nil { + return nil, err + } + _, err = d.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2836_edges ( + parent_event_id TEXT NOT NULL, + child_event_id TEXT NOT NULL, + rel_type TEXT NOT NULL, + parent_room_id TEXT NOT NULL, + parent_servers TEXT NOT NULL, + UNIQUE (parent_event_id, child_event_id, rel_type) + ); + + CREATE TABLE IF NOT EXISTS msc2836_nodes ( + event_id TEXT PRIMARY KEY NOT NULL, + origin_server_ts BIGINT NOT NULL, + room_id TEXT NOT NULL + ); + `) + if err != nil { + return nil, err + } + if d.insertEdgeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING + `); err != nil { + return nil, err + } + if d.insertNodeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + selectChildrenQuery := ` + SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges + LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id + WHERE parent_event_id = $1 AND rel_type = $2 + ORDER BY origin_server_ts + ` + if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil { + return nil, err + } + if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { + return nil, err + } + return &d, nil +} + +func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { + parent, child, relType := parentChildEventIDs(ev) + if parent == "" || child == "" { + return nil + } + relationRoomID, relationServers := roomIDAndServers(ev) + relationServersJSON, err := json.Marshal(relationServers) + if err != nil { + return err + } + return p.writer.Do(p.db, nil, func(txn *sql.Tx) error { + _, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON)) + if err != nil { + return err + } + _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID()) + return err + }) +} + +func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) { + var rows *sql.Rows + var err error + if recentFirst { + rows, err = p.selectChildrenForParentRecentFirstStmt.QueryContext(ctx, eventID, relType) + } else { + rows, err = p.selectChildrenForParentOldestFirstStmt.QueryContext(ctx, eventID, relType) + } + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + var children []eventInfo + for rows.Next() { + var evInfo eventInfo + if err := rows.Scan(&evInfo.EventID, &evInfo.OriginServerTS, &evInfo.RoomID); err != nil { + return nil, err + } + children = append(children, evInfo) + } + return children, nil +} + +func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) { + if ev == nil { + return + } + body := struct { + Relationship struct { + RelType string `json:"rel_type"` + EventID string `json:"event_id"` + } `json:"m.relationship"` + }{} + if err := json.Unmarshal(ev.Content(), &body); err != nil { + return + } + if body.Relationship.EventID == "" || body.Relationship.RelType == "" { + return + } + return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType +} + +func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, servers []string) { + servers = []string{} + if ev == nil { + return + } + body := struct { + RoomID string `json:"relationship_room_id"` + Servers []string `json:"relationship_servers"` + }{} + if err := json.Unmarshal(ev.Unsigned(), &body); err != nil { + return + } + return body.RoomID, body.Servers +} diff --git a/internal/mscs/mscs.go b/internal/mscs/mscs.go new file mode 100644 index 00000000..0a896ab0 --- /dev/null +++ b/internal/mscs/mscs.go @@ -0,0 +1,42 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mscs implements Matrix Spec Changes from https://github.com/matrix-org/matrix-doc +package mscs + +import ( + "fmt" + + "github.com/matrix-org/dendrite/internal/mscs/msc2836" + "github.com/matrix-org/dendrite/internal/setup" +) + +// Enable MSCs - returns an error on unknown MSCs +func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error { + for _, msc := range base.Cfg.MSCs.MSCs { + if err := EnableMSC(base, monolith, msc); err != nil { + return err + } + } + return nil +} + +func EnableMSC(base *setup.BaseDendrite, monolith *setup.Monolith, msc string) error { + switch msc { + case "msc2836": + return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing) + default: + return fmt.Errorf("EnableMSC: unknown msc '%s'", msc) + } +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 29dbd25c..ec561f11 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -83,7 +83,8 @@ type PerformJoinRequest struct { type PerformJoinResponse struct { // The room ID, populated on success. - RoomID string `json:"room_id"` + RoomID string `json:"room_id"` + JoinedVia gomatrixserverlib.ServerName // If non-nil, the join request failed. Contains more information why it failed. Error *PerformError } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index d7257539..79dc2fe1 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -22,6 +22,7 @@ import ( "time" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" @@ -61,7 +62,11 @@ func (w *inputWorker) start() { for { select { case task := <-w.input: + hooks.Run(hooks.KindNewEventReceived, &task.event.Event) _, task.err = w.r.processRoomEvent(task.ctx, task.event) + if task.err == nil { + hooks.Run(hooks.KindNewEventPersisted, &task.event.Event) + } task.wg.Done() case <-time.After(time.Second * 5): return diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 56ae6d0b..f3745a7f 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -47,7 +47,7 @@ func (r *Joiner) PerformJoin( req *api.PerformJoinRequest, res *api.PerformJoinResponse, ) { - roomID, err := r.performJoin(ctx, req) + roomID, joinedVia, err := r.performJoin(ctx, req) if err != nil { perr, ok := err.(*api.PerformError) if ok { @@ -59,21 +59,22 @@ func (r *Joiner) PerformJoin( } } res.RoomID = roomID + res.JoinedVia = joinedVia } func (r *Joiner) performJoin( ctx context.Context, req *api.PerformJoinRequest, -) (string, error) { +) (string, gomatrixserverlib.ServerName, error) { _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } if domain != r.Cfg.Matrix.ServerName { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), } @@ -84,7 +85,7 @@ func (r *Joiner) performJoin( if strings.HasPrefix(req.RoomIDOrAlias, "#") { return r.performJoinRoomByAlias(ctx, req) } - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias), } @@ -93,11 +94,11 @@ func (r *Joiner) performJoin( func (r *Joiner) performJoinRoomByAlias( ctx context.Context, req *api.PerformJoinRequest, -) (string, error) { +) (string, gomatrixserverlib.ServerName, error) { // Get the domain part of the room alias. _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) if err != nil { - return "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias) + return "", "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias) } req.ServerNames = append(req.ServerNames, domain) @@ -115,7 +116,7 @@ func (r *Joiner) performJoinRoomByAlias( err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) if err != nil { logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias) - return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) + return "", "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) } roomID = dirRes.RoomID req.ServerNames = append(req.ServerNames, dirRes.ServerNames...) @@ -123,13 +124,13 @@ func (r *Joiner) performJoinRoomByAlias( // Otherwise, look up if we know this room alias locally. roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias) if err != nil { - return "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err) + return "", "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err) } } // If the room ID is empty then we failed to look up the alias. if roomID == "" { - return "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias) + return "", "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias) } // If we do, then pluck out the room ID and continue the join. @@ -142,11 +143,11 @@ func (r *Joiner) performJoinRoomByAlias( func (r *Joiner) performJoinRoomByID( ctx context.Context, req *api.PerformJoinRequest, -) (string, error) { +) (string, gomatrixserverlib.ServerName, error) { // Get the domain part of the room ID. _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) if err != nil { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err), } @@ -169,7 +170,7 @@ func (r *Joiner) performJoinRoomByID( Redacts: "", } if err = eb.SetUnsigned(struct{}{}); err != nil { - return "", fmt.Errorf("eb.SetUnsigned: %w", err) + return "", "", fmt.Errorf("eb.SetUnsigned: %w", err) } // It is possible for the request to include some "content" for the @@ -180,7 +181,7 @@ func (r *Joiner) performJoinRoomByID( } req.Content["membership"] = gomatrixserverlib.Join if err = eb.SetContent(req.Content); err != nil { - return "", fmt.Errorf("eb.SetContent: %w", err) + return "", "", fmt.Errorf("eb.SetContent: %w", err) } // Force a federated join if we aren't in the room and we've been @@ -194,7 +195,7 @@ func (r *Joiner) performJoinRoomByID( if err == nil && isInvitePending { _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) if ierr != nil { - return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } // If we were invited by someone from another server then we can @@ -206,8 +207,10 @@ func (r *Joiner) performJoinRoomByID( } // If we should do a forced federated join then do that. + var joinedVia gomatrixserverlib.ServerName if forceFederatedJoin { - return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) + return req.RoomIDOrAlias, joinedVia, err } // Try to construct an actual join event from the template. @@ -249,7 +252,7 @@ func (r *Joiner) performJoinRoomByID( inputRes := api.InputRoomEventsResponse{} r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = inputRes.Err(); err != nil { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorNotAllowed, Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err), } @@ -265,7 +268,7 @@ func (r *Joiner) performJoinRoomByID( // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. if len(req.ServerNames) == 0 { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorNoRoom, Msg: fmt.Sprintf("Room ID %q does not exist", req.RoomIDOrAlias), } @@ -273,24 +276,25 @@ func (r *Joiner) performJoinRoomByID( } // Perform a federated room join. - return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) + return req.RoomIDOrAlias, joinedVia, err default: // Something else went wrong. - return "", fmt.Errorf("Error joining local room: %q", err) + return "", "", fmt.Errorf("Error joining local room: %q", err) } // By this point, if req.RoomIDOrAlias contained an alias, then // it will have been overwritten with a room ID by performJoinRoomByAlias. // We should now include this in the response so that the CS API can // return the right room ID. - return req.RoomIDOrAlias, nil + return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil } func (r *Joiner) performFederatedJoinRoomByID( ctx context.Context, req *api.PerformJoinRequest, -) error { +) (gomatrixserverlib.ServerName, error) { // Try joining by all of the supplied server names. fedReq := fsAPI.PerformJoinRequest{ RoomID: req.RoomIDOrAlias, // the room ID to try and join @@ -301,13 +305,13 @@ func (r *Joiner) performFederatedJoinRoomByID( fedRes := fsAPI.PerformJoinResponse{} r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) if fedRes.LastError != nil { - return &api.PerformError{ + return "", &api.PerformError{ Code: api.PerformErrRemote, Msg: fedRes.LastError.Message, RemoteCode: fedRes.LastError.Code, } } - return nil + return fedRes.JoinedVia, nil } func buildEvent(