Implement /sync `limited` and read timeline limit from stored filters (#1168)

* Move filter table to syncapi where it is used

* Implement /sync `limited` and read timeline limit from stored filters

We now fully handle `room.timeline.limit` filters (in-line + stored) and
return the right value for `limited` syncs.

* Update whitelist

* Default to the default timeline limit if it's unset, also strip the extra event correctly

* Update whitelist
main
Kegsay 2020-06-26 15:34:41 +01:00 committed by GitHub
parent 164057a3be
commit 1ad7219e4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 194 additions and 135 deletions

View File

@ -376,26 +376,6 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PutFilter(req, device, accountDB, vars["userId"])
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetFilter(req, device, accountDB, vars["userId"], vars["filterId"])
}),
).Methods(http.MethodGet, http.MethodOptions)
// Riot user settings // Riot user settings
r0mux.Handle("/profile/{userID}", r0mux.Handle("/profile/{userID}",

View File

@ -15,19 +15,22 @@
package routing package routing
import ( import (
"encoding/json"
"io/ioutil"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/tidwall/gjson"
) )
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
func GetFilter( func GetFilter(
req *http.Request, device *api.Device, accountDB accounts.Database, userID string, filterID string, req *http.Request, device *api.Device, syncDB storage.Database, userID string, filterID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
@ -41,7 +44,7 @@ func GetFilter(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
filter, err := accountDB.GetFilter(req.Context(), localpart, filterID) filter, err := syncDB.GetFilter(req.Context(), localpart, filterID)
if err != nil { if err != nil {
//TODO better error handling. This error message is *probably* right, //TODO better error handling. This error message is *probably* right,
// but if there are obscure db errors, this will also be returned, // but if there are obscure db errors, this will also be returned,
@ -64,7 +67,7 @@ type filterResponse struct {
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
func PutFilter( func PutFilter(
req *http.Request, device *api.Device, accountDB accounts.Database, userID string, req *http.Request, device *api.Device, syncDB storage.Database, userID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
@ -81,8 +84,27 @@ func PutFilter(
var filter gomatrixserverlib.Filter var filter gomatrixserverlib.Filter
if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil { defer req.Body.Close() // nolint:errcheck
return *reqErr body, err := ioutil.ReadAll(req.Body)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be read. " + err.Error()),
}
}
if err = json.Unmarshal(body, &filter); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
}
// the filter `limit` is `int` which defaults to 0 if not set which is not what we want. We want to use the default
// limit if it is unset, which is what this does.
limitRes := gjson.GetBytes(body, "room.timeline.limit")
if !limitRes.Exists() {
util.GetLogger(req.Context()).Infof("missing timeline limit, using default")
filter.Room.Timeline.Limit = sync.DefaultTimelineLimit
} }
// Validate generates a user-friendly error // Validate generates a user-friendly error
@ -93,9 +115,9 @@ func PutFilter(
} }
} }
filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter) filterID, err := syncDB.PutFilter(req.Context(), localpart, &filter)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.PutFilter failed") util.GetLogger(req.Context()).WithError(err).Error("syncDB.PutFilter failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View File

@ -55,4 +55,24 @@ func Setup(
} }
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, rsAPI, cfg) return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, rsAPI, cfg)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PutFilter(req, device, syncDB, vars["userId"])
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetFilter(req, device, syncDB, vars["userId"], vars["filterId"])
}),
).Methods(http.MethodGet, http.MethodOptions)
} }

View File

@ -128,4 +128,12 @@ type Database interface {
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent. // SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error) SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
} }

View File

