Implement key caching directly (#1038)
* Use gomatrixserverlib key caching * Implement key caching wrapper * Add caching wrapper in BaseComponent * Review commentsmain
parent
7ca230e931
commit
419ff150d4
|
@ -25,6 +25,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common/caching"
|
"github.com/matrix-org/dendrite/common/caching"
|
||||||
"github.com/matrix-org/dendrite/common/keydb"
|
"github.com/matrix-org/dendrite/common/keydb"
|
||||||
|
"github.com/matrix-org/dendrite/common/keydb/cache"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/naffka"
|
"github.com/matrix-org/naffka"
|
||||||
|
@ -186,7 +187,11 @@ func (b *BaseDendrite) CreateKeyDB() keydb.Database {
|
||||||
logrus.WithError(err).Panicf("failed to connect to keys db")
|
logrus.WithError(err).Panicf("failed to connect to keys db")
|
||||||
}
|
}
|
||||||
|
|
||||||
return db
|
cachedDB, err := cache.NewKeyDatabase(db, b.ImmutableCache)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Panicf("failed to create key cache wrapper")
|
||||||
|
}
|
||||||
|
return cachedDB
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateFederationClient creates a new federation client. Should only be called
|
// CreateFederationClient creates a new federation client. Should only be called
|
||||||
|
|
|
@ -4,9 +4,12 @@ import "github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RoomVersionMaxCacheEntries = 128
|
RoomVersionMaxCacheEntries = 128
|
||||||
|
ServerKeysMaxCacheEntries = 128
|
||||||
)
|
)
|
||||||
|
|
||||||
type ImmutableCache interface {
|
type ImmutableCache interface {
|
||||||
GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool)
|
GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool)
|
||||||
StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion)
|
StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion)
|
||||||
|
GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest) (gomatrixserverlib.PublicKeyLookupResult, bool)
|
||||||
|
StoreServerKey(request gomatrixserverlib.PublicKeyLookupRequest, response gomatrixserverlib.PublicKeyLookupResult)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
type ImmutableInMemoryLRUCache struct {
|
type ImmutableInMemoryLRUCache struct {
|
||||||
roomVersions *lru.Cache
|
roomVersions *lru.Cache
|
||||||
|
serverKeys *lru.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) {
|
func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) {
|
||||||
|
@ -16,8 +17,13 @@ func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) {
|
||||||
if rvErr != nil {
|
if rvErr != nil {
|
||||||
return nil, rvErr
|
return nil, rvErr
|
||||||
}
|
}
|
||||||
|
serverKeysCache, rvErr := lru.New(ServerKeysMaxCacheEntries)
|
||||||
|
if rvErr != nil {
|
||||||
|
return nil, rvErr
|
||||||
|
}
|
||||||
return &ImmutableInMemoryLRUCache{
|
return &ImmutableInMemoryLRUCache{
|
||||||
roomVersions: roomVersionCache,
|
roomVersions: roomVersionCache,
|
||||||
|
serverKeys: serverKeysCache,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,3 +47,25 @@ func (c *ImmutableInMemoryLRUCache) StoreRoomVersion(roomID string, roomVersion
|
||||||
checkForInvalidMutation(c.roomVersions, roomID, roomVersion)
|
checkForInvalidMutation(c.roomVersions, roomID, roomVersion)
|
||||||
c.roomVersions.Add(roomID, roomVersion)
|
c.roomVersions.Add(roomID, roomVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ImmutableInMemoryLRUCache) GetServerKey(
|
||||||
|
request gomatrixserverlib.PublicKeyLookupRequest,
|
||||||
|
) (gomatrixserverlib.PublicKeyLookupResult, bool) {
|
||||||
|
key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
|
||||||
|
val, found := c.serverKeys.Get(key)
|
||||||
|
if found && val != nil {
|
||||||
|
if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok {
|
||||||
|
return keyLookupResult, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return gomatrixserverlib.PublicKeyLookupResult{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ImmutableInMemoryLRUCache) StoreServerKey(
|
||||||
|
request gomatrixserverlib.PublicKeyLookupRequest,
|
||||||
|
response gomatrixserverlib.PublicKeyLookupResult,
|
||||||
|
) {
|
||||||
|
key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
|
||||||
|
checkForInvalidMutation(c.roomVersions, key, response)
|
||||||
|
c.serverKeys.Add(request, response)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common/caching"
|
||||||
|
"github.com/matrix-org/dendrite/common/keydb"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
|
||||||
|
// the public keys for other matrix servers.
|
||||||
|
type KeyDatabase struct {
|
||||||
|
inner keydb.Database
|
||||||
|
cache caching.ImmutableCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKeyDatabase(inner keydb.Database, cache caching.ImmutableCache) (*KeyDatabase, error) {
|
||||||
|
if inner == nil {
|
||||||
|
return nil, errors.New("inner database can't be nil")
|
||||||
|
}
|
||||||
|
if cache == nil {
|
||||||
|
return nil, errors.New("cache can't be nil")
|
||||||
|
}
|
||||||
|
return &KeyDatabase{
|
||||||
|
inner: inner,
|
||||||
|
cache: cache,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetcherName implements KeyFetcher
|
||||||
|
func (d KeyDatabase) FetcherName() string {
|
||||||
|
return "InMemoryKeyCache"
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||||
|
func (d *KeyDatabase) FetchKeys(
|
||||||
|
ctx context.Context,
|
||||||
|
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||||
|
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||||
|
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
|
||||||
|
for req := range requests {
|
||||||
|
if res, cached := d.cache.GetServerKey(req); cached {
|
||||||
|
results[req] = res
|
||||||
|
delete(requests, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fromDB, err := d.inner.FetchKeys(ctx, requests)
|
||||||
|
if err != nil {
|
||||||
|
return results, err
|
||||||
|
}
|
||||||
|
for req, res := range fromDB {
|
||||||
|
results[req] = res
|
||||||
|
d.cache.StoreServerKey(req, res)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||||
|
func (d *KeyDatabase) StoreKeys(
|
||||||
|
ctx context.Context,
|
||||||
|
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||||
|
) error {
|
||||||
|
for req, res := range keyMap {
|
||||||
|
d.cache.StoreServerKey(req, res)
|
||||||
|
}
|
||||||
|
return d.inner.StoreKeys(ctx, keyMap)
|
||||||
|
}
|
|
@ -79,7 +79,7 @@ func NewDatabase(
|
||||||
|
|
||||||
// FetcherName implements KeyFetcher
|
// FetcherName implements KeyFetcher
|
||||||
func (d Database) FetcherName() string {
|
func (d Database) FetcherName() string {
|
||||||
return "KeyDatabase"
|
return "PostgresKeyDatabase"
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||||
|
|
|
@ -80,7 +80,7 @@ func NewDatabase(
|
||||||
|
|
||||||
// FetcherName implements KeyFetcher
|
// FetcherName implements KeyFetcher
|
||||||
func (d Database) FetcherName() string {
|
func (d Database) FetcherName() string {
|
||||||
return "KeyDatabase"
|
return "SqliteKeyDatabase"
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||||
|
|
|
@ -20,10 +20,8 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
lru "github.com/hashicorp/golang-lru"
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const serverKeysSchema = `
|
const serverKeysSchema = `
|
||||||
|
@ -66,16 +64,10 @@ type serverKeyStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
bulkSelectServerKeysStmt *sql.Stmt
|
bulkSelectServerKeysStmt *sql.Stmt
|
||||||
upsertServerKeysStmt *sql.Stmt
|
upsertServerKeysStmt *sql.Stmt
|
||||||
|
|
||||||
cache *lru.Cache // nameAndKeyID => gomatrixserverlib.PublicKeyLookupResult
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
|
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.cache, err = lru.New(64)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = db.Exec(serverKeysSchema)
|
_, err = db.Exec(serverKeysSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -98,21 +90,6 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
|
||||||
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we can satisfy all of the requests from the cache, do so. TODO: Allow partial matches with merges.
|
|
||||||
cacheResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
|
|
||||||
for request := range requests {
|
|
||||||
r, ok := s.cache.Get(nameAndKeyID(request))
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
cacheResult := r.(gomatrixserverlib.PublicKeyLookupResult)
|
|
||||||
cacheResults[request] = cacheResult
|
|
||||||
}
|
|
||||||
if len(cacheResults) == len(requests) {
|
|
||||||
util.GetLogger(ctx).Infof("KeyDB cache hit for %d keys", len(cacheResults))
|
|
||||||
return cacheResults, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", common.QueryVariadic(len(nameAndKeyIDs)), 1)
|
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", common.QueryVariadic(len(nameAndKeyIDs)), 1)
|
||||||
|
|
||||||
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
|
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
|
||||||
|
@ -158,7 +135,6 @@ func (s *serverKeyStatements) upsertServerKeys(
|
||||||
request gomatrixserverlib.PublicKeyLookupRequest,
|
request gomatrixserverlib.PublicKeyLookupRequest,
|
||||||
key gomatrixserverlib.PublicKeyLookupResult,
|
key gomatrixserverlib.PublicKeyLookupResult,
|
||||||
) error {
|
) error {
|
||||||
s.cache.Add(nameAndKeyID(request), key)
|
|
||||||
_, err := s.upsertServerKeysStmt.ExecContext(
|
_, err := s.upsertServerKeysStmt.ExecContext(
|
||||||
ctx,
|
ctx,
|
||||||
string(request.ServerName),
|
string(request.ServerName),
|
||||||
|
|
Loading…
Reference in New Issue