From 419ff150d41a3d0de25f0e8e66baf36948bcfbc1 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 15 May 2020 09:32:40 +0100 Subject: [PATCH] Implement key caching directly (#1038) * Use gomatrixserverlib key caching * Implement key caching wrapper * Add caching wrapper in BaseComponent * Review comments --- common/basecomponent/base.go | 7 ++- common/caching/immutablecache.go | 3 ++ common/caching/immutableinmemorylru.go | 28 ++++++++++ common/keydb/cache/keydb.go | 69 ++++++++++++++++++++++++ common/keydb/postgres/keydb.go | 2 +- common/keydb/sqlite3/keydb.go | 2 +- common/keydb/sqlite3/server_key_table.go | 24 --------- 7 files changed, 108 insertions(+), 27 deletions(-) create mode 100644 common/keydb/cache/keydb.go diff --git a/common/basecomponent/base.go b/common/basecomponent/base.go index cb04a308..4342e25a 100644 --- a/common/basecomponent/base.go +++ b/common/basecomponent/base.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/common/caching" "github.com/matrix-org/dendrite/common/keydb" + "github.com/matrix-org/dendrite/common/keydb/cache" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/naffka" @@ -186,7 +187,11 @@ func (b *BaseDendrite) CreateKeyDB() keydb.Database { logrus.WithError(err).Panicf("failed to connect to keys db") } - return db + cachedDB, err := cache.NewKeyDatabase(db, b.ImmutableCache) + if err != nil { + logrus.WithError(err).Panicf("failed to create key cache wrapper") + } + return cachedDB } // CreateFederationClient creates a new federation client. Should only be called diff --git a/common/caching/immutablecache.go b/common/caching/immutablecache.go index 9620667a..362e4349 100644 --- a/common/caching/immutablecache.go +++ b/common/caching/immutablecache.go @@ -4,9 +4,12 @@ import "github.com/matrix-org/gomatrixserverlib" const ( RoomVersionMaxCacheEntries = 128 + ServerKeysMaxCacheEntries = 128 ) type ImmutableCache interface { GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool) StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion) + GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest) (gomatrixserverlib.PublicKeyLookupResult, bool) + StoreServerKey(request gomatrixserverlib.PublicKeyLookupRequest, response gomatrixserverlib.PublicKeyLookupResult) } diff --git a/common/caching/immutableinmemorylru.go b/common/caching/immutableinmemorylru.go index 3e8f4aad..6d2a785f 100644 --- a/common/caching/immutableinmemorylru.go +++ b/common/caching/immutableinmemorylru.go @@ -9,6 +9,7 @@ import ( type ImmutableInMemoryLRUCache struct { roomVersions *lru.Cache + serverKeys *lru.Cache } func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) { @@ -16,8 +17,13 @@ func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) { if rvErr != nil { return nil, rvErr } + serverKeysCache, rvErr := lru.New(ServerKeysMaxCacheEntries) + if rvErr != nil { + return nil, rvErr + } return &ImmutableInMemoryLRUCache{ roomVersions: roomVersionCache, + serverKeys: serverKeysCache, }, nil } @@ -41,3 +47,25 @@ func (c *ImmutableInMemoryLRUCache) StoreRoomVersion(roomID string, roomVersion checkForInvalidMutation(c.roomVersions, roomID, roomVersion) c.roomVersions.Add(roomID, roomVersion) } + +func (c *ImmutableInMemoryLRUCache) GetServerKey( + request gomatrixserverlib.PublicKeyLookupRequest, +) (gomatrixserverlib.PublicKeyLookupResult, bool) { + key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) + val, found := c.serverKeys.Get(key) + if found && val != nil { + if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok { + return keyLookupResult, true + } + } + return gomatrixserverlib.PublicKeyLookupResult{}, false +} + +func (c *ImmutableInMemoryLRUCache) StoreServerKey( + request gomatrixserverlib.PublicKeyLookupRequest, + response gomatrixserverlib.PublicKeyLookupResult, +) { + key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) + checkForInvalidMutation(c.roomVersions, key, response) + c.serverKeys.Add(request, response) +} diff --git a/common/keydb/cache/keydb.go b/common/keydb/cache/keydb.go new file mode 100644 index 00000000..ae929fa4 --- /dev/null +++ b/common/keydb/cache/keydb.go @@ -0,0 +1,69 @@ +package cache + +import ( + "context" + "errors" + + "github.com/matrix-org/dendrite/common/caching" + "github.com/matrix-org/dendrite/common/keydb" + "github.com/matrix-org/gomatrixserverlib" +) + +// A Database implements gomatrixserverlib.KeyDatabase and is used to store +// the public keys for other matrix servers. +type KeyDatabase struct { + inner keydb.Database + cache caching.ImmutableCache +} + +func NewKeyDatabase(inner keydb.Database, cache caching.ImmutableCache) (*KeyDatabase, error) { + if inner == nil { + return nil, errors.New("inner database can't be nil") + } + if cache == nil { + return nil, errors.New("cache can't be nil") + } + return &KeyDatabase{ + inner: inner, + cache: cache, + }, nil +} + +// FetcherName implements KeyFetcher +func (d KeyDatabase) FetcherName() string { + return "InMemoryKeyCache" +} + +// FetchKeys implements gomatrixserverlib.KeyDatabase +func (d *KeyDatabase) FetchKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) + for req := range requests { + if res, cached := d.cache.GetServerKey(req); cached { + results[req] = res + delete(requests, req) + } + } + fromDB, err := d.inner.FetchKeys(ctx, requests) + if err != nil { + return results, err + } + for req, res := range fromDB { + results[req] = res + d.cache.StoreServerKey(req, res) + } + return results, nil +} + +// StoreKeys implements gomatrixserverlib.KeyDatabase +func (d *KeyDatabase) StoreKeys( + ctx context.Context, + keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + for req, res := range keyMap { + d.cache.StoreServerKey(req, res) + } + return d.inner.StoreKeys(ctx, keyMap) +} diff --git a/common/keydb/postgres/keydb.go b/common/keydb/postgres/keydb.go index 6149d877..706ca005 100644 --- a/common/keydb/postgres/keydb.go +++ b/common/keydb/postgres/keydb.go @@ -79,7 +79,7 @@ func NewDatabase( // FetcherName implements KeyFetcher func (d Database) FetcherName() string { - return "KeyDatabase" + return "PostgresKeyDatabase" } // FetchKeys implements gomatrixserverlib.KeyDatabase diff --git a/common/keydb/sqlite3/keydb.go b/common/keydb/sqlite3/keydb.go index 1405836a..94a32e29 100644 --- a/common/keydb/sqlite3/keydb.go +++ b/common/keydb/sqlite3/keydb.go @@ -80,7 +80,7 @@ func NewDatabase( // FetcherName implements KeyFetcher func (d Database) FetcherName() string { - return "KeyDatabase" + return "SqliteKeyDatabase" } // FetchKeys implements gomatrixserverlib.KeyDatabase diff --git a/common/keydb/sqlite3/server_key_table.go b/common/keydb/sqlite3/server_key_table.go index ba1cc060..883d3cd0 100644 --- a/common/keydb/sqlite3/server_key_table.go +++ b/common/keydb/sqlite3/server_key_table.go @@ -20,10 +20,8 @@ import ( "database/sql" "strings" - lru "github.com/hashicorp/golang-lru" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) const serverKeysSchema = ` @@ -66,16 +64,10 @@ type serverKeyStatements struct { db *sql.DB bulkSelectServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt - - cache *lru.Cache // nameAndKeyID => gomatrixserverlib.PublicKeyLookupResult } func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { s.db = db - s.cache, err = lru.New(64) - if err != nil { - return - } _, err = db.Exec(serverKeysSchema) if err != nil { return @@ -98,21 +90,6 @@ func (s *serverKeyStatements) bulkSelectServerKeys( nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } - // If we can satisfy all of the requests from the cache, do so. TODO: Allow partial matches with merges. - cacheResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - for request := range requests { - r, ok := s.cache.Get(nameAndKeyID(request)) - if !ok { - break - } - cacheResult := r.(gomatrixserverlib.PublicKeyLookupResult) - cacheResults[request] = cacheResult - } - if len(cacheResults) == len(requests) { - util.GetLogger(ctx).Infof("KeyDB cache hit for %d keys", len(cacheResults)) - return cacheResults, nil - } - query := strings.Replace(bulkSelectServerKeysSQL, "($1)", common.QueryVariadic(len(nameAndKeyIDs)), 1) iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) @@ -158,7 +135,6 @@ func (s *serverKeyStatements) upsertServerKeys( request gomatrixserverlib.PublicKeyLookupRequest, key gomatrixserverlib.PublicKeyLookupResult, ) error { - s.cache.Add(nameAndKeyID(request), key) _, err := s.upsertServerKeysStmt.ExecContext( ctx, string(request.ServerName),