@ -19,12 +19,13 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const filterSchema = ` const filterSchema = `
-- Stores data about filters -- Stores data about filters
CREATE TABLE IF NOT EXISTS account_filter ( CREATE TABLE IF NOT EXISTS syncapi_filter (
-- The filter -- The filter
filter TEXT NOT NULL, filter TEXT NOT NULL,
-- The ID -- The ID
@ -35,17 +36,17 @@ CREATE TABLE IF NOT EXISTS account_filter (
PRIMARY KEY(id, localpart) PRIMARY KEY(id, localpart)
); );
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart); CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart);
` `
const selectFilterSQL = "" + const selectFilterSQL = "" +
"SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2" "SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2"
const selectFilterIDByContentSQL = "" + const selectFilterIDByContentSQL = "" +
"SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2" "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
const insertFilterSQL = "" + const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id" "INSERT INTO syncapi_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id"
type filterStatements struct { type filterStatements struct {
selectFilterStmt *sql.Stmt selectFilterStmt *sql.Stmt
@ -53,24 +54,25 @@ type filterStatements struct {
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
} }
func (s *filterStatements) prepare(db *sql.DB) (err error) { func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
_, err = db.Exec(filterSchema) _, err := db.Exec(filterSchema)
if err != nil { if err != nil {
return return nil, err
} }
s := &filterStatements{}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return return nil, err
} }
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return return nil, err
} }
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return return nil, err
} }
return return s, nil
} }
func (s *filterStatements) selectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) (*gomatrixserverlib.Filter, error) {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
@ -88,7 +90,7 @@ func (s *filterStatements) selectFilter(
return &filter, nil return &filter, nil
} }
func (s *filterStatements) insertFilter( func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) { ) (filterID string, err error) {
var existingFilterID string var existingFilterID string

View File

@ -301,21 +301,21 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int, roomID string, r types.Range, limit int,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, bool, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
if onlySyncEvents { if onlySyncEvents {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt)
} else { } else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
} }
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit+1)
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
events, err := rowsToStreamEvents(rows) events, err := rowsToStreamEvents(rows)
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
if chronologicalOrder { if chronologicalOrder {
// The events need to be returned from oldest to latest, which isn't // The events need to be returned from oldest to latest, which isn't
@ -325,7 +325,19 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
return events[i].StreamPosition < events[j].StreamPosition return events[i].StreamPosition < events[j].StreamPosition
}) })
} }
return events, nil // we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false
if len(events) > limit {
limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder {
events = events[1:]
} else {
events = events[:len(events)-1]
}
}
return events, limited, nil
} }
// selectEarlyEvents returns the earliest events in the given room, starting // selectEarlyEvents returns the earliest events in the given room, starting

View File

@ -71,6 +71,10 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*S
if err != nil { if err != nil {
return nil, err return nil, err
} }
filter, err := NewPostgresFilterTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Invites: invites, Invites: invites,
@ -79,6 +83,7 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*S
Topology: topology, Topology: topology,
CurrentRoomState: currState, CurrentRoomState: currState,
BackwardExtremities: backwardExtremities, BackwardExtremities: backwardExtremities,
Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewTransactionWriter(), SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),

View File

