Move currentstateserver API to roomserver (#1387)
* Move currentstateserver API to roomserver Stub out DB functions for now, nothing uses the roomserver version yet. * Allow it to startup * Implement some current-state-server storage interface functions * Add missing packagemain
parent
6150de6cb3
commit
b20386123e
|
@ -23,17 +23,25 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/currentstateserver/storage"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ServerACLDatabase interface {
|
||||||
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
|
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
}
|
||||||
|
|
||||||
type ServerACLs struct {
|
type ServerACLs struct {
|
||||||
acls map[string]*serverACL // room ID -> ACL
|
acls map[string]*serverACL // room ID -> ACL
|
||||||
aclsMutex sync.RWMutex // protects the above
|
aclsMutex sync.RWMutex // protects the above
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServerACLs(db storage.Database) *ServerACLs {
|
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
acls := &ServerACLs{
|
acls := &ServerACLs{
|
||||||
acls: make(map[string]*serverACL),
|
acls: make(map[string]*serverACL),
|
||||||
|
|
|
@ -296,6 +296,30 @@ func (t *testRoomserverAPI) RemoveRoomAlias(
|
||||||
return fmt.Errorf("not implemented")
|
return fmt.Errorf("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type testStateAPI struct {
|
type testStateAPI struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,164 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package acls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerACLDatabase interface {
|
||||||
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
|
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerACLs struct {
|
||||||
|
acls map[string]*serverACL // room ID -> ACL
|
||||||
|
aclsMutex sync.RWMutex // protects the above
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||||
|
ctx := context.TODO()
|
||||||
|
acls := &ServerACLs{
|
||||||
|
acls: make(map[string]*serverACL),
|
||||||
|
}
|
||||||
|
// Look up all of the rooms that the current state server knows about.
|
||||||
|
rooms, err := db.GetKnownRooms(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatalf("Failed to get known rooms")
|
||||||
|
}
|
||||||
|
// For each room, let's see if we have a server ACL state event. If we
|
||||||
|
// do then we'll process it into memory so that we have the regexes to
|
||||||
|
// hand.
|
||||||
|
for _, room := range rooms {
|
||||||
|
state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "")
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if state != nil {
|
||||||
|
acls.OnServerACLUpdate(&state.Event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return acls
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerACL struct {
|
||||||
|
Allowed []string `json:"allow"`
|
||||||
|
Denied []string `json:"deny"`
|
||||||
|
AllowIPLiterals bool `json:"allow_ip_literals"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverACL struct {
|
||||||
|
ServerACL
|
||||||
|
allowedRegexes []*regexp.Regexp
|
||||||
|
deniedRegexes []*regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
func compileACLRegex(orig string) (*regexp.Regexp, error) {
|
||||||
|
escaped := regexp.QuoteMeta(orig)
|
||||||
|
escaped = strings.Replace(escaped, "\\?", ".", -1)
|
||||||
|
escaped = strings.Replace(escaped, "\\*", ".*", -1)
|
||||||
|
return regexp.Compile(escaped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) {
|
||||||
|
acls := &serverACL{}
|
||||||
|
if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// The spec calls only for * (zero or more chars) and ? (exactly one char)
|
||||||
|
// to be supported as wildcard components, so we will escape all of the regex
|
||||||
|
// special characters and then replace * and ? with their regex counterparts.
|
||||||
|
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
|
||||||
|
for _, orig := range acls.Allowed {
|
||||||
|
if expr, err := compileACLRegex(orig); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to compile allowed regex")
|
||||||
|
} else {
|
||||||
|
acls.allowedRegexes = append(acls.allowedRegexes, expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, orig := range acls.Denied {
|
||||||
|
if expr, err := compileACLRegex(orig); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to compile denied regex")
|
||||||
|
} else {
|
||||||
|
acls.deniedRegexes = append(acls.deniedRegexes, expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"allow_ip_literals": acls.AllowIPLiterals,
|
||||||
|
"num_allowed": len(acls.allowedRegexes),
|
||||||
|
"num_denied": len(acls.deniedRegexes),
|
||||||
|
}).Debugf("Updating server ACLs for %q", state.RoomID())
|
||||||
|
s.aclsMutex.Lock()
|
||||||
|
defer s.aclsMutex.Unlock()
|
||||||
|
s.acls[state.RoomID()] = acls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool {
|
||||||
|
s.aclsMutex.RLock()
|
||||||
|
// First of all check if we have an ACL for this room. If we don't then
|
||||||
|
// no servers are banned from the room.
|
||||||
|
acls, ok := s.acls[roomID]
|
||||||
|
if !ok {
|
||||||
|
s.aclsMutex.RUnlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.aclsMutex.RUnlock()
|
||||||
|
// Split the host and port apart. This is because the spec calls on us to
|
||||||
|
// validate the hostname only in cases where the port is also present.
|
||||||
|
if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil {
|
||||||
|
serverName = gomatrixserverlib.ServerName(serverNameOnly)
|
||||||
|
}
|
||||||
|
// Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding
|
||||||
|
// a /0 prefix length just to trick ParseCIDR into working. If we find that
|
||||||
|
// the server is an IP literal and we don't allow those then stop straight
|
||||||
|
// away.
|
||||||
|
if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil {
|
||||||
|
if !acls.AllowIPLiterals {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if the hostname matches one of the denied regexes. If it does then
|
||||||
|
// the server is banned from the room.
|
||||||
|
for _, expr := range acls.deniedRegexes {
|
||||||
|
if expr.MatchString(string(serverName)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if the hostname matches one of the allowed regexes. If it does then
|
||||||
|
// the server is NOT banned from the room.
|
||||||
|
for _, expr := range acls.allowedRegexes {
|
||||||
|
if expr.MatchString(string(serverName)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we've got to this point then we haven't matched any regexes or an IP
|
||||||
|
// hostname if disallowed. The spec calls for default-deny here.
|
||||||
|
return true
|
||||||
|
}
|
|
@ -0,0 +1,105 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package acls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenACLsWithBlacklist(t *testing.T) {
|
||||||
|
roomID := "!test:test.com"
|
||||||
|
allowRegex, err := compileACLRegex("*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
denyRegex, err := compileACLRegex("foo.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
acls := ServerACLs{
|
||||||
|
acls: make(map[string]*serverACL),
|
||||||
|
}
|
||||||
|
|
||||||
|
acls.acls[roomID] = &serverACL{
|
||||||
|
ServerACL: ServerACL{
|
||||||
|
AllowIPLiterals: true,
|
||||||
|
},
|
||||||
|
allowedRegexes: []*regexp.Regexp{allowRegex},
|
||||||
|
deniedRegexes: []*regexp.Regexp{denyRegex},
|
||||||
|
}
|
||||||
|
|
||||||
|
if acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
|
||||||
|
t.Fatal("Expected 1.2.3.4 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
|
||||||
|
t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("foo.com", roomID) {
|
||||||
|
t.Fatal("Expected foo.com to be banned but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
|
||||||
|
t.Fatal("Expected foo.com:3456 to be banned but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("bar.com", roomID) {
|
||||||
|
t.Fatal("Expected bar.com to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("bar.com:4567", roomID) {
|
||||||
|
t.Fatal("Expected bar.com:4567 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultACLsWithWhitelist(t *testing.T) {
|
||||||
|
roomID := "!test:test.com"
|
||||||
|
allowRegex, err := compileACLRegex("foo.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
acls := ServerACLs{
|
||||||
|
acls: make(map[string]*serverACL),
|
||||||
|
}
|
||||||
|
|
||||||
|
acls.acls[roomID] = &serverACL{
|
||||||
|
ServerACL: ServerACL{
|
||||||
|
AllowIPLiterals: false,
|
||||||
|
},
|
||||||
|
allowedRegexes: []*regexp.Regexp{allowRegex},
|
||||||
|
deniedRegexes: []*regexp.Regexp{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
|
||||||
|
t.Fatal("Expected 1.2.3.4 to be banned but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
|
||||||
|
t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("foo.com", roomID) {
|
||||||
|
t.Fatal("Expected foo.com to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
|
||||||
|
t.Fatal("Expected foo.com:3456 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("bar.com", roomID) {
|
||||||
|
t.Fatal("Expected bar.com to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("baz.com", roomID) {
|
||||||
|
t.Fatal("Expected baz.com to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) {
|
||||||
|
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
}
|
|
@ -106,6 +106,20 @@ type RoomserverInternalAPI interface {
|
||||||
response *QueryStateAndAuthChainResponse,
|
response *QueryStateAndAuthChainResponse,
|
||||||
) error
|
) error
|
||||||
|
|
||||||
|
// QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from
|
||||||
|
// the response.
|
||||||
|
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
|
||||||
|
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
|
||||||
|
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
||||||
|
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
|
||||||
|
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
|
||||||
|
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||||
|
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
|
||||||
|
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
||||||
|
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
|
||||||
|
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||||
|
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
||||||
|
|
||||||
// Query a given amount (or less) of events prior to a given set of events.
|
// Query a given amount (or less) of events prior to a given set of events.
|
||||||
PerformBackfill(
|
PerformBackfill(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -236,6 +236,47 @@ func (t *RoomserverInternalAPITrace) RemoveRoomAlias(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error {
|
||||||
|
err := t.Impl.QueryCurrentState(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error {
|
||||||
|
err := t.Impl.QueryRoomsForUser(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error {
|
||||||
|
err := t.Impl.QueryBulkStateContent(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||||
|
func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error {
|
||||||
|
err := t.Impl.QuerySharedUsers(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error {
|
||||||
|
err := t.Impl.QueryKnownUsers(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error {
|
||||||
|
err := t.Impl.QueryServerBannedFromRoom(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func js(thing interface{}) string {
|
func js(thing interface{}) string {
|
||||||
b, err := json.Marshal(thing)
|
b, err := json.Marshal(thing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -17,6 +17,11 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -225,3 +230,102 @@ type QueryPublishedRoomsResponse struct {
|
||||||
// The list of published rooms.
|
// The list of published rooms.
|
||||||
RoomIDs []string
|
RoomIDs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QuerySharedUsersRequest struct {
|
||||||
|
UserID string
|
||||||
|
ExcludeRoomIDs []string
|
||||||
|
IncludeRoomIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QuerySharedUsersResponse struct {
|
||||||
|
UserIDsToCount map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryRoomsForUserRequest struct {
|
||||||
|
UserID string
|
||||||
|
// The desired membership of the user. If this is the empty string then no rooms are returned.
|
||||||
|
WantMembership string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryRoomsForUserResponse struct {
|
||||||
|
RoomIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryBulkStateContentRequest struct {
|
||||||
|
// Returns state events in these rooms
|
||||||
|
RoomIDs []string
|
||||||
|
// If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*'
|
||||||
|
AllowWildcards bool
|
||||||
|
// The state events to return. Only a small subset of tuples are allowed in this request as only certain events
|
||||||
|
// have their content fields extracted. Specifically, the tuple Type must be one of:
|
||||||
|
// m.room.avatar
|
||||||
|
// m.room.create
|
||||||
|
// m.room.canonical_alias
|
||||||
|
// m.room.guest_access
|
||||||
|
// m.room.history_visibility
|
||||||
|
// m.room.join_rules
|
||||||
|
// m.room.member
|
||||||
|
// m.room.name
|
||||||
|
// m.room.topic
|
||||||
|
// Any other tuple type will result in the query failing.
|
||||||
|
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||||
|
}
|
||||||
|
type QueryBulkStateContentResponse struct {
|
||||||
|
// map of room ID -> tuple -> content_value
|
||||||
|
Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryCurrentStateRequest struct {
|
||||||
|
RoomID string
|
||||||
|
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryCurrentStateResponse struct {
|
||||||
|
StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryKnownUsersRequest struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
SearchString string `json:"search_string"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryKnownUsersResponse struct {
|
||||||
|
Users []authtypes.FullyQualifiedProfile `json:"profiles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryServerBannedFromRoomRequest struct {
|
||||||
|
ServerName gomatrixserverlib.ServerName `json:"server_name"`
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryServerBannedFromRoomResponse struct {
|
||||||
|
Banned bool `json:"banned"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
|
||||||
|
func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents))
|
||||||
|
for k, v := range r.StateEvents {
|
||||||
|
// use 0x1F (unit separator) as the delimiter between type/state key,
|
||||||
|
se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v
|
||||||
|
}
|
||||||
|
return json.Marshal(se)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
res := make(map[string]*gomatrixserverlib.HeaderedEvent)
|
||||||
|
err := json.Unmarshal(data, &res)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res))
|
||||||
|
for k, v := range res {
|
||||||
|
fields := strings.Split(k, "\x1F")
|
||||||
|
r.StateEvents[gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: fields[0],
|
||||||
|
StateKey: fields[1],
|
||||||
|
}] = v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -133,3 +133,102 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string)
|
||||||
}
|
}
|
||||||
return &res.Events[0]
|
return &res.Events[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateEvent returns the current state event in the room or nil.
|
||||||
|
func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent {
|
||||||
|
var res QueryCurrentStateResponse
|
||||||
|
err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{
|
||||||
|
RoomID: roomID,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
|
||||||
|
}, &res)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("Failed to QueryCurrentState")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ev, ok := res.StateEvents[tuple]
|
||||||
|
if ok {
|
||||||
|
return ev
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs.
|
||||||
|
func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool {
|
||||||
|
req := &QueryServerBannedFromRoomRequest{
|
||||||
|
ServerName: serverName,
|
||||||
|
RoomID: roomID,
|
||||||
|
}
|
||||||
|
res := &QueryServerBannedFromRoomResponse{}
|
||||||
|
if err := rsAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return res.Banned
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the
|
||||||
|
// published room directory.
|
||||||
|
// due to lots of switches
|
||||||
|
// nolint:gocyclo
|
||||||
|
func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) {
|
||||||
|
avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""}
|
||||||
|
nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""}
|
||||||
|
canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""}
|
||||||
|
topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""}
|
||||||
|
guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""}
|
||||||
|
visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""}
|
||||||
|
joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""}
|
||||||
|
|
||||||
|
var stateRes QueryBulkStateContentResponse
|
||||||
|
err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{
|
||||||
|
RoomIDs: roomIDs,
|
||||||
|
AllowWildcards: true,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
||||||
|
nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple,
|
||||||
|
{EventType: gomatrixserverlib.MRoomMember, StateKey: "*"},
|
||||||
|
},
|
||||||
|
}, &stateRes)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs))
|
||||||
|
i := 0
|
||||||
|
for roomID, data := range stateRes.Rooms {
|
||||||
|
pub := gomatrixserverlib.PublicRoom{
|
||||||
|
RoomID: roomID,
|
||||||
|
}
|
||||||
|
joinCount := 0
|
||||||
|
var joinRule, guestAccess string
|
||||||
|
for tuple, contentVal := range data {
|
||||||
|
if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" {
|
||||||
|
joinCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch tuple {
|
||||||
|
case avatarTuple:
|
||||||
|
pub.AvatarURL = contentVal
|
||||||
|
case nameTuple:
|
||||||
|
pub.Name = contentVal
|
||||||
|
case topicTuple:
|
||||||
|
pub.Topic = contentVal
|
||||||
|
case canonicalTuple:
|
||||||
|
pub.CanonicalAlias = contentVal
|
||||||
|
case visibilityTuple:
|
||||||
|
pub.WorldReadable = contentVal == "world_readable"
|
||||||
|
// need both of these to determine whether guests can join
|
||||||
|
case joinRuleTuple:
|
||||||
|
joinRule = contentVal
|
||||||
|
case guestTuple:
|
||||||
|
guestAccess = contentVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" {
|
||||||
|
pub.GuestCanJoin = true
|
||||||
|
}
|
||||||
|
pub.JoinedMembersCount = joinCount
|
||||||
|
chunk[i] = pub
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return chunk, nil
|
||||||
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/acls"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/perform"
|
"github.com/matrix-org/dendrite/roomserver/internal/perform"
|
||||||
|
@ -46,8 +47,9 @@ func NewRoomserverAPI(
|
||||||
ServerName: cfg.Matrix.ServerName,
|
ServerName: cfg.Matrix.ServerName,
|
||||||
KeyRing: keyRing,
|
KeyRing: keyRing,
|
||||||
Queryer: &query.Queryer{
|
Queryer: &query.Queryer{
|
||||||
DB: roomserverDB,
|
DB: roomserverDB,
|
||||||
Cache: caches,
|
Cache: caches,
|
||||||
|
ServerACLs: acls.NewServerACLs(roomserverDB),
|
||||||
},
|
},
|
||||||
Inputer: &input.Inputer{
|
Inputer: &input.Inputer{
|
||||||
DB: roomserverDB,
|
DB: roomserverDB,
|
||||||
|
|
|
@ -16,9 +16,12 @@ package query
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/acls"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
|
@ -31,8 +34,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Queryer struct {
|
type Queryer struct {
|
||||||
DB storage.Database
|
DB storage.Database
|
||||||
Cache caching.RoomServerCaches
|
Cache caching.RoomServerCaches
|
||||||
|
ServerACLs *acls.ServerACLs
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryLatestEventsAndState implements api.RoomserverInternalAPI
|
// QueryLatestEventsAndState implements api.RoomserverInternalAPI
|
||||||
|
@ -502,3 +506,97 @@ func (r *Queryer) QueryPublishedRooms(
|
||||||
res.RoomIDs = rooms
|
res.RoomIDs = rooms
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||||
|
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
||||||
|
for _, tuple := range req.StateTuples {
|
||||||
|
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ev != nil {
|
||||||
|
res.StateEvents[tuple] = ev
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||||
|
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.RoomIDs = roomIDs
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
|
||||||
|
users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, user := range users {
|
||||||
|
res.Users = append(res.Users, authtypes.FullyQualifiedProfile{
|
||||||
|
UserID: user,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
|
||||||
|
events, err := r.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
|
||||||
|
for _, ev := range events {
|
||||||
|
if res.Rooms[ev.RoomID] == nil {
|
||||||
|
res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string)
|
||||||
|
}
|
||||||
|
room := res.Rooms[ev.RoomID]
|
||||||
|
room[gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: ev.EventType,
|
||||||
|
StateKey: ev.StateKey,
|
||||||
|
}] = ev.ContentValue
|
||||||
|
res.Rooms[ev.RoomID] = room
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||||
|
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, req.IncludeRoomIDs...)
|
||||||
|
excludeMap := make(map[string]bool)
|
||||||
|
for _, roomID := range req.ExcludeRoomIDs {
|
||||||
|
excludeMap[roomID] = true
|
||||||
|
}
|
||||||
|
// filter out excluded rooms
|
||||||
|
j := 0
|
||||||
|
for i := range roomIDs {
|
||||||
|
// move elements to include to the beginning of the slice
|
||||||
|
// then trim elements on the right
|
||||||
|
if !excludeMap[roomIDs[i]] {
|
||||||
|
roomIDs[j] = roomIDs[i]
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
roomIDs = roomIDs[:j]
|
||||||
|
|
||||||
|
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.UserIDsToCount = users
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
|
||||||
|
if r.ServerACLs == nil {
|
||||||
|
return errors.New("no server ACL tracking")
|
||||||
|
}
|
||||||
|
res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -43,6 +43,12 @@ const (
|
||||||
RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities"
|
RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities"
|
||||||
RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom"
|
RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom"
|
||||||
RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms"
|
RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms"
|
||||||
|
RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState"
|
||||||
|
RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser"
|
||||||
|
RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent"
|
||||||
|
RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers"
|
||||||
|
RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers"
|
||||||
|
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
|
||||||
)
|
)
|
||||||
|
|
||||||
type httpRoomserverInternalAPI struct {
|
type httpRoomserverInternalAPI struct {
|
||||||
|
@ -371,3 +377,69 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom(
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryCurrentState(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryCurrentStateRequest,
|
||||||
|
response *api.QueryCurrentStateResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryRoomsForUserRequest,
|
||||||
|
response *api.QueryRoomsForUserResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryBulkStateContentRequest,
|
||||||
|
response *api.QueryBulkStateContentResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QuerySharedUsers(
|
||||||
|
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryKnownUsers(
|
||||||
|
ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
|
||||||
|
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
|
@ -312,4 +312,82 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryCurrentStatePath,
|
||||||
|
httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryCurrentStateRequest{}
|
||||||
|
response := api.QueryCurrentStateResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryRoomsForUserPath,
|
||||||
|
httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryRoomsForUserRequest{}
|
||||||
|
response := api.QueryRoomsForUserResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryBulkStateContentPath,
|
||||||
|
httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryBulkStateContentRequest{}
|
||||||
|
response := api.QueryBulkStateContentResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
|
||||||
|
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QuerySharedUsersRequest{}
|
||||||
|
response := api.QuerySharedUsersResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
|
||||||
|
httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryKnownUsersRequest{}
|
||||||
|
response := api.QueryKnownUsersResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath,
|
||||||
|
httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryServerBannedFromRoomRequest{}
|
||||||
|
response := api.QueryServerBannedFromRoomResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ package storage
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/currentstateserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
@ -138,4 +139,22 @@ type Database interface {
|
||||||
PublishRoom(ctx context.Context, roomID string, publish bool) error
|
PublishRoom(ctx context.Context, roomID string, publish bool) error
|
||||||
// Returns a list of room IDs for rooms which are published.
|
// Returns a list of room IDs for rooms which are published.
|
||||||
GetPublishedRooms(ctx context.Context) ([]string, error)
|
GetPublishedRooms(ctx context.Context) ([]string, error)
|
||||||
|
|
||||||
|
// TODO: factor out - from currentstateserver
|
||||||
|
|
||||||
|
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||||
|
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
|
||||||
|
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||||
|
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||||
|
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||||
|
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||||
|
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
|
||||||
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
|
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
||||||
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,6 +99,9 @@ const updateMembershipSQL = "" +
|
||||||
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
|
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
|
||||||
" WHERE room_nid = $1 AND target_nid = $2"
|
" WHERE room_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
selectMembershipForUpdateStmt *sql.Stmt
|
selectMembershipForUpdateStmt *sql.Stmt
|
||||||
|
@ -108,6 +111,7 @@ type membershipStatements struct {
|
||||||
selectMembershipsFromRoomStmt *sql.Stmt
|
selectMembershipsFromRoomStmt *sql.Stmt
|
||||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||||
updateMembershipStmt *sql.Stmt
|
updateMembershipStmt *sql.Stmt
|
||||||
|
selectRoomsWithMembershipStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
|
@ -126,6 +130,7 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
||||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||||
|
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,3 +227,22 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
|
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
|
) ([]types.RoomNID, error) {
|
||||||
|
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
if err := rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
return roomNIDs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
@ -74,6 +75,12 @@ const selectRoomVersionForRoomNIDSQL = "" +
|
||||||
const selectRoomInfoSQL = "" +
|
const selectRoomInfoSQL = "" +
|
||||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||||
|
|
||||||
|
const selectRoomIDsSQL = "" +
|
||||||
|
"SELECT room_id FROM roomserver_rooms"
|
||||||
|
|
||||||
|
const bulkSelectRoomIDsSQL = "" +
|
||||||
|
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
selectRoomNIDStmt *sql.Stmt
|
selectRoomNIDStmt *sql.Stmt
|
||||||
|
@ -82,6 +89,8 @@ type roomStatements struct {
|
||||||
updateLatestEventNIDsStmt *sql.Stmt
|
updateLatestEventNIDsStmt *sql.Stmt
|
||||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||||
selectRoomInfoStmt *sql.Stmt
|
selectRoomInfoStmt *sql.Stmt
|
||||||
|
selectRoomIDsStmt *sql.Stmt
|
||||||
|
bulkSelectRoomIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
|
@ -98,9 +107,27 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||||
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
||||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||||
|
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||||
|
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||||
|
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||||
|
var roomIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
func (s *roomStatements) InsertRoomNID(
|
func (s *roomStatements) InsertRoomNID(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
|
@ -197,3 +224,24 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
|
||||||
}
|
}
|
||||||
return roomVersion, err
|
return roomVersion, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||||
|
var array pq.Int64Array
|
||||||
|
for _, nid := range roomNIDs {
|
||||||
|
array = append(array, int64(nid))
|
||||||
|
}
|
||||||
|
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||||
|
var roomIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -711,3 +712,82 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
|
||||||
}
|
}
|
||||||
return &evs[0]
|
return &evs[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateEvent returns the current state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
/*
|
||||||
|
roomInfo, err := d.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||||
|
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
|
||||||
|
var membershipState tables.MembershipState
|
||||||
|
switch membership {
|
||||||
|
case "join":
|
||||||
|
membershipState = tables.MembershipStateJoin
|
||||||
|
case "invite":
|
||||||
|
membershipState = tables.MembershipStateInvite
|
||||||
|
case "leave":
|
||||||
|
membershipState = tables.MembershipStateLeaveOrBan
|
||||||
|
case "ban":
|
||||||
|
membershipState = tables.MembershipStateLeaveOrBan
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
||||||
|
}
|
||||||
|
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
||||||
|
}
|
||||||
|
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(roomIDs) != len(roomNIDs) {
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs))
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||||
|
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||||
|
func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]csstables.StrippedEvent, error) {
|
||||||
|
return nil, fmt.Errorf("not implemented yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||||
|
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||||
|
return nil, fmt.Errorf("not implemented yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
|
func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
|
||||||
|
return nil, fmt.Errorf("not implemented yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
|
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||||
|
return d.RoomsTable.SelectRoomIDs(ctx)
|
||||||
|
}
|
||||||
|
|
|
@ -75,6 +75,9 @@ const updateMembershipSQL = "" +
|
||||||
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" +
|
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" +
|
||||||
" WHERE room_nid = $4 AND target_nid = $5"
|
" WHERE room_nid = $4 AND target_nid = $5"
|
||||||
|
|
||||||
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
|
@ -84,6 +87,7 @@ type membershipStatements struct {
|
||||||
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
||||||
selectMembershipsFromRoomStmt *sql.Stmt
|
selectMembershipsFromRoomStmt *sql.Stmt
|
||||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||||
|
selectRoomsWithMembershipStmt *sql.Stmt
|
||||||
updateMembershipStmt *sql.Stmt
|
updateMembershipStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,6 +109,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
||||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||||
|
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,3 +208,22 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
|
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
|
) ([]types.RoomNID, error) {
|
||||||
|
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
if err := rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
return roomNIDs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -21,7 +21,9 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
@ -64,6 +66,12 @@ const selectRoomVersionForRoomNIDSQL = "" +
|
||||||
const selectRoomInfoSQL = "" +
|
const selectRoomInfoSQL = "" +
|
||||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||||
|
|
||||||
|
const selectRoomIDsSQL = "" +
|
||||||
|
"SELECT room_id FROM roomserver_rooms"
|
||||||
|
|
||||||
|
const bulkSelectRoomIDsSQL = "" +
|
||||||
|
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
|
@ -73,6 +81,7 @@ type roomStatements struct {
|
||||||
updateLatestEventNIDsStmt *sql.Stmt
|
updateLatestEventNIDsStmt *sql.Stmt
|
||||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||||
selectRoomInfoStmt *sql.Stmt
|
selectRoomInfoStmt *sql.Stmt
|
||||||
|
selectRoomIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
|
@ -91,9 +100,27 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||||
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
||||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||||
|
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||||
|
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||||
|
var roomIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||||
var info types.RoomInfo
|
var info types.RoomInfo
|
||||||
var latestNIDsJSON string
|
var latestNIDsJSON string
|
||||||
|
@ -203,3 +230,25 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
|
||||||
}
|
}
|
||||||
return roomVersion, err
|
return roomVersion, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||||
|
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||||
|
for i, v := range roomNIDs {
|
||||||
|
iRoomNIDs[i] = v
|
||||||
|
}
|
||||||
|
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||||
|
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||||
|
var roomIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -65,6 +65,8 @@ type Rooms interface {
|
||||||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||||
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
||||||
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||||
|
SelectRoomIDs(ctx context.Context) ([]string, error)
|
||||||
|
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Transactions interface {
|
type Transactions interface {
|
||||||
|
@ -120,6 +122,7 @@ type Membership interface {
|
||||||
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
|
||||||
|
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Published interface {
|
type Published interface {
|
||||||
|
|
Loading…
Reference in New Issue