Implement key caching directly (#1038)

* Use gomatrixserverlib key caching

* Implement key caching wrapper

* Add caching wrapper in BaseComponent

* Review comments
main
Neil Alexander 2020-05-15 09:32:40 +01:00 committed by GitHub
parent 7ca230e931
commit 419ff150d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 27 deletions

View File

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/common/caching" "github.com/matrix-org/dendrite/common/caching"
"github.com/matrix-org/dendrite/common/keydb" "github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/common/keydb/cache"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/naffka" "github.com/matrix-org/naffka"
@ -186,7 +187,11 @@ func (b *BaseDendrite) CreateKeyDB() keydb.Database {
logrus.WithError(err).Panicf("failed to connect to keys db") logrus.WithError(err).Panicf("failed to connect to keys db")
} }
return db cachedDB, err := cache.NewKeyDatabase(db, b.ImmutableCache)
if err != nil {
logrus.WithError(err).Panicf("failed to create key cache wrapper")
}
return cachedDB
} }
// CreateFederationClient creates a new federation client. Should only be called // CreateFederationClient creates a new federation client. Should only be called

View File

@ -4,9 +4,12 @@ import "github.com/matrix-org/gomatrixserverlib"
const ( const (
RoomVersionMaxCacheEntries = 128 RoomVersionMaxCacheEntries = 128
ServerKeysMaxCacheEntries = 128
) )
type ImmutableCache interface { type ImmutableCache interface {
GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool) GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool)
StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion) StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion)
GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest) (gomatrixserverlib.PublicKeyLookupResult, bool)
StoreServerKey(request gomatrixserverlib.PublicKeyLookupRequest, response gomatrixserverlib.PublicKeyLookupResult)
} }

View File

@ -9,6 +9,7 @@ import (
type ImmutableInMemoryLRUCache struct { type ImmutableInMemoryLRUCache struct {
roomVersions *lru.Cache roomVersions *lru.Cache
serverKeys *lru.Cache
} }
func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) { func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) {
@ -16,8 +17,13 @@ func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) {
if rvErr != nil { if rvErr != nil {
return nil, rvErr return nil, rvErr
} }
serverKeysCache, rvErr := lru.New(ServerKeysMaxCacheEntries)
if rvErr != nil {
return nil, rvErr
}
return &ImmutableInMemoryLRUCache{ return &ImmutableInMemoryLRUCache{
roomVersions: roomVersionCache, roomVersions: roomVersionCache,
serverKeys: serverKeysCache,
}, nil }, nil
} }
@ -41,3 +47,25 @@ func (c *ImmutableInMemoryLRUCache) StoreRoomVersion(roomID string, roomVersion
checkForInvalidMutation(c.roomVersions, roomID, roomVersion) checkForInvalidMutation(c.roomVersions, roomID, roomVersion)
c.roomVersions.Add(roomID, roomVersion) c.roomVersions.Add(roomID, roomVersion)
} }
func (c *ImmutableInMemoryLRUCache) GetServerKey(
request gomatrixserverlib.PublicKeyLookupRequest,
) (gomatrixserverlib.PublicKeyLookupResult, bool) {
key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
val, found := c.serverKeys.Get(key)
if found && val != nil {
if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok {
return keyLookupResult, true
}
}
return gomatrixserverlib.PublicKeyLookupResult{}, false
}
func (c *ImmutableInMemoryLRUCache) StoreServerKey(
request gomatrixserverlib.PublicKeyLookupRequest,
response gomatrixserverlib.PublicKeyLookupResult,
) {
key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
checkForInvalidMutation(c.roomVersions, key, response)
c.serverKeys.Add(request, response)
}

69
common/keydb/cache/keydb.go vendored Normal file
View File