@ -43,6 +43,7 @@ type Database struct {
CurrentRoomState tables.CurrentRoomState CurrentRoomState tables.CurrentRoomState
BackwardExtremities tables.BackwardsExtremities BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice SendToDevice tables.SendToDevice
Filter tables.Filter
SendToDeviceWriter *sqlutil.TransactionWriter SendToDeviceWriter *sqlutil.TransactionWriter
EDUCache *cache.EDUCache EDUCache *cache.EDUCache
} }
@ -78,7 +79,7 @@ func (d *Database) GetEventsInStreamingRange(
} }
if backwardOrdering { if backwardOrdering {
// When using backward ordering, we want the most recent events first. // When using backward ordering, we want the most recent events first.
if events, err = d.OutputEvents.SelectRecentEvents( if events, _, err = d.OutputEvents.SelectRecentEvents(
ctx, nil, roomID, r, limit, false, false, ctx, nil, roomID, r, limit, false, false,
); err != nil { ); err != nil {
return return
@ -545,6 +546,18 @@ func (d *Database) addEDUDeltaToResponse(
return return
} }
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.Filter.SelectFilter(ctx, localpart, filterID)
}
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.Filter.InsertFilter(ctx, filter, localpart)
}
func (d *Database) IncrementalSync( func (d *Database) IncrementalSync(
ctx context.Context, res *types.Response, ctx context.Context, res *types.Response,
device userapi.Device, device userapi.Device,
@ -642,7 +655,8 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []types.StreamEvent var recentStreamEvents []types.StreamEvent
recentStreamEvents, err = d.OutputEvents.SelectRecentEvents( var limited bool
recentStreamEvents, limited, err = d.OutputEvents.SelectRecentEvents(
ctx, txn, roomID, r, numRecentEventsPerRoom, true, true, ctx, txn, roomID, r, numRecentEventsPerRoom, true, true,
) )
if err != nil { if err != nil {
@ -670,7 +684,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatchStr jr.Timeline.PrevBatch = prevBatchStr
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[roomID] = *jr res.Rooms.Join[roomID] = *jr
} }
@ -776,7 +790,7 @@ func (d *Database) addRoomDeltaToResponse(
// This is all "okay" assuming history_visibility == "shared" which it is by default. // This is all "okay" assuming history_visibility == "shared" which it is by default.
r.To = delta.membershipPos r.To = delta.membershipPos
} }
recentStreamEvents, err := d.OutputEvents.SelectRecentEvents( recentStreamEvents, limited, err := d.OutputEvents.SelectRecentEvents(
ctx, txn, delta.roomID, r, ctx, txn, delta.roomID, r,
numRecentEventsPerRoom, true, true, numRecentEventsPerRoom, true, true,
) )
@ -796,7 +810,7 @@ func (d *Database) addRoomDeltaToResponse(
jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.PrevBatch = prevBatch.String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.roomID] = *jr res.Rooms.Join[delta.roomID] = *jr
case gomatrixserverlib.Leave: case gomatrixserverlib.Leave:

View File

@ -20,12 +20,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const filterSchema = ` const filterSchema = `
-- Stores data about filters -- Stores data about filters
CREATE TABLE IF NOT EXISTS account_filter ( CREATE TABLE IF NOT EXISTS syncapi_filter (
-- The filter -- The filter
filter TEXT NOT NULL, filter TEXT NOT NULL,
-- The ID -- The ID
@ -36,17 +37,17 @@ CREATE TABLE IF NOT EXISTS account_filter (
UNIQUE (id, localpart) UNIQUE (id, localpart)
); );
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart); CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart);
` `
const selectFilterSQL = "" + const selectFilterSQL = "" +
"SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2" "SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2"
const selectFilterIDByContentSQL = "" + const selectFilterIDByContentSQL = "" +
"SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2" "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
const insertFilterSQL = "" + const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)" "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
type filterStatements struct { type filterStatements struct {
selectFilterStmt *sql.Stmt selectFilterStmt *sql.Stmt
@ -54,24 +55,25 @@ type filterStatements struct {
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
} }
func (s *filterStatements) prepare(db *sql.DB) (err error) { func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
_, err = db.Exec(filterSchema) _, err := db.Exec(filterSchema)
if err != nil { if err != nil {
return return nil, err
} }
s := &filterStatements{}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return return nil, err
} }
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return return nil, err
} }
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return return nil, err
} }
return return s, nil
} }
func (s *filterStatements) selectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) (*gomatrixserverlib.Filter, error) {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
@ -89,7 +91,7 @@ func (s *filterStatements) selectFilter(
return &filter, nil return &filter, nil
} }
func (s *filterStatements) insertFilter( func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) { ) (filterID string, err error) {
var existingFilterID string var existingFilterID string

View File

@ -311,7 +311,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int, roomID string, r types.Range, limit int,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, bool, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
if onlySyncEvents { if onlySyncEvents {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt)
@ -319,14 +319,14 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
} }
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit+1)
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
events, err := rowsToStreamEvents(rows) events, err := rowsToStreamEvents(rows)
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
if chronologicalOrder { if chronologicalOrder {
// The events need to be returned from oldest to latest, which isn't // The events need to be returned from oldest to latest, which isn't
@ -336,7 +336,18 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
return events[i].StreamPosition < events[j].StreamPosition return events[i].StreamPosition < events[j].StreamPosition
}) })
} }
return events, nil // we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false
if len(events) > limit {
limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder {
events = events[1:]
} else {
events = events[:len(events)-1]
}
}
return events, limited, nil
} }
func (s *outputRoomEventsStatements) SelectEarlyEvents( func (s *outputRoomEventsStatements) SelectEarlyEvents(

View File

@ -87,6 +87,10 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
filter, err := NewSqliteFilterTable(d.db)
if err != nil {
return err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Invites: invites, Invites: invites,
@ -95,6 +99,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
BackwardExtremities: bwExtrem, BackwardExtremities: bwExtrem,
CurrentRoomState: roomState, CurrentRoomState: roomState,
Topology: topology, Topology: topology,
Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewTransactionWriter(), SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),

View File

@ -44,8 +44,8 @@ type Events interface {
InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error) InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error)
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, error) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
@ -133,3 +133,8 @@ type SendToDevice interface {
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
} }
type Filter interface {
SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
}

View File

@ -363,7 +363,7 @@ func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syn
timeout: 1 * time.Minute, timeout: 1 * time.Minute,
since: &since, since: &since,
wantFullState: false, wantFullState: false,
limit: defaultTimelineLimit, limit: DefaultTimelineLimit,
log: util.GetLogger(context.TODO()), log: util.GetLogger(context.TODO()),
ctx: context.TODO(), ctx: context.TODO(),
} }

