185 lines
6.2 KiB
Go
185 lines
6.2 KiB
Go
package shared
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"time"
|
|
|
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
|
"github.com/matrix-org/dendrite/common"
|
|
"github.com/matrix-org/dendrite/eduserver/cache"
|
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
|
"github.com/matrix-org/dendrite/syncapi/types"
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
|
|
// For now this contains the shared functions
|
|
type Database struct {
|
|
DB *sql.DB
|
|
Invites tables.Invites
|
|
AccountData tables.AccountData
|
|
OutputEvents tables.Events
|
|
EDUCache *cache.EDUCache
|
|
}
|
|
|
|
// Events lookups a list of event by their event ID.
|
|
// Returns a list of events matching the requested IDs found in the database.
|
|
// If an event is not found in the database then it will be omitted from the list.
|
|
// Returns an error if there was a problem talking with the database.
|
|
// Does not include any transaction IDs in the returned events.
|
|
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) {
|
|
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// We don't include a device here as we only include transaction IDs in
|
|
// incremental syncs.
|
|
return d.StreamEventsToEvents(nil, streamEvents), nil
|
|
}
|
|
|
|
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the
|
|
// given extremities and limit.
|
|
func (d *Database) GetEventsInStreamingRange(
|
|
ctx context.Context,
|
|
from, to *types.StreamingToken,
|
|
roomID string, limit int,
|
|
backwardOrdering bool,
|
|
) (events []types.StreamEvent, err error) {
|
|
if backwardOrdering {
|
|
// When using backward ordering, we want the most recent events first.
|
|
if events, err = d.OutputEvents.SelectRecentEvents(
|
|
ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false,
|
|
); err != nil {
|
|
return
|
|
}
|
|
} else {
|
|
// When using forward ordering, we want the least recent events first.
|
|
if events, err = d.OutputEvents.SelectEarlyEvents(
|
|
ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit,
|
|
); err != nil {
|
|
return
|
|
}
|
|
}
|
|
return events, err
|
|
}
|
|
|
|
func (d *Database) AddTypingUser(
|
|
userID, roomID string, expireTime *time.Time,
|
|
) types.StreamPosition {
|
|
return types.StreamPosition(d.EDUCache.AddTypingUser(userID, roomID, expireTime))
|
|
}
|
|
|
|
func (d *Database) RemoveTypingUser(
|
|
userID, roomID string,
|
|
) types.StreamPosition {
|
|
return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID))
|
|
}
|
|
|
|
func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
|
|
d.EDUCache.SetTimeoutCallback(fn)
|
|
}
|
|
|
|
func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {
|
|
var maxID int64
|
|
var err error
|
|
err = common.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
|
maxID, err = d.OutputEvents.SelectMaxEventID(ctx, txn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var maxAccountDataID int64
|
|
maxAccountDataID, err = d.AccountData.SelectMaxAccountDataID(ctx, txn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if maxAccountDataID > maxID {
|
|
maxID = maxAccountDataID
|
|
}
|
|
var maxInviteID int64
|
|
maxInviteID, err = d.Invites.SelectMaxInviteID(ctx, txn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if maxInviteID > maxID {
|
|
maxID = maxInviteID
|
|
}
|
|
return nil
|
|
})
|
|
return types.StreamPosition(maxID), err
|
|
}
|
|
|
|
// AddInviteEvent stores a new invite event for a user.
|
|
// If the invite was successfully stored this returns the stream ID it was stored at.
|
|
// Returns an error if there was a problem communicating with the database.
|
|
func (d *Database) AddInviteEvent(
|
|
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
|
|
) (sp types.StreamPosition, err error) {
|
|
err = common.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
|
sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// RetireInviteEvent removes an old invite event from the database.
|
|
// Returns an error if there was a problem communicating with the database.
|
|
func (d *Database) RetireInviteEvent(
|
|
ctx context.Context, inviteEventID string,
|
|
) error {
|
|
// TODO: Record that invite has been retired in a stream so that we can
|
|
// notify the user in an incremental sync.
|
|
err := d.Invites.DeleteInviteEvent(ctx, inviteEventID)
|
|
return err
|
|
}
|
|
|
|
// GetAccountDataInRange returns all account data for a given user inserted or
|
|
// updated between two given positions
|
|
// Returns a map following the format data[roomID] = []dataTypes
|
|
// If no data is retrieved, returns an empty map
|
|
// If there was an issue with the retrieval, returns an error
|
|
func (d *Database) GetAccountDataInRange(
|
|
ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
|
|
accountDataFilterPart *gomatrixserverlib.EventFilter,
|
|
) (map[string][]string, error) {
|
|
return d.AccountData.SelectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
|
|
}
|
|
|
|
// UpsertAccountData keeps track of new or updated account data, by saving the type
|
|
// of the new/updated data, and the user ID and room ID the data is related to (empty)
|
|
// room ID means the data isn't specific to any room)
|
|
// If no data with the given type, user ID and room ID exists in the database,
|
|
// creates a new row, else update the existing one
|
|
// Returns an error if there was an issue with the upsert
|
|
func (d *Database) UpsertAccountData(
|
|
ctx context.Context, userID, roomID, dataType string,
|
|
) (sp types.StreamPosition, err error) {
|
|
err = common.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
|
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
func (d *Database) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent {
|
|
out := make([]gomatrixserverlib.HeaderedEvent, len(in))
|
|
for i := 0; i < len(in); i++ {
|
|
out[i] = in[i].HeaderedEvent
|
|
if device != nil && in[i].TransactionID != nil {
|
|
if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
|
|
err := out[i].SetUnsignedField(
|
|
"transaction_id", in[i].TransactionID.TransactionID,
|
|
)
|
|
if err != nil {
|
|
logrus.WithFields(logrus.Fields{
|
|
"event_id": out[i].EventID(),
|
|
}).WithError(err).Warnf("Failed to add transaction ID to event")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return out
|
|
}
|