Implement more CSS storage functions in roomserver (#1388)

main
Kegsay 2020-09-03 18:27:02 +01:00 committed by GitHub
parent b20386123e
commit 33b8143a95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 303 additions and 20 deletions

View File

@ -18,7 +18,9 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -62,6 +64,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
);
`
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@ -102,6 +108,16 @@ const updateMembershipSQL = "" +
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway.
var selectKnownUsersSQL = "" +
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE room_nid = ANY(" +
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
@ -112,6 +128,8 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
}
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
@ -131,6 +149,8 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
}.Prepare(db)
}
@ -246,3 +266,42 @@ func (s *membershipStatements) SelectRoomsWithMembership(
}
return roomNIDs, nil
}
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
result := make(map[types.EventStateKeyNID]int)
for rows.Next() {
var userID types.EventStateKeyNID
var count int
if err := rows.Scan(&userID, &count); err != nil {
return nil, err
}
result[userID] = count
}
return result, rows.Err()
}
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
result := []string{}
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View File

@ -81,6 +81,9 @@ const selectRoomIDsSQL = "" +
const bulkSelectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
const bulkSelectRoomNIDsSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
type roomStatements struct {
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
@ -91,6 +94,7 @@ type roomStatements struct {
selectRoomInfoStmt *sql.Stmt
selectRoomIDsStmt *sql.Stmt
bulkSelectRoomIDsStmt *sql.Stmt
bulkSelectRoomNIDsStmt *sql.Stmt
}
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
@ -109,6 +113,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
{&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL},
}.Prepare(db)
}
@ -245,3 +250,24 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
}
return roomIDs, nil
}
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
var array pq.StringArray
for _, roomID := range roomIDs {
array = append(array, roomID)
}
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
for rows.Next() {
var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}
roomNIDs = append(roomNIDs, roomNID)
}
return roomNIDs, nil
}

View File

@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"sort"
csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
"github.com/matrix-org/dendrite/internal/caching"
@ -13,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
)
@ -717,25 +719,42 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
/*
roomInfo, err := d.RoomInfo(ctx, roomID)
if err != nil {
return nil, err
roomInfo, err := d.RoomInfo(ctx, roomID)
if err != nil {
return nil, err
}
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
if err != nil {
return nil, err
}
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
if err != nil {
return nil, err
}
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
if err != nil {
return nil, err
}
// return the event requested
for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID})
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID)
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(data[0].EventJSON, false, roomInfo.RoomVersion)
if err != nil {
return nil, err
}
h := ev.Headered(roomInfo.RoomVersion)
return &h, nil
}
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
if err != nil {
return nil, err
}
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
if err != nil {
return nil, err
}
blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID})
if err != nil {
return nil, err
}
*/
return nil, nil
}
return nil, fmt.Errorf("GetStateEvent: no event type '%s' with key '%s' exists in room %s", evType, stateKey, roomID)
}
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
@ -779,15 +798,106 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
return nil, fmt.Errorf("not implemented yet")
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs)
if err != nil {
return nil, err
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs)
if err != nil {
return nil, err
}
stateKeyNIDs := make([]types.EventStateKeyNID, len(userNIDToCount))
i := 0
for nid := range userNIDToCount {
stateKeyNIDs[i] = nid
i++
}
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
if err != nil {
return nil, err
}
if len(nidToUserID) != len(userNIDToCount) {
return nil, fmt.Errorf("found %d users but only have state key nids for %d of them", len(userNIDToCount), len(nidToUserID))
}
result := make(map[string]int, len(userNIDToCount))
for nid, count := range userNIDToCount {
result[nidToUserID[nid]] = count
}
return result, nil
}
// GetKnownUsers searches all users that userID knows about.
func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
return nil, fmt.Errorf("not implemented yet")
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
if err != nil {
return nil, err
}
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit)
}
// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx)
}
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
// it should live in this package!
func (d *Database) loadStateAtSnapshot(
ctx context.Context, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) {
stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil {
return nil, err
}
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0]
stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
if err != nil {
return nil, err
}
stateEntriesMap := stateEntryListMap(stateEntryLists)
// Combine all the state entries for this snapshot.
// The order of state block NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
entries, ok := stateEntriesMap.lookup(stateBlockNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
}
fullState = append(fullState, entries...)
}
// Stable sort so that the most recent entry for each state key stays
// remains later in the list than the older entries for the same state key.
sort.Stable(stateEntryByStateKeySorter(fullState))
// Unique returns the last entry and hence the most recent entry for each state key.
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
return fullState, nil
}
type stateEntryListMap []types.StateEntryList
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
list := []types.StateEntryList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateBlockNID >= stateBlockNID
})
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
ok = true
stateEntries = list[i].StateEntries
}
return
}
type stateEntryByStateKeySorter []types.StateEntry
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
func (s stateEntryByStateKeySorter) Less(i, j int) bool {
return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
}
func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View File

@ -18,6 +18,8 @@ package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -38,6 +40,10 @@ const membershipSchema = `
);
`
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@ -78,6 +84,16 @@ const updateMembershipSQL = "" +
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway.
var selectKnownUsersSQL = "" +
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE room_nid IN (" +
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
type membershipStatements struct {
db *sql.DB
insertMembershipStmt *sql.Stmt
@ -89,6 +105,7 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
}
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
@ -110,6 +127,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
}.Prepare(db)
}
@ -227,3 +245,43 @@ func (s *membershipStatements) SelectRoomsWithMembership(
}
return roomNIDs, nil
}
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
result := make(map[types.EventStateKeyNID]int)
for rows.Next() {
var userID types.EventStateKeyNID
var count int
if err := rows.Scan(&userID, &count); err != nil {
return nil, err
}
result[userID] = count
}
return result, rows.Err()
}
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
result := []string{}
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View File

@ -72,6 +72,9 @@ const selectRoomIDsSQL = "" +
const bulkSelectRoomIDsSQL = "" +
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
const bulkSelectRoomNIDsSQL = "" +
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
type roomStatements struct {
db *sql.DB
insertRoomNIDStmt *sql.Stmt
@ -252,3 +255,25 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
}
return roomIDs, nil
}
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs {
iRoomIDs[i] = v
}
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
for rows.Next() {
var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}
roomNIDs = append(roomNIDs, roomNID)
}
return roomNIDs, nil
}

View File

@ -67,6 +67,7 @@ type Rooms interface {
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error)
}
type Transactions interface {
@ -123,6 +124,10 @@ type Membership interface {
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
}
type Published interface {