View File

@ -21,14 +21,16 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const defaultSyncTimeout = time.Duration(0) const defaultSyncTimeout = time.Duration(0)
const defaultTimelineLimit = 20 const DefaultTimelineLimit = 20
type filter struct { type filter struct {
Room struct { Room struct {
@ -49,7 +51,7 @@ type syncRequest struct {
log *log.Entry log *log.Entry
} }
func newSyncRequest(req *http.Request, device userapi.Device) (*syncRequest, error) { func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Database) (*syncRequest, error) {
timeout := getTimeout(req.URL.Query().Get("timeout")) timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state") fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false" wantFullState := fullState != "" && fullState != "false"
@ -66,16 +68,29 @@ func newSyncRequest(req *http.Request, device userapi.Device) (*syncRequest, err
tok := types.NewStreamToken(0, 0) tok := types.NewStreamToken(0, 0)
since = &tok since = &tok
} }
timelineLimit := defaultTimelineLimit timelineLimit := DefaultTimelineLimit
// TODO: read from stored filters too // TODO: read from stored filters too
filterQuery := req.URL.Query().Get("filter") filterQuery := req.URL.Query().Get("filter")
if filterQuery != "" && filterQuery[0] == '{' { if filterQuery != "" {
if filterQuery[0] == '{' {
// attempt to parse the timeline limit at least // attempt to parse the timeline limit at least
var f filter var f filter
err := json.Unmarshal([]byte(filterQuery), &f) err := json.Unmarshal([]byte(filterQuery), &f)
if err == nil && f.Room.Timeline.Limit != nil { if err == nil && f.Room.Timeline.Limit != nil {
timelineLimit = *f.Room.Timeline.Limit timelineLimit = *f.Room.Timeline.Limit
} }
} else {
// attempt to load the filter ID
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, err
}
f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery)
if err == nil {
timelineLimit = f.Room.Timeline.Limit
}
}
} }
// TODO: Additional query params: set_presence, filter // TODO: Additional query params: set_presence, filter
return &syncRequest{ return &syncRequest{

View File

@ -49,7 +49,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
var syncData *types.Response var syncData *types.Response
// Extract values from request // Extract values from request
syncReq, err := newSyncRequest(req, *device) syncReq, err := newSyncRequest(req, *device, rp.db)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

View File

@ -387,3 +387,9 @@ Can reject invites over federation for rooms with version 4
Can reject invites over federation for rooms with version 5 Can reject invites over federation for rooms with version 5
Can reject invites over federation for rooms with version 6 Can reject invites over federation for rooms with version 6
Event size limits Event size limits
Can sync a room with a single message
Can sync a room with a message with a transaction id
A full_state incremental update returns only recent timeline
A prev_batch token can be used in the v1 messages API
We don't send redundant membership state across incremental syncs by default
Typing notifications don't leak

View File

@ -52,8 +52,6 @@ type Database interface {
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
} }

View File

@ -40,7 +40,6 @@ type Database struct {
memberships membershipStatements memberships membershipStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -75,11 +74,7 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serve
if err = t.prepare(db); err != nil { if err = t.prepare(db); err != nil {
return nil, err return nil, err
} }
f := filterStatements{} return &Database{db, partitions, a, p, m, ac, t, serverName}, nil
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
@ -396,24 +391,6 @@ func (d *Database) GetThreePIDsForLocalpart(
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
} }
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present // CheckAccountAvailability checks if the username/localpart is already present
// in the database. // in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken. // If the DB returns sql.ErrNoRows the Localpart isn't taken.

View File

@ -39,7 +39,6 @@ type Database struct {
memberships membershipStatements memberships membershipStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
createAccountMu sync.Mutex createAccountMu sync.Mutex
@ -80,11 +79,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if err = t.prepare(db); err != nil { if err = t.prepare(db); err != nil {
return nil, err return nil, err
} }
f := filterStatements{} return &Database{db, partitions, a, p, m, ac, t, serverName, sync.Mutex{}}, nil
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
@ -410,24 +405,6 @@ func (d *Database) GetThreePIDsForLocalpart(
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
} }
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present // CheckAccountAvailability checks if the username/localpart is already present
// in the database. // in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken. // If the DB returns sql.ErrNoRows the Localpart isn't taken.