@ -0,0 +1,69 @@
package cache
import (
"context"
"errors"
"github.com/matrix-org/dendrite/common/caching"
"github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/gomatrixserverlib"
)
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers.
type KeyDatabase struct {
inner keydb.Database
cache caching.ImmutableCache
}
func NewKeyDatabase(inner keydb.Database, cache caching.ImmutableCache) (*KeyDatabase, error) {
if inner == nil {
return nil, errors.New("inner database can't be nil")
}
if cache == nil {
return nil, errors.New("cache can't be nil")
}
return &KeyDatabase{
inner: inner,
cache: cache,
}, nil
}
// FetcherName implements KeyFetcher
func (d KeyDatabase) FetcherName() string {
return "InMemoryKeyCache"
}
// FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *KeyDatabase) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
for req := range requests {
if res, cached := d.cache.GetServerKey(req); cached {
results[req] = res
delete(requests, req)
}
}
fromDB, err := d.inner.FetchKeys(ctx, requests)
if err != nil {
return results, err
}
for req, res := range fromDB {
results[req] = res
d.cache.StoreServerKey(req, res)
}
return results, nil
}
// StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *KeyDatabase) StoreKeys(
ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
for req, res := range keyMap {
d.cache.StoreServerKey(req, res)
}
return d.inner.StoreKeys(ctx, keyMap)
}

View File

@ -79,7 +79,7 @@ func NewDatabase(
// FetcherName implements KeyFetcher // FetcherName implements KeyFetcher
func (d Database) FetcherName() string { func (d Database) FetcherName() string {
return "KeyDatabase" return "PostgresKeyDatabase"
} }
// FetchKeys implements gomatrixserverlib.KeyDatabase // FetchKeys implements gomatrixserverlib.KeyDatabase

View File

@ -80,7 +80,7 @@ func NewDatabase(
// FetcherName implements KeyFetcher // FetcherName implements KeyFetcher
func (d Database) FetcherName() string { func (d Database) FetcherName() string {
return "KeyDatabase" return "SqliteKeyDatabase"
} }
// FetchKeys implements gomatrixserverlib.KeyDatabase // FetchKeys implements gomatrixserverlib.KeyDatabase

View File

@ -20,10 +20,8 @@ import (
"database/sql" "database/sql"
"strings" "strings"
lru "github.com/hashicorp/golang-lru"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
const serverKeysSchema = ` const serverKeysSchema = `
@ -66,16 +64,10 @@ type serverKeyStatements struct {
db *sql.DB db *sql.DB
bulkSelectServerKeysStmt *sql.Stmt bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt
cache *lru.Cache // nameAndKeyID => gomatrixserverlib.PublicKeyLookupResult
} }
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
s.db = db s.db = db
s.cache, err = lru.New(64)
if err != nil {
return
}
_, err = db.Exec(serverKeysSchema) _, err = db.Exec(serverKeysSchema)
if err != nil { if err != nil {
return return
@ -98,21 +90,6 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
// If we can satisfy all of the requests from the cache, do so. TODO: Allow partial matches with merges.
cacheResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for request := range requests {
r, ok := s.cache.Get(nameAndKeyID(request))
if !ok {
break
}
cacheResult := r.(gomatrixserverlib.PublicKeyLookupResult)
cacheResults[request] = cacheResult
}
if len(cacheResults) == len(requests) {
util.GetLogger(ctx).Infof("KeyDB cache hit for %d keys", len(cacheResults))
return cacheResults, nil
}
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", common.QueryVariadic(len(nameAndKeyIDs)), 1) query := strings.Replace(bulkSelectServerKeysSQL, "($1)", common.QueryVariadic(len(nameAndKeyIDs)), 1)
iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
@ -158,7 +135,6 @@ func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyLookupRequest, request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult, key gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
s.cache.Add(nameAndKeyID(request), key)
_, err := s.upsertServerKeysStmt.ExecContext( _, err := s.upsertServerKeysStmt.ExecContext(
ctx, ctx,
string(request.ServerName), string(request.ServerName),