Add context to the server key database (#248)

main
Mark Haines 2017-09-21 16:16:02 +01:00 committed by GitHub
parent 7596c19f3a
commit fef290c47e
2 changed files with 16 additions and 7 deletions

View File

@ -49,7 +49,7 @@ func (d *Database) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
return d.statements.bulkSelectServerKeys(requests)
return d.statements.bulkSelectServerKeys(ctx, requests)
}
// StoreKeys implements gomatrixserverlib.KeyDatabase
@ -62,7 +62,7 @@ func (d *Database) StoreKeys(
// high for a single insert statement.
var lastErr error
for request, keys := range keyMap {
if err := d.statements.upsertServerKeys(request, keys); err != nil {
if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
// Rather than returning immediately on error we try to insert the
// remaining keys.
// Since we are inserting the keys outside of a transaction it is

View File

@ -15,6 +15,7 @@
package keydb
import (
"context"
"database/sql"
"encoding/json"
@ -73,13 +74,15 @@ func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
}
func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
var nameAndKeyIDs []string
for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
}
rows, err := s.bulkSelectServerKeysStmt.Query(pq.StringArray(nameAndKeyIDs))
stmt := s.bulkSelectServerKeysStmt
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
if err != nil {
return nil, err
}
@ -106,15 +109,21 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
}
func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyRequest, keys gomatrixserverlib.ServerKeys,
ctx context.Context,
request gomatrixserverlib.PublicKeyRequest,
keys gomatrixserverlib.ServerKeys,
) error {
keyJSON, err := json.Marshal(keys)
if err != nil {
return err
}
_, err = s.upsertServerKeysStmt.Exec(
string(request.ServerName), string(request.KeyID), nameAndKeyID(request),
int64(keys.ValidUntilTS), keyJSON,
_, err = s.upsertServerKeysStmt.ExecContext(
ctx,
string(request.ServerName),
string(request.KeyID),
nameAndKeyID(request),
int64(keys.ValidUntilTS),
keyJSON,
)
return err
}