Add context.Context to the federation client (#225)
* Add context.Context to the federation client * gb vendor update github.com/matrix-org/gomatrixserverlibmain
parent
086683459f
commit
029e71828a
|
@ -66,7 +66,7 @@ func DirectoryRoom(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
resp, err = federation.LookupRoomAlias(domain, roomAlias)
|
resp, err = federation.LookupRoomAlias(req.Context(), domain, roomAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch x := err.(type) {
|
switch x := err.(type) {
|
||||||
case gomatrix.HTTPError:
|
case gomatrix.HTTPError:
|
||||||
|
|
|
@ -136,7 +136,7 @@ func (r joinRoomReq) joinRoomByAlias(roomAlias string) util.JSONResponse {
|
||||||
func (r joinRoomReq) joinRoomByRemoteAlias(
|
func (r joinRoomReq) joinRoomByRemoteAlias(
|
||||||
domain gomatrixserverlib.ServerName, roomAlias string,
|
domain gomatrixserverlib.ServerName, roomAlias string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
resp, err := r.federation.LookupRoomAlias(domain, roomAlias)
|
resp, err := r.federation.LookupRoomAlias(r.req.Context(), domain, roomAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch x := err.(type) {
|
switch x := err.(type) {
|
||||||
case gomatrix.HTTPError:
|
case gomatrix.HTTPError:
|
||||||
|
@ -226,7 +226,7 @@ func (r joinRoomReq) joinRoomUsingServers(
|
||||||
// server was invalid this returns an error.
|
// server was invalid this returns an error.
|
||||||
// Otherwise this returns a JSONResponse.
|
// Otherwise this returns a JSONResponse.
|
||||||
func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) {
|
func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) {
|
||||||
respMakeJoin, err := r.federation.MakeJoin(server, roomID, r.userID)
|
respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: Check if the user was not allowed to join the room.
|
// TODO: Check if the user was not allowed to join the room.
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -246,12 +246,12 @@ func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
respSendJoin, err := r.federation.SendJoin(server, event)
|
respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = respSendJoin.Check(r.keyRing, event); err != nil {
|
if err = respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,9 @@
|
||||||
package keydb
|
package keydb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,6 +43,7 @@ func NewDatabase(dataSourceName string) (*Database, error) {
|
||||||
|
|
||||||
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||||
func (d *Database) FetchKeys(
|
func (d *Database) FetchKeys(
|
||||||
|
ctx context.Context,
|
||||||
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
||||||
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
||||||
return d.statements.bulkSelectServerKeys(requests)
|
return d.statements.bulkSelectServerKeys(requests)
|
||||||
|
@ -48,6 +51,7 @@ func (d *Database) FetchKeys(
|
||||||
|
|
||||||
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||||
func (d *Database) StoreKeys(
|
func (d *Database) StoreKeys(
|
||||||
|
ctx context.Context,
|
||||||
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys,
|
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys,
|
||||||
) error {
|
) error {
|
||||||
// TODO: Inserting all the keys within a single transaction may
|
// TODO: Inserting all the keys within a single transaction may
|
||||||
|
|
|
@ -76,7 +76,7 @@ func Invite(
|
||||||
Message: event.Redact().JSON(),
|
Message: event.Redact().JSON(),
|
||||||
AtTS: event.OriginServerTS(),
|
AtTS: event.OriginServerTS(),
|
||||||
}}
|
}}
|
||||||
verifyResults, err := keys.VerifyJSONs(verifyRequests)
|
verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(httpReq, err)
|
return httputil.LogThenError(httpReq, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package writers
|
package writers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -41,6 +42,7 @@ func Send(
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
||||||
t := txnReq{
|
t := txnReq{
|
||||||
|
context: httpReq.Context(),
|
||||||
query: query,
|
query: query,
|
||||||
producer: producer,
|
producer: producer,
|
||||||
keys: keys,
|
keys: keys,
|
||||||
|
@ -70,6 +72,7 @@ func Send(
|
||||||
|
|
||||||
type txnReq struct {
|
type txnReq struct {
|
||||||
gomatrixserverlib.Transaction
|
gomatrixserverlib.Transaction
|
||||||
|
context context.Context
|
||||||
query api.RoomserverQueryAPI
|
query api.RoomserverQueryAPI
|
||||||
producer *producers.RoomserverProducer
|
producer *producers.RoomserverProducer
|
||||||
keys gomatrixserverlib.KeyRing
|
keys gomatrixserverlib.KeyRing
|
||||||
|
@ -78,7 +81,7 @@ type txnReq struct {
|
||||||
|
|
||||||
func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
|
func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
|
||||||
// Check the event signatures
|
// Check the event signatures
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(t.PDUs, t.keys); err != nil {
|
if err := gomatrixserverlib.VerifyEventSignatures(t.context, t.PDUs, t.keys); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +113,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
|
||||||
// our server so we should bail processing the transaction entirely.
|
// our server so we should bail processing the transaction entirely.
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
results[e.EventID()] = gomatrixserverlib.PDUResult{err.Error()}
|
results[e.EventID()] = gomatrixserverlib.PDUResult{
|
||||||
|
Error: err.Error(),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
results[e.EventID()] = gomatrixserverlib.PDUResult{}
|
results[e.EventID()] = gomatrixserverlib.PDUResult{}
|
||||||
}
|
}
|
||||||
|
@ -197,12 +202,12 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event) error {
|
||||||
// need to fallback to /state.
|
// need to fallback to /state.
|
||||||
// TODO: Attempt to fill in the gap using /get_missing_events
|
// TODO: Attempt to fill in the gap using /get_missing_events
|
||||||
// TODO: Attempt to fetch the state using /state_ids and /events
|
// TODO: Attempt to fetch the state using /state_ids and /events
|
||||||
state, err := t.federation.LookupState(t.Origin, e.RoomID(), e.EventID())
|
state, err := t.federation.LookupState(t.context, t.Origin, e.RoomID(), e.EventID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Check that the returned state is valid.
|
// Check that the returned state is valid.
|
||||||
if err := state.Check(t.keys); err != nil {
|
if err := state.Check(t.context, t.keys); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Check that the event is allowed by the state.
|
// Check that the event is allowed by the state.
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package writers
|
package writers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -63,7 +64,9 @@ func CreateInvitesFrom3PIDInvites(
|
||||||
|
|
||||||
evs := []gomatrixserverlib.Event{}
|
evs := []gomatrixserverlib.Event{}
|
||||||
for _, inv := range body.Invites {
|
for _, inv := range body.Invites {
|
||||||
event, err := createInviteFrom3PIDInvite(queryAPI, cfg, inv, federation)
|
event, err := createInviteFrom3PIDInvite(
|
||||||
|
req.Context(), queryAPI, cfg, inv, federation,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
@ -139,7 +142,7 @@ func ExchangeThirdPartyInvite(
|
||||||
|
|
||||||
// Ask the requesting server to sign the newly created event so we know it
|
// Ask the requesting server to sign the newly created event so we know it
|
||||||
// acknowledged it
|
// acknowledged it
|
||||||
signedEvent, err := federation.SendInvite(request.Origin(), *event)
|
signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), *event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(httpReq, err)
|
return httputil.LogThenError(httpReq, err)
|
||||||
}
|
}
|
||||||
|
@ -160,8 +163,8 @@ func ExchangeThirdPartyInvite(
|
||||||
// Returns an error if there was a problem building the event or fetching the
|
// Returns an error if there was a problem building the event or fetching the
|
||||||
// necessary data to do so.
|
// necessary data to do so.
|
||||||
func createInviteFrom3PIDInvite(
|
func createInviteFrom3PIDInvite(
|
||||||
queryAPI api.RoomserverQueryAPI, cfg config.Dendrite, inv invite,
|
ctx context.Context, queryAPI api.RoomserverQueryAPI, cfg config.Dendrite,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
inv invite, federation *gomatrixserverlib.FederationClient,
|
||||||
) (*gomatrixserverlib.Event, error) {
|
) (*gomatrixserverlib.Event, error) {
|
||||||
// Build the event
|
// Build the event
|
||||||
builder := &gomatrixserverlib.EventBuilder{
|
builder := &gomatrixserverlib.EventBuilder{
|
||||||
|
@ -185,7 +188,10 @@ func createInviteFrom3PIDInvite(
|
||||||
|
|
||||||
event, err := buildMembershipEvent(builder, queryAPI, cfg)
|
event, err := buildMembershipEvent(builder, queryAPI, cfg)
|
||||||
if err == errNotInRoom {
|
if err == errNotInRoom {
|
||||||
return nil, sendToRemoteServer(inv, federation, cfg, *builder)
|
return nil, sendToRemoteServer(ctx, inv, federation, cfg, *builder)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return event, nil
|
return event, nil
|
||||||
|
@ -253,7 +259,8 @@ func buildMembershipEvent(
|
||||||
// Returns an error if it couldn't get the server names to reach or if all of
|
// Returns an error if it couldn't get the server names to reach or if all of
|
||||||
// them responded with an error.
|
// them responded with an error.
|
||||||
func sendToRemoteServer(
|
func sendToRemoteServer(
|
||||||
inv invite, federation *gomatrixserverlib.FederationClient, cfg config.Dendrite,
|
ctx context.Context, inv invite,
|
||||||
|
federation *gomatrixserverlib.FederationClient, cfg config.Dendrite,
|
||||||
builder gomatrixserverlib.EventBuilder,
|
builder gomatrixserverlib.EventBuilder,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
remoteServers := make([]gomatrixserverlib.ServerName, 2)
|
remoteServers := make([]gomatrixserverlib.ServerName, 2)
|
||||||
|
@ -269,7 +276,7 @@ func sendToRemoteServer(
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, server := range remoteServers {
|
for _, server := range remoteServers {
|
||||||
err = federation.ExchangeThirdPartyInvite(server, builder)
|
err = federation.ExchangeThirdPartyInvite(ctx, server, builder)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package queue
|
package queue
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -65,7 +66,7 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
// TODO: handle retries.
|
// TODO: handle retries.
|
||||||
// TODO: blacklist uncooperative servers.
|
// TODO: blacklist uncooperative servers.
|
||||||
|
|
||||||
_, err := oq.client.SendTransaction(*t)
|
_, err := oq.client.SendTransaction(context.TODO(), *t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"destination": oq.destination,
|
"destination": oq.destination,
|
||||||
|
|
|
@ -116,7 +116,7 @@
|
||||||
{
|
{
|
||||||
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
||||||
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
||||||
"revision": "790f02e8f465552dab4317ffe7ca047ccb594cbf",
|
"revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e",
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -17,6 +17,7 @@ package gomatrixserverlib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -103,7 +104,9 @@ func (f *federationTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
|
||||||
// LookupUserInfo gets information about a user from a given matrix homeserver
|
// LookupUserInfo gets information about a user from a given matrix homeserver
|
||||||
// using a bearer access token.
|
// using a bearer access token.
|
||||||
func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserInfo, err error) {
|
func (fc *Client) LookupUserInfo(
|
||||||
|
ctx context.Context, matrixServer ServerName, token string,
|
||||||
|
) (u UserInfo, err error) {
|
||||||
url := url.URL{
|
url := url.URL{
|
||||||
Scheme: "matrix",
|
Scheme: "matrix",
|
||||||
Host: string(matrixServer),
|
Host: string(matrixServer),
|
||||||
|
@ -111,8 +114,13 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
|
||||||
RawQuery: url.Values{"access_token": []string{token}}.Encode(),
|
RawQuery: url.Values{"access_token": []string{token}}.Encode(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", url.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var response *http.Response
|
var response *http.Response
|
||||||
response, err = fc.client.Get(url.String())
|
response, err = fc.client.Do(req.WithContext(ctx))
|
||||||
if response != nil {
|
if response != nil {
|
||||||
defer response.Body.Close() // nolint: errcheck
|
defer response.Body.Close() // nolint: errcheck
|
||||||
}
|
}
|
||||||
|
@ -153,7 +161,7 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
|
||||||
// copy of the keys.
|
// copy of the keys.
|
||||||
// Returns the keys or an error if there was a problem talking to the server.
|
// Returns the keys or an error if there was a problem talking to the server.
|
||||||
func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
||||||
matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
|
ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
|
||||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||||
url := url.URL{
|
url := url.URL{
|
||||||
Scheme: "matrix",
|
Scheme: "matrix",
|
||||||
|
@ -183,7 +191,13 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := fc.client.Post(url.String(), "application/json", bytes.NewBuffer(requestBytes))
|
req, err := http.NewRequest("POST", url.String(), bytes.NewBuffer(requestBytes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
|
response, err := fc.client.Do(req.WithContext(ctx))
|
||||||
if response != nil {
|
if response != nil {
|
||||||
defer response.Body.Close() // nolint: errcheck
|
defer response.Body.Close() // nolint: errcheck
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ package gomatrixserverlib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -188,7 +189,7 @@ func verifyEventSignature(signingName string, keyID KeyID, publicKey ed25519.Pub
|
||||||
|
|
||||||
// VerifyEventSignatures checks that each event in a list of events has valid
|
// VerifyEventSignatures checks that each event in a list of events has valid
|
||||||
// signatures from the server that sent it.
|
// signatures from the server that sent it.
|
||||||
func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: gocyclo
|
func VerifyEventSignatures(ctx context.Context, events []Event, keyRing KeyRing) error { // nolint: gocyclo
|
||||||
var toVerify []VerifyJSONRequest
|
var toVerify []VerifyJSONRequest
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
redactedJSON, err := redactEvent(event.eventJSON)
|
redactedJSON, err := redactEvent(event.eventJSON)
|
||||||
|
@ -222,7 +223,7 @@ func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: g
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := keyRing.VerifyJSONs(toVerify)
|
results, err := keyRing.VerifyJSONs(ctx, toVerify)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gomatrixserverlib
|
package gomatrixserverlib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -31,7 +32,7 @@ func NewFederationClient(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{}) error {
|
func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest, resBody interface{}) error {
|
||||||
if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil {
|
if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -41,7 +42,7 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := ac.client.Do(req)
|
res, err := ac.client.Do(req.WithContext(ctx))
|
||||||
if res != nil {
|
if res != nil {
|
||||||
defer res.Body.Close() // nolint: errcheck
|
defer res.Body.Close() // nolint: errcheck
|
||||||
}
|
}
|
||||||
|
@ -87,13 +88,15 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
|
||||||
var federationPathPrefix = "/_matrix/federation/v1"
|
var federationPathPrefix = "/_matrix/federation/v1"
|
||||||
|
|
||||||
// SendTransaction sends a transaction
|
// SendTransaction sends a transaction
|
||||||
func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err error) {
|
func (ac *FederationClient) SendTransaction(
|
||||||
|
ctx context.Context, t Transaction,
|
||||||
|
) (res RespSend, err error) {
|
||||||
path := federationPathPrefix + "/send/" + string(t.TransactionID) + "/"
|
path := federationPathPrefix + "/send/" + string(t.TransactionID) + "/"
|
||||||
req := NewFederationRequest("PUT", t.Destination, path)
|
req := NewFederationRequest("PUT", t.Destination, path)
|
||||||
if err = req.SetContent(t); err != nil {
|
if err = req.SetContent(t); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ac.doRequest(req, &res)
|
err = ac.doRequest(ctx, req, &res)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,12 +109,14 @@ func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err er
|
||||||
// If this successfully returns an acceptable event we will sign it with our
|
// If this successfully returns an acceptable event we will sign it with our
|
||||||
// server's key and pass it to SendJoin.
|
// server's key and pass it to SendJoin.
|
||||||
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
|
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
|
||||||
func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res RespMakeJoin, err error) {
|
func (ac *FederationClient) MakeJoin(
|
||||||
|
ctx context.Context, s ServerName, roomID, userID string,
|
||||||
|
) (res RespMakeJoin, err error) {
|
||||||
path := federationPathPrefix + "/make_join/" +
|
path := federationPathPrefix + "/make_join/" +
|
||||||
url.PathEscape(roomID) + "/" +
|
url.PathEscape(roomID) + "/" +
|
||||||
url.PathEscape(userID)
|
url.PathEscape(userID)
|
||||||
req := NewFederationRequest("GET", s, path)
|
req := NewFederationRequest("GET", s, path)
|
||||||
err = ac.doRequest(req, &res)
|
err = ac.doRequest(ctx, req, &res)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +124,9 @@ func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res R
|
||||||
// remote matrix server.
|
// remote matrix server.
|
||||||
// This is used to join a room the local server isn't a member of.
|
// This is used to join a room the local server isn't a member of.
|
||||||
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
|
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
|
||||||
func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoin, err error) {
|
func (ac *FederationClient) SendJoin(
|
||||||
|
ctx context.Context, s ServerName, event Event,
|
||||||
|
) (res RespSendJoin, err error) {
|
||||||
path := federationPathPrefix + "/send_join/" +
|
path := federationPathPrefix + "/send_join/" +
|
||||||
url.PathEscape(event.RoomID()) + "/" +
|
url.PathEscape(event.RoomID()) + "/" +
|
||||||
url.PathEscape(event.EventID())
|
url.PathEscape(event.EventID())
|
||||||
|
@ -127,13 +134,15 @@ func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoi
|
||||||
if err = req.SetContent(event); err != nil {
|
if err = req.SetContent(event); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ac.doRequest(req, &res)
|
err = ac.doRequest(ctx, req, &res)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendInvite sends an invite m.room.member event to an invited server to be
|
// SendInvite sends an invite m.room.member event to an invited server to be
|
||||||
// signed by it. This is used to invite a user that is not on the local server.
|
// signed by it. This is used to invite a user that is not on the local server.
|
||||||
func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvite, err error) {
|
func (ac *FederationClient) SendInvite(
|
||||||
|
ctx context.Context, s ServerName, event Event,
|
||||||
|
) (res RespInvite, err error) {
|
||||||
path := federationPathPrefix + "/invite/" +
|
path := federationPathPrefix + "/invite/" +
|
||||||
url.PathEscape(event.RoomID()) + "/" +
|
url.PathEscape(event.RoomID()) + "/" +
|
||||||
url.PathEscape(event.EventID())
|
url.PathEscape(event.EventID())
|
||||||
|
@ -141,7 +150,7 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
|
||||||
if err = req.SetContent(event); err != nil {
|
if err = req.SetContent(event); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ac.doRequest(req, &res)
|
err = ac.doRequest(ctx, req, &res)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,38 +159,44 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
|
||||||
// server.
|
// server.
|
||||||
// This is used to exchange a m.room.third_party_invite event for a m.room.member
|
// This is used to exchange a m.room.third_party_invite event for a m.room.member
|
||||||
// one in a room the local server isn't a member of.
|
// one in a room the local server isn't a member of.
|
||||||
func (ac *FederationClient) ExchangeThirdPartyInvite(s ServerName, builder EventBuilder) (err error) {
|
func (ac *FederationClient) ExchangeThirdPartyInvite(
|
||||||
|
ctx context.Context, s ServerName, builder EventBuilder,
|
||||||
|
) (err error) {
|
||||||
path := federationPathPrefix + "/exchange_third_party_invite/" +
|
path := federationPathPrefix + "/exchange_third_party_invite/" +
|
||||||
url.PathEscape(builder.RoomID)
|
url.PathEscape(builder.RoomID)
|
||||||
req := NewFederationRequest("PUT", s, path)
|
req := NewFederationRequest("PUT", s, path)
|
||||||
if err = req.SetContent(builder); err != nil {
|
if err = req.SetContent(builder); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ac.doRequest(req, nil)
|
err = ac.doRequest(ctx, req, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupState retrieves the room state for a room at an event from a
|
// LookupState retrieves the room state for a room at an event from a
|
||||||
// remote matrix server as full matrix events.
|
// remote matrix server as full matrix events.
|
||||||
func (ac *FederationClient) LookupState(s ServerName, roomID, eventID string) (res RespState, err error) {
|
func (ac *FederationClient) LookupState(
|
||||||
|
ctx context.Context, s ServerName, roomID, eventID string,
|
||||||
|
) (res RespState, err error) {
|
||||||
path := federationPathPrefix + "/state/" +
|
path := federationPathPrefix + "/state/" +
|
||||||
url.PathEscape(roomID) +
|
url.PathEscape(roomID) +
|
||||||
"/?event_id=" +
|
"/?event_id=" +
|
||||||
url.QueryEscape(eventID)
|
url.QueryEscape(eventID)
|
||||||
req := NewFederationRequest("GET", s, path)
|
req := NewFederationRequest("GET", s, path)
|
||||||
err = ac.doRequest(req, &res)
|
err = ac.doRequest(ctx, req, &res)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupStateIDs retrieves the room state for a room at an event from a
|
// LookupStateIDs retrieves the room state for a room at an event from a
|
||||||
// remote matrix server as lists of matrix event IDs.
|
// remote matrix server as lists of matrix event IDs.
|
||||||
func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string) (res RespStateIDs, err error) {
|
func (ac *FederationClient) LookupStateIDs(
|
||||||
|
ctx context.Context, s ServerName, roomID, eventID string,
|
||||||
|
) (res RespStateIDs, err error) {
|
||||||
path := federationPathPrefix + "/state_ids/" +
|
path := federationPathPrefix + "/state_ids/" +
|
||||||
url.PathEscape(roomID) +
|
url.PathEscape(roomID) +
|
||||||
"/?event_id=" +
|
"/?event_id=" +
|
||||||
url.QueryEscape(eventID)
|
url.QueryEscape(eventID)
|
||||||
req := NewFederationRequest("GET", s, path)
|
req := NewFederationRequest("GET", s, path)
|
||||||
err = ac.doRequest(req, &res)
|
err = ac.doRequest(ctx, req, &res)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,10 +205,12 @@ func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string)
|
||||||
// being looked up on.
|
// being looked up on.
|
||||||
// If the room alias doesn't exist on the remote server then a 404 gomatrix.HTTPError
|
// If the room alias doesn't exist on the remote server then a 404 gomatrix.HTTPError
|
||||||
// is returned.
|
// is returned.
|
||||||
func (ac *FederationClient) LookupRoomAlias(s ServerName, roomAlias string) (res RespDirectory, err error) {
|
func (ac *FederationClient) LookupRoomAlias(
|
||||||
|
ctx context.Context, s ServerName, roomAlias string,
|
||||||
|
) (res RespDirectory, err error) {
|
||||||
path := federationPathPrefix + "/query/directory?room_alias=" +
|
path := federationPathPrefix + "/query/directory?room_alias=" +
|
||||||
url.QueryEscape(roomAlias)
|
url.QueryEscape(roomAlias)
|
||||||
req := NewFederationRequest("GET", s, path)
|
req := NewFederationRequest("GET", s, path)
|
||||||
err = ac.doRequest(req, &res)
|
err = ac.doRequest(ctx, req, &res)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gomatrixserverlib
|
package gomatrixserverlib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
@ -107,7 +108,7 @@ func (r RespState) Events() ([]Event, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that a response to /state is valid.
|
// Check that a response to /state is valid.
|
||||||
func (r RespState) Check(keyRing KeyRing) error {
|
func (r RespState) Check(ctx context.Context, keyRing KeyRing) error {
|
||||||
var allEvents []Event
|
var allEvents []Event
|
||||||
for _, event := range r.AuthEvents {
|
for _, event := range r.AuthEvents {
|
||||||
if event.StateKey() == nil {
|
if event.StateKey() == nil {
|
||||||
|
@ -133,7 +134,7 @@ func (r RespState) Check(keyRing KeyRing) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the events pass signature checks.
|
// Check if the events pass signature checks.
|
||||||
if err := VerifyEventSignatures(allEvents, keyRing); err != nil {
|
if err := VerifyEventSignatures(ctx, allEvents, keyRing); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,11 +214,11 @@ type respSendJoinFields struct {
|
||||||
// Check that a response to /send_join is valid.
|
// Check that a response to /send_join is valid.
|
||||||
// This checks that it would be valid as a response to /state
|
// This checks that it would be valid as a response to /state
|
||||||
// This also checks that the join event is allowed by the state.
|
// This also checks that the join event is allowed by the state.
|
||||||
func (r RespSendJoin) Check(keyRing KeyRing, joinEvent Event) error {
|
func (r RespSendJoin) Check(ctx context.Context, keyRing KeyRing, joinEvent Event) error {
|
||||||
// First check that the state is valid.
|
// First check that the state is valid.
|
||||||
// The response to /send_join has the same data as a response to /state
|
// The response to /send_join has the same data as a response to /state
|
||||||
// and the checks for a response to /state also apply.
|
// and the checks for a response to /state also apply.
|
||||||
if err := RespState(r).Check(keyRing); err != nil {
|
if err := RespState(r).Check(ctx, keyRing); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,13 +6,13 @@ echo "Installing lint search engine..."
|
||||||
go get github.com/alecthomas/gometalinter/
|
go get github.com/alecthomas/gometalinter/
|
||||||
gometalinter --config=linter.json --install --update
|
gometalinter --config=linter.json --install --update
|
||||||
|
|
||||||
|
echo "Testing..."
|
||||||
|
go test
|
||||||
|
|
||||||
echo "Looking for lint..."
|
echo "Looking for lint..."
|
||||||
gometalinter --config=linter.json
|
gometalinter --config=linter.json
|
||||||
|
|
||||||
echo "Double checking spelling..."
|
echo "Double checking spelling..."
|
||||||
misspell -error src *.md
|
misspell -error src *.md
|
||||||
|
|
||||||
echo "Testing..."
|
|
||||||
go test
|
|
||||||
|
|
||||||
echo "Done!"
|
echo "Done!"
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gomatrixserverlib
|
package gomatrixserverlib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -26,7 +27,7 @@ type KeyFetcher interface {
|
||||||
// The result may have fewer (server name, key ID) pairs than were in the request.
|
// The result may have fewer (server name, key ID) pairs than were in the request.
|
||||||
// The result may have more (server name, key ID) pairs than were in the request.
|
// The result may have more (server name, key ID) pairs than were in the request.
|
||||||
// Returns an error if there was a problem fetching the keys.
|
// Returns an error if there was a problem fetching the keys.
|
||||||
FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
|
FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// A KeyDatabase is a store for caching public keys.
|
// A KeyDatabase is a store for caching public keys.
|
||||||
|
@ -39,7 +40,7 @@ type KeyDatabase interface {
|
||||||
// to a concurrent FetchKeys(). This is acceptable since the database is
|
// to a concurrent FetchKeys(). This is acceptable since the database is
|
||||||
// only used as a cache for the keys, so if a FetchKeys() races with a
|
// only used as a cache for the keys, so if a FetchKeys() races with a
|
||||||
// StoreKeys() and some of the keys are missing they will be just be refetched.
|
// StoreKeys() and some of the keys are missing they will be just be refetched.
|
||||||
StoreKeys(map[PublicKeyRequest]ServerKeys) error
|
StoreKeys(ctx context.Context, results map[PublicKeyRequest]ServerKeys) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
|
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
|
||||||
|
@ -73,7 +74,7 @@ type VerifyJSONResult struct {
|
||||||
// The caller should check the Result field for each entry to see if it was valid.
|
// The caller should check the Result field for each entry to see if it was valid.
|
||||||
// Returns an error if there was a problem talking to the database or one of the other methods
|
// Returns an error if there was a problem talking to the database or one of the other methods
|
||||||
// of fetching the public keys.
|
// of fetching the public keys.
|
||||||
func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
|
func (k *KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
|
||||||
results := make([]VerifyJSONResult, len(requests))
|
results := make([]VerifyJSONResult, len(requests))
|
||||||
keyIDs := make([][]KeyID, len(requests))
|
keyIDs := make([][]KeyID, len(requests))
|
||||||
|
|
||||||
|
@ -109,7 +110,7 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
|
||||||
// This will happen if all the objects are missing supported signatures.
|
// This will happen if all the objects are missing supported signatures.
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
keysFromDatabase, err := k.KeyDatabase.FetchKeys(keyRequests)
|
keysFromDatabase, err := k.KeyDatabase.FetchKeys(ctx, keyRequests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -124,14 +125,14 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
|
||||||
}
|
}
|
||||||
// TODO: Coalesce in-flight requests for the same keys.
|
// TODO: Coalesce in-flight requests for the same keys.
|
||||||
// Otherwise we risk spamming the servers we query the keys from.
|
// Otherwise we risk spamming the servers we query the keys from.
|
||||||
keysFetched, err := k.KeyFetchers[i].FetchKeys(keyRequests)
|
keysFetched, err := k.KeyFetchers[i].FetchKeys(ctx, keyRequests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
k.checkUsingKeys(requests, results, keyIDs, keysFetched)
|
k.checkUsingKeys(requests, results, keyIDs, keysFetched)
|
||||||
|
|
||||||
// Add the keys to the database so that we won't need to fetch them again.
|
// Add the keys to the database so that we won't need to fetch them again.
|
||||||
if err := k.KeyDatabase.StoreKeys(keysFetched); err != nil {
|
if err := k.KeyDatabase.StoreKeys(ctx, keysFetched); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -143,7 +144,9 @@ func (k *KeyRing) isAlgorithmSupported(keyID KeyID) bool {
|
||||||
return strings.HasPrefix(string(keyID), "ed25519:")
|
return strings.HasPrefix(string(keyID), "ed25519:")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KeyRing) publicKeyRequests(requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID) map[PublicKeyRequest]Timestamp {
|
func (k *KeyRing) publicKeyRequests(
|
||||||
|
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
|
||||||
|
) map[PublicKeyRequest]Timestamp {
|
||||||
keyRequests := map[PublicKeyRequest]Timestamp{}
|
keyRequests := map[PublicKeyRequest]Timestamp{}
|
||||||
for i := range requests {
|
for i := range requests {
|
||||||
if results[i].Error == nil {
|
if results[i].Error == nil {
|
||||||
|
@ -218,8 +221,10 @@ type PerspectiveKeyFetcher struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchKeys implements KeyFetcher
|
// FetchKeys implements KeyFetcher
|
||||||
func (p *PerspectiveKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
|
func (p *PerspectiveKeyFetcher) FetchKeys(
|
||||||
results, err := p.Client.LookupServerKeys(p.PerspectiveServerName, requests)
|
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||||
|
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||||
|
results, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -269,7 +274,9 @@ type DirectKeyFetcher struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchKeys implements KeyFetcher
|
// FetchKeys implements KeyFetcher
|
||||||
func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
|
func (d *DirectKeyFetcher) FetchKeys(
|
||||||
|
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||||
|
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||||
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{}
|
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{}
|
||||||
for req, ts := range requests {
|
for req, ts := range requests {
|
||||||
server := byServer[req.ServerName]
|
server := byServer[req.ServerName]
|
||||||
|
@ -283,7 +290,7 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
|
||||||
results := map[PublicKeyRequest]ServerKeys{}
|
results := map[PublicKeyRequest]ServerKeys{}
|
||||||
for server, reqs := range byServer {
|
for server, reqs := range byServer {
|
||||||
// TODO: make these requests in parallel
|
// TODO: make these requests in parallel
|
||||||
serverResults, err := d.fetchKeysForServer(server, reqs)
|
serverResults, err := d.fetchKeysForServer(ctx, server, reqs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: Should we actually be erroring here? or should we just drop those keys from the result map?
|
// TODO: Should we actually be erroring here? or should we just drop those keys from the result map?
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -296,9 +303,9 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DirectKeyFetcher) fetchKeysForServer(
|
func (d *DirectKeyFetcher) fetchKeysForServer(
|
||||||
serverName ServerName, requests map[PublicKeyRequest]Timestamp,
|
ctx context.Context, serverName ServerName, requests map[PublicKeyRequest]Timestamp,
|
||||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||||
results, err := d.Client.LookupServerKeys(serverName, requests)
|
results, err := d.Client.LookupServerKeys(ctx, serverName, requests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gomatrixserverlib
|
package gomatrixserverlib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -36,7 +37,9 @@ var testKeys = `{
|
||||||
|
|
||||||
type testKeyDatabase struct{}
|
type testKeyDatabase struct{}
|
||||||
|
|
||||||
func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
|
func (db *testKeyDatabase) FetchKeys(
|
||||||
|
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||||
|
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||||
results := map[PublicKeyRequest]ServerKeys{}
|
results := map[PublicKeyRequest]ServerKeys{}
|
||||||
var keys ServerKeys
|
var keys ServerKeys
|
||||||
if err := json.Unmarshal([]byte(testKeys), &keys); err != nil {
|
if err := json.Unmarshal([]byte(testKeys), &keys); err != nil {
|
||||||
|
@ -54,14 +57,16 @@ func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *testKeyDatabase) StoreKeys(requests map[PublicKeyRequest]ServerKeys) error {
|
func (db *testKeyDatabase) StoreKeys(
|
||||||
|
ctx context.Context, requests map[PublicKeyRequest]ServerKeys,
|
||||||
|
) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyJSONsSuccess(t *testing.T) {
|
func TestVerifyJSONsSuccess(t *testing.T) {
|
||||||
// Check that trying to verify the server key JSON works.
|
// Check that trying to verify the server key JSON works.
|
||||||
k := KeyRing{nil, &testKeyDatabase{}}
|
k := KeyRing{nil, &testKeyDatabase{}}
|
||||||
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
|
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
|
||||||
ServerName: "localhost:8800",
|
ServerName: "localhost:8800",
|
||||||
Message: []byte(testKeys),
|
Message: []byte(testKeys),
|
||||||
AtTS: 1493142432964,
|
AtTS: 1493142432964,
|
||||||
|
@ -77,7 +82,7 @@ func TestVerifyJSONsSuccess(t *testing.T) {
|
||||||
func TestVerifyJSONsUnknownServerFails(t *testing.T) {
|
func TestVerifyJSONsUnknownServerFails(t *testing.T) {
|
||||||
// Check that trying to verify JSON for an unknown server fails.
|
// Check that trying to verify JSON for an unknown server fails.
|
||||||
k := KeyRing{nil, &testKeyDatabase{}}
|
k := KeyRing{nil, &testKeyDatabase{}}
|
||||||
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
|
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
|
||||||
ServerName: "unknown:8800",
|
ServerName: "unknown:8800",
|
||||||
Message: []byte(testKeys),
|
Message: []byte(testKeys),
|
||||||
AtTS: 1493142432964,
|
AtTS: 1493142432964,
|
||||||
|
@ -94,7 +99,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
|
||||||
// Check that trying to verify JSON from the distant future fails.
|
// Check that trying to verify JSON from the distant future fails.
|
||||||
distantFuture := Timestamp(2000000000000)
|
distantFuture := Timestamp(2000000000000)
|
||||||
k := KeyRing{nil, &testKeyDatabase{}}
|
k := KeyRing{nil, &testKeyDatabase{}}
|
||||||
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
|
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
|
||||||
ServerName: "unknown:8800",
|
ServerName: "unknown:8800",
|
||||||
Message: []byte(testKeys),
|
Message: []byte(testKeys),
|
||||||
AtTS: distantFuture,
|
AtTS: distantFuture,
|
||||||
|
@ -110,7 +115,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
|
||||||
func TestVerifyJSONsFetcherError(t *testing.T) {
|
func TestVerifyJSONsFetcherError(t *testing.T) {
|
||||||
// Check that if the database errors then the attempt to verify JSON fails.
|
// Check that if the database errors then the attempt to verify JSON fails.
|
||||||
k := KeyRing{nil, &erroringKeyDatabase{}}
|
k := KeyRing{nil, &erroringKeyDatabase{}}
|
||||||
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
|
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
|
||||||
ServerName: "localhost:8800",
|
ServerName: "localhost:8800",
|
||||||
Message: []byte(testKeys),
|
Message: []byte(testKeys),
|
||||||
AtTS: 1493142432964,
|
AtTS: 1493142432964,
|
||||||
|
@ -129,10 +134,14 @@ func (e *erroringKeyDatabaseError) Error() string { return "An error with the ke
|
||||||
var testErrorFetch = erroringKeyDatabaseError(1)
|
var testErrorFetch = erroringKeyDatabaseError(1)
|
||||||
var testErrorStore = erroringKeyDatabaseError(2)
|
var testErrorStore = erroringKeyDatabaseError(2)
|
||||||
|
|
||||||
func (e *erroringKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
|
func (e *erroringKeyDatabase) FetchKeys(
|
||||||
|
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||||
|
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||||
return nil, &testErrorFetch
|
return nil, &testErrorFetch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *erroringKeyDatabase) StoreKeys(keys map[PublicKeyRequest]ServerKeys) error {
|
func (e *erroringKeyDatabase) StoreKeys(
|
||||||
|
ctx context.Context, keys map[PublicKeyRequest]ServerKeys,
|
||||||
|
) error {
|
||||||
return &testErrorStore
|
return &testErrorStore
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
{
|
{
|
||||||
|
"Deadline": "5m",
|
||||||
"Enable": [
|
"Enable": [
|
||||||
"vet",
|
"vet",
|
||||||
"vetshadow",
|
"vetshadow",
|
||||||
|
|
|
@ -215,7 +215,7 @@ func VerifyHTTPRequest(
|
||||||
return nil, util.MessageResponse(401, message)
|
return nil, util.MessageResponse(401, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := keys.VerifyJSONs([]VerifyJSONRequest{{
|
results, err := keys.VerifyJSONs(req.Context(), []VerifyJSONRequest{{
|
||||||
ServerName: request.Origin(),
|
ServerName: request.Origin(),
|
||||||
AtTS: AsTimestamp(now),
|
AtTS: AsTimestamp(now),
|
||||||
Message: toVerify,
|
Message: toVerify,
|
||||||
|
|
Loading…
Reference in New Issue