Add context.Context to the federation client (#225)

* Add context.Context to the federation client

* gb vendor update github.com/matrix-org/gomatrixserverlib
main
Mark Haines 2017-09-13 11:03:41 +01:00 committed by GitHub
parent 086683459f
commit 029e71828a
17 changed files with 139 additions and 72 deletions

View File

@ -66,7 +66,7 @@ func DirectoryRoom(
} }
} }
} else { } else {
resp, err = federation.LookupRoomAlias(domain, roomAlias) resp, err = federation.LookupRoomAlias(req.Context(), domain, roomAlias)
if err != nil { if err != nil {
switch x := err.(type) { switch x := err.(type) {
case gomatrix.HTTPError: case gomatrix.HTTPError:

View File

@ -136,7 +136,7 @@ func (r joinRoomReq) joinRoomByAlias(roomAlias string) util.JSONResponse {
func (r joinRoomReq) joinRoomByRemoteAlias( func (r joinRoomReq) joinRoomByRemoteAlias(
domain gomatrixserverlib.ServerName, roomAlias string, domain gomatrixserverlib.ServerName, roomAlias string,
) util.JSONResponse { ) util.JSONResponse {
resp, err := r.federation.LookupRoomAlias(domain, roomAlias) resp, err := r.federation.LookupRoomAlias(r.req.Context(), domain, roomAlias)
if err != nil { if err != nil {
switch x := err.(type) { switch x := err.(type) {
case gomatrix.HTTPError: case gomatrix.HTTPError:
@ -226,7 +226,7 @@ func (r joinRoomReq) joinRoomUsingServers(
// server was invalid this returns an error. // server was invalid this returns an error.
// Otherwise this returns a JSONResponse. // Otherwise this returns a JSONResponse.
func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) { func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) {
respMakeJoin, err := r.federation.MakeJoin(server, roomID, r.userID) respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID)
if err != nil { if err != nil {
// TODO: Check if the user was not allowed to join the room. // TODO: Check if the user was not allowed to join the room.
return nil, err return nil, err
@ -246,12 +246,12 @@ func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib
return &res, nil return &res, nil
} }
respSendJoin, err := r.federation.SendJoin(server, event) respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = respSendJoin.Check(r.keyRing, event); err != nil { if err = respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil {
return nil, err return nil, err
} }

View File

@ -15,7 +15,9 @@
package keydb package keydb
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -41,6 +43,7 @@ func NewDatabase(dataSourceName string) (*Database, error) {
// FetchKeys implements gomatrixserverlib.KeyDatabase // FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) FetchKeys( func (d *Database) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) { ) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
return d.statements.bulkSelectServerKeys(requests) return d.statements.bulkSelectServerKeys(requests)
@ -48,6 +51,7 @@ func (d *Database) FetchKeys(
// StoreKeys implements gomatrixserverlib.KeyDatabase // StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) StoreKeys( func (d *Database) StoreKeys(
ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys,
) error { ) error {
// TODO: Inserting all the keys within a single transaction may // TODO: Inserting all the keys within a single transaction may

View File

@ -76,7 +76,7 @@ func Invite(
Message: event.Redact().JSON(), Message: event.Redact().JSON(),
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
}} }}
verifyResults, err := keys.VerifyJSONs(verifyRequests) verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} }

View File

@ -15,6 +15,7 @@
package writers package writers
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -41,6 +42,7 @@ func Send(
) util.JSONResponse { ) util.JSONResponse {
t := txnReq{ t := txnReq{
context: httpReq.Context(),
query: query, query: query,
producer: producer, producer: producer,
keys: keys, keys: keys,
@ -70,6 +72,7 @@ func Send(
type txnReq struct { type txnReq struct {
gomatrixserverlib.Transaction gomatrixserverlib.Transaction
context context.Context
query api.RoomserverQueryAPI query api.RoomserverQueryAPI
producer *producers.RoomserverProducer producer *producers.RoomserverProducer
keys gomatrixserverlib.KeyRing keys gomatrixserverlib.KeyRing
@ -78,7 +81,7 @@ type txnReq struct {
func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) { func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
// Check the event signatures // Check the event signatures
if err := gomatrixserverlib.VerifyEventSignatures(t.PDUs, t.keys); err != nil { if err := gomatrixserverlib.VerifyEventSignatures(t.context, t.PDUs, t.keys); err != nil {
return nil, err return nil, err
} }
@ -110,7 +113,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
// our server so we should bail processing the transaction entirely. // our server so we should bail processing the transaction entirely.
return nil, err return nil, err
} }
results[e.EventID()] = gomatrixserverlib.PDUResult{err.Error()} results[e.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
} else { } else {
results[e.EventID()] = gomatrixserverlib.PDUResult{} results[e.EventID()] = gomatrixserverlib.PDUResult{}
} }
@ -197,12 +202,12 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event) error {
// need to fallback to /state. // need to fallback to /state.
// TODO: Attempt to fill in the gap using /get_missing_events // TODO: Attempt to fill in the gap using /get_missing_events
// TODO: Attempt to fetch the state using /state_ids and /events // TODO: Attempt to fetch the state using /state_ids and /events
state, err := t.federation.LookupState(t.Origin, e.RoomID(), e.EventID()) state, err := t.federation.LookupState(t.context, t.Origin, e.RoomID(), e.EventID())
if err != nil { if err != nil {
return err return err
} }
// Check that the returned state is valid. // Check that the returned state is valid.
if err := state.Check(t.keys); err != nil { if err := state.Check(t.context, t.keys); err != nil {
return err return err
} }
// Check that the event is allowed by the state. // Check that the event is allowed by the state.

View File

@ -15,6 +15,7 @@
package writers package writers
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -63,7 +64,9 @@ func CreateInvitesFrom3PIDInvites(
evs := []gomatrixserverlib.Event{} evs := []gomatrixserverlib.Event{}
for _, inv := range body.Invites { for _, inv := range body.Invites {
event, err := createInviteFrom3PIDInvite(queryAPI, cfg, inv, federation) event, err := createInviteFrom3PIDInvite(
req.Context(), queryAPI, cfg, inv, federation,
)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -139,7 +142,7 @@ func ExchangeThirdPartyInvite(
// Ask the requesting server to sign the newly created event so we know it // Ask the requesting server to sign the newly created event so we know it
// acknowledged it // acknowledged it
signedEvent, err := federation.SendInvite(request.Origin(), *event) signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), *event)
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} }
@ -160,8 +163,8 @@ func ExchangeThirdPartyInvite(
// Returns an error if there was a problem building the event or fetching the // Returns an error if there was a problem building the event or fetching the
// necessary data to do so. // necessary data to do so.
func createInviteFrom3PIDInvite( func createInviteFrom3PIDInvite(
queryAPI api.RoomserverQueryAPI, cfg config.Dendrite, inv invite, ctx context.Context, queryAPI api.RoomserverQueryAPI, cfg config.Dendrite,
federation *gomatrixserverlib.FederationClient, inv invite, federation *gomatrixserverlib.FederationClient,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
// Build the event // Build the event
builder := &gomatrixserverlib.EventBuilder{ builder := &gomatrixserverlib.EventBuilder{
@ -185,7 +188,10 @@ func createInviteFrom3PIDInvite(
event, err := buildMembershipEvent(builder, queryAPI, cfg) event, err := buildMembershipEvent(builder, queryAPI, cfg)
if err == errNotInRoom { if err == errNotInRoom {
return nil, sendToRemoteServer(inv, federation, cfg, *builder) return nil, sendToRemoteServer(ctx, inv, federation, cfg, *builder)
}
if err != nil {
return nil, err
} }
return event, nil return event, nil
@ -253,7 +259,8 @@ func buildMembershipEvent(
// Returns an error if it couldn't get the server names to reach or if all of // Returns an error if it couldn't get the server names to reach or if all of
// them responded with an error. // them responded with an error.
func sendToRemoteServer( func sendToRemoteServer(
inv invite, federation *gomatrixserverlib.FederationClient, cfg config.Dendrite, ctx context.Context, inv invite,
federation *gomatrixserverlib.FederationClient, cfg config.Dendrite,
builder gomatrixserverlib.EventBuilder, builder gomatrixserverlib.EventBuilder,
) (err error) { ) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2) remoteServers := make([]gomatrixserverlib.ServerName, 2)
@ -269,7 +276,7 @@ func sendToRemoteServer(
} }
for _, server := range remoteServers { for _, server := range remoteServers {
err = federation.ExchangeThirdPartyInvite(server, builder) err = federation.ExchangeThirdPartyInvite(ctx, server, builder)
if err == nil { if err == nil {
return return
} }

View File

@ -15,6 +15,7 @@
package queue package queue
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@ -65,7 +66,7 @@ func (oq *destinationQueue) backgroundSend() {
// TODO: handle retries. // TODO: handle retries.
// TODO: blacklist uncooperative servers. // TODO: blacklist uncooperative servers.
_, err := oq.client.SendTransaction(*t) _, err := oq.client.SendTransaction(context.TODO(), *t)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"destination": oq.destination, "destination": oq.destination,

2
vendor/manifest vendored
View File

@ -116,7 +116,7 @@
{ {
"importpath": "github.com/matrix-org/gomatrixserverlib", "importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib", "repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "790f02e8f465552dab4317ffe7ca047ccb594cbf", "revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e",
"branch": "master" "branch": "master"
}, },
{ {

View File

@ -17,6 +17,7 @@ package gomatrixserverlib
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -103,7 +104,9 @@ func (f *federationTripper) RoundTrip(r *http.Request) (*http.Response, error) {
// LookupUserInfo gets information about a user from a given matrix homeserver // LookupUserInfo gets information about a user from a given matrix homeserver
// using a bearer access token. // using a bearer access token.
func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserInfo, err error) { func (fc *Client) LookupUserInfo(
ctx context.Context, matrixServer ServerName, token string,
) (u UserInfo, err error) {
url := url.URL{ url := url.URL{
Scheme: "matrix", Scheme: "matrix",
Host: string(matrixServer), Host: string(matrixServer),
@ -111,8 +114,13 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
RawQuery: url.Values{"access_token": []string{token}}.Encode(), RawQuery: url.Values{"access_token": []string{token}}.Encode(),
} }
req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
return
}
var response *http.Response var response *http.Response
response, err = fc.client.Get(url.String()) response, err = fc.client.Do(req.WithContext(ctx))
if response != nil { if response != nil {
defer response.Body.Close() // nolint: errcheck defer response.Body.Close() // nolint: errcheck
} }
@ -153,7 +161,7 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
// copy of the keys. // copy of the keys.
// Returns the keys or an error if there was a problem talking to the server. // Returns the keys or an error if there was a problem talking to the server.
func (fc *Client) LookupServerKeys( // nolint: gocyclo func (fc *Client) LookupServerKeys( // nolint: gocyclo
matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp, ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) { ) (map[PublicKeyRequest]ServerKeys, error) {
url := url.URL{ url := url.URL{
Scheme: "matrix", Scheme: "matrix",
@ -183,7 +191,13 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
return nil, err return nil, err
} }
response, err := fc.client.Post(url.String(), "application/json", bytes.NewBuffer(requestBytes)) req, err := http.NewRequest("POST", url.String(), bytes.NewBuffer(requestBytes))
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
response, err := fc.client.Do(req.WithContext(ctx))
if response != nil { if response != nil {
defer response.Body.Close() // nolint: errcheck defer response.Body.Close() // nolint: errcheck
} }

View File

@ -17,6 +17,7 @@ package gomatrixserverlib
import ( import (
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -188,7 +189,7 @@ func verifyEventSignature(signingName string, keyID KeyID, publicKey ed25519.Pub
// VerifyEventSignatures checks that each event in a list of events has valid // VerifyEventSignatures checks that each event in a list of events has valid
// signatures from the server that sent it. // signatures from the server that sent it.
func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: gocyclo func VerifyEventSignatures(ctx context.Context, events []Event, keyRing KeyRing) error { // nolint: gocyclo
var toVerify []VerifyJSONRequest var toVerify []VerifyJSONRequest
for _, event := range events { for _, event := range events {
redactedJSON, err := redactEvent(event.eventJSON) redactedJSON, err := redactEvent(event.eventJSON)
@ -222,7 +223,7 @@ func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: g
} }
} }
results, err := keyRing.VerifyJSONs(toVerify) results, err := keyRing.VerifyJSONs(ctx, toVerify)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package gomatrixserverlib package gomatrixserverlib
import ( import (
"context"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -31,7 +32,7 @@ func NewFederationClient(
} }
} }
func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{}) error { func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest, resBody interface{}) error {
if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil { if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil {
return err return err
} }
@ -41,7 +42,7 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
return err return err
} }
res, err := ac.client.Do(req) res, err := ac.client.Do(req.WithContext(ctx))
if res != nil { if res != nil {
defer res.Body.Close() // nolint: errcheck defer res.Body.Close() // nolint: errcheck
} }
@ -87,13 +88,15 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
var federationPathPrefix = "/_matrix/federation/v1" var federationPathPrefix = "/_matrix/federation/v1"
// SendTransaction sends a transaction // SendTransaction sends a transaction
func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err error) { func (ac *FederationClient) SendTransaction(
ctx context.Context, t Transaction,
) (res RespSend, err error) {
path := federationPathPrefix + "/send/" + string(t.TransactionID) + "/" path := federationPathPrefix + "/send/" + string(t.TransactionID) + "/"
req := NewFederationRequest("PUT", t.Destination, path) req := NewFederationRequest("PUT", t.Destination, path)
if err = req.SetContent(t); err != nil { if err = req.SetContent(t); err != nil {
return return
} }
err = ac.doRequest(req, &res) err = ac.doRequest(ctx, req, &res)
return return
} }
@ -106,12 +109,14 @@ func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err er
// If this successfully returns an acceptable event we will sign it with our // If this successfully returns an acceptable event we will sign it with our
// server's key and pass it to SendJoin. // server's key and pass it to SendJoin.
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms // See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res RespMakeJoin, err error) { func (ac *FederationClient) MakeJoin(
ctx context.Context, s ServerName, roomID, userID string,
) (res RespMakeJoin, err error) {
path := federationPathPrefix + "/make_join/" + path := federationPathPrefix + "/make_join/" +
url.PathEscape(roomID) + "/" + url.PathEscape(roomID) + "/" +
url.PathEscape(userID) url.PathEscape(userID)
req := NewFederationRequest("GET", s, path) req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res) err = ac.doRequest(ctx, req, &res)
return return
} }
@ -119,7 +124,9 @@ func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res R
// remote matrix server. // remote matrix server.
// This is used to join a room the local server isn't a member of. // This is used to join a room the local server isn't a member of.
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms // See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoin, err error) { func (ac *FederationClient) SendJoin(
ctx context.Context, s ServerName, event Event,
) (res RespSendJoin, err error) {
path := federationPathPrefix + "/send_join/" + path := federationPathPrefix + "/send_join/" +
url.PathEscape(event.RoomID()) + "/" + url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.EventID()) url.PathEscape(event.EventID())
@ -127,13 +134,15 @@ func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoi
if err = req.SetContent(event); err != nil { if err = req.SetContent(event); err != nil {
return return
} }
err = ac.doRequest(req, &res) err = ac.doRequest(ctx, req, &res)
return return
} }
// SendInvite sends an invite m.room.member event to an invited server to be // SendInvite sends an invite m.room.member event to an invited server to be
// signed by it. This is used to invite a user that is not on the local server. // signed by it. This is used to invite a user that is not on the local server.
func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvite, err error) { func (ac *FederationClient) SendInvite(
ctx context.Context, s ServerName, event Event,
) (res RespInvite, err error) {
path := federationPathPrefix + "/invite/" + path := federationPathPrefix + "/invite/" +
url.PathEscape(event.RoomID()) + "/" + url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.EventID()) url.PathEscape(event.EventID())
@ -141,7 +150,7 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
if err = req.SetContent(event); err != nil { if err = req.SetContent(event); err != nil {
return return
} }
err = ac.doRequest(req, &res) err = ac.doRequest(ctx, req, &res)
return return
} }
@ -150,38 +159,44 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
// server. // server.
// This is used to exchange a m.room.third_party_invite event for a m.room.member // This is used to exchange a m.room.third_party_invite event for a m.room.member
// one in a room the local server isn't a member of. // one in a room the local server isn't a member of.
func (ac *FederationClient) ExchangeThirdPartyInvite(s ServerName, builder EventBuilder) (err error) { func (ac *FederationClient) ExchangeThirdPartyInvite(
ctx context.Context, s ServerName, builder EventBuilder,
) (err error) {
path := federationPathPrefix + "/exchange_third_party_invite/" + path := federationPathPrefix + "/exchange_third_party_invite/" +
url.PathEscape(builder.RoomID) url.PathEscape(builder.RoomID)
req := NewFederationRequest("PUT", s, path) req := NewFederationRequest("PUT", s, path)
if err = req.SetContent(builder); err != nil { if err = req.SetContent(builder); err != nil {
return return
} }
err = ac.doRequest(req, nil) err = ac.doRequest(ctx, req, nil)
return return
} }
// LookupState retrieves the room state for a room at an event from a // LookupState retrieves the room state for a room at an event from a
// remote matrix server as full matrix events. // remote matrix server as full matrix events.
func (ac *FederationClient) LookupState(s ServerName, roomID, eventID string) (res RespState, err error) { func (ac *FederationClient) LookupState(
ctx context.Context, s ServerName, roomID, eventID string,
) (res RespState, err error) {
path := federationPathPrefix + "/state/" + path := federationPathPrefix + "/state/" +
url.PathEscape(roomID) + url.PathEscape(roomID) +
"/?event_id=" + "/?event_id=" +
url.QueryEscape(eventID) url.QueryEscape(eventID)
req := NewFederationRequest("GET", s, path) req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res) err = ac.doRequest(ctx, req, &res)
return return
} }
// LookupStateIDs retrieves the room state for a room at an event from a // LookupStateIDs retrieves the room state for a room at an event from a
// remote matrix server as lists of matrix event IDs. // remote matrix server as lists of matrix event IDs.
func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string) (res RespStateIDs, err error) { func (ac *FederationClient) LookupStateIDs(
ctx context.Context, s ServerName, roomID, eventID string,
) (res RespStateIDs, err error) {
path := federationPathPrefix + "/state_ids/" + path := federationPathPrefix + "/state_ids/" +
url.PathEscape(roomID) + url.PathEscape(roomID) +
"/?event_id=" + "/?event_id=" +
url.QueryEscape(eventID) url.QueryEscape(eventID)
req := NewFederationRequest("GET", s, path) req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res) err = ac.doRequest(ctx, req, &res)
return return
} }
@ -190,10 +205,12 @@ func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string)
// being looked up on. // being looked up on.
// If the room alias doesn't exist on the remote server then a 404 gomatrix.HTTPError // If the room alias doesn't exist on the remote server then a 404 gomatrix.HTTPError
// is returned. // is returned.
func (ac *FederationClient) LookupRoomAlias(s ServerName, roomAlias string) (res RespDirectory, err error) { func (ac *FederationClient) LookupRoomAlias(
ctx context.Context, s ServerName, roomAlias string,
) (res RespDirectory, err error) {
path := federationPathPrefix + "/query/directory?room_alias=" + path := federationPathPrefix + "/query/directory?room_alias=" +
url.QueryEscape(roomAlias) url.QueryEscape(roomAlias)
req := NewFederationRequest("GET", s, path) req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res) err = ac.doRequest(ctx, req, &res)
return return
} }

View File

@ -1,6 +1,7 @@
package gomatrixserverlib package gomatrixserverlib
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
) )
@ -107,7 +108,7 @@ func (r RespState) Events() ([]Event, error) {
} }
// Check that a response to /state is valid. // Check that a response to /state is valid.
func (r RespState) Check(keyRing KeyRing) error { func (r RespState) Check(ctx context.Context, keyRing KeyRing) error {
var allEvents []Event var allEvents []Event
for _, event := range r.AuthEvents { for _, event := range r.AuthEvents {
if event.StateKey() == nil { if event.StateKey() == nil {
@ -133,7 +134,7 @@ func (r RespState) Check(keyRing KeyRing) error {
} }
// Check if the events pass signature checks. // Check if the events pass signature checks.
if err := VerifyEventSignatures(allEvents, keyRing); err != nil { if err := VerifyEventSignatures(ctx, allEvents, keyRing); err != nil {
return nil return nil
} }
@ -213,11 +214,11 @@ type respSendJoinFields struct {
// Check that a response to /send_join is valid. // Check that a response to /send_join is valid.
// This checks that it would be valid as a response to /state // This checks that it would be valid as a response to /state
// This also checks that the join event is allowed by the state. // This also checks that the join event is allowed by the state.
func (r RespSendJoin) Check(keyRing KeyRing, joinEvent Event) error { func (r RespSendJoin) Check(ctx context.Context, keyRing KeyRing, joinEvent Event) error {
// First check that the state is valid. // First check that the state is valid.
// The response to /send_join has the same data as a response to /state // The response to /send_join has the same data as a response to /state
// and the checks for a response to /state also apply. // and the checks for a response to /state also apply.
if err := RespState(r).Check(keyRing); err != nil { if err := RespState(r).Check(ctx, keyRing); err != nil {
return err return err
} }

View File

@ -6,13 +6,13 @@ echo "Installing lint search engine..."
go get github.com/alecthomas/gometalinter/ go get github.com/alecthomas/gometalinter/
gometalinter --config=linter.json --install --update gometalinter --config=linter.json --install --update
echo "Testing..."
go test
echo "Looking for lint..." echo "Looking for lint..."
gometalinter --config=linter.json gometalinter --config=linter.json
echo "Double checking spelling..." echo "Double checking spelling..."
misspell -error src *.md misspell -error src *.md
echo "Testing..."
go test
echo "Done!" echo "Done!"

View File

@ -1,6 +1,7 @@
package gomatrixserverlib package gomatrixserverlib
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -26,7 +27,7 @@ type KeyFetcher interface {
// The result may have fewer (server name, key ID) pairs than were in the request. // The result may have fewer (server name, key ID) pairs than were in the request.
// The result may have more (server name, key ID) pairs than were in the request. // The result may have more (server name, key ID) pairs than were in the request.
// Returns an error if there was a problem fetching the keys. // Returns an error if there was a problem fetching the keys.
FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
} }
// A KeyDatabase is a store for caching public keys. // A KeyDatabase is a store for caching public keys.
@ -39,7 +40,7 @@ type KeyDatabase interface {
// to a concurrent FetchKeys(). This is acceptable since the database is // to a concurrent FetchKeys(). This is acceptable since the database is
// only used as a cache for the keys, so if a FetchKeys() races with a // only used as a cache for the keys, so if a FetchKeys() races with a
// StoreKeys() and some of the keys are missing they will be just be refetched. // StoreKeys() and some of the keys are missing they will be just be refetched.
StoreKeys(map[PublicKeyRequest]ServerKeys) error StoreKeys(ctx context.Context, results map[PublicKeyRequest]ServerKeys) error
} }
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages. // A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
@ -73,7 +74,7 @@ type VerifyJSONResult struct {
// The caller should check the Result field for each entry to see if it was valid. // The caller should check the Result field for each entry to see if it was valid.
// Returns an error if there was a problem talking to the database or one of the other methods // Returns an error if there was a problem talking to the database or one of the other methods
// of fetching the public keys. // of fetching the public keys.
func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo func (k *KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
results := make([]VerifyJSONResult, len(requests)) results := make([]VerifyJSONResult, len(requests))
keyIDs := make([][]KeyID, len(requests)) keyIDs := make([][]KeyID, len(requests))
@ -109,7 +110,7 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
// This will happen if all the objects are missing supported signatures. // This will happen if all the objects are missing supported signatures.
return results, nil return results, nil
} }
keysFromDatabase, err := k.KeyDatabase.FetchKeys(keyRequests) keysFromDatabase, err := k.KeyDatabase.FetchKeys(ctx, keyRequests)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,14 +125,14 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
} }
// TODO: Coalesce in-flight requests for the same keys. // TODO: Coalesce in-flight requests for the same keys.
// Otherwise we risk spamming the servers we query the keys from. // Otherwise we risk spamming the servers we query the keys from.
keysFetched, err := k.KeyFetchers[i].FetchKeys(keyRequests) keysFetched, err := k.KeyFetchers[i].FetchKeys(ctx, keyRequests)
if err != nil { if err != nil {
return nil, err return nil, err
} }
k.checkUsingKeys(requests, results, keyIDs, keysFetched) k.checkUsingKeys(requests, results, keyIDs, keysFetched)
// Add the keys to the database so that we won't need to fetch them again. // Add the keys to the database so that we won't need to fetch them again.
if err := k.KeyDatabase.StoreKeys(keysFetched); err != nil { if err := k.KeyDatabase.StoreKeys(ctx, keysFetched); err != nil {
return nil, err return nil, err
} }
} }
@ -143,7 +144,9 @@ func (k *KeyRing) isAlgorithmSupported(keyID KeyID) bool {
return strings.HasPrefix(string(keyID), "ed25519:") return strings.HasPrefix(string(keyID), "ed25519:")
} }
func (k *KeyRing) publicKeyRequests(requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID) map[PublicKeyRequest]Timestamp { func (k *KeyRing) publicKeyRequests(
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
) map[PublicKeyRequest]Timestamp {
keyRequests := map[PublicKeyRequest]Timestamp{} keyRequests := map[PublicKeyRequest]Timestamp{}
for i := range requests { for i := range requests {
if results[i].Error == nil { if results[i].Error == nil {
@ -218,8 +221,10 @@ type PerspectiveKeyFetcher struct {
} }
// FetchKeys implements KeyFetcher // FetchKeys implements KeyFetcher
func (p *PerspectiveKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) { func (p *PerspectiveKeyFetcher) FetchKeys(
results, err := p.Client.LookupServerKeys(p.PerspectiveServerName, requests) ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
results, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -269,7 +274,9 @@ type DirectKeyFetcher struct {
} }
// FetchKeys implements KeyFetcher // FetchKeys implements KeyFetcher
func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) { func (d *DirectKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{} byServer := map[ServerName]map[PublicKeyRequest]Timestamp{}
for req, ts := range requests { for req, ts := range requests {
server := byServer[req.ServerName] server := byServer[req.ServerName]
@ -283,7 +290,7 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
results := map[PublicKeyRequest]ServerKeys{} results := map[PublicKeyRequest]ServerKeys{}
for server, reqs := range byServer { for server, reqs := range byServer {
// TODO: make these requests in parallel // TODO: make these requests in parallel
serverResults, err := d.fetchKeysForServer(server, reqs) serverResults, err := d.fetchKeysForServer(ctx, server, reqs)
if err != nil { if err != nil {
// TODO: Should we actually be erroring here? or should we just drop those keys from the result map? // TODO: Should we actually be erroring here? or should we just drop those keys from the result map?
return nil, err return nil, err
@ -296,9 +303,9 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
} }
func (d *DirectKeyFetcher) fetchKeysForServer( func (d *DirectKeyFetcher) fetchKeysForServer(
serverName ServerName, requests map[PublicKeyRequest]Timestamp, ctx context.Context, serverName ServerName, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) { ) (map[PublicKeyRequest]ServerKeys, error) {
results, err := d.Client.LookupServerKeys(serverName, requests) results, err := d.Client.LookupServerKeys(ctx, serverName, requests)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package gomatrixserverlib package gomatrixserverlib
import ( import (
"context"
"encoding/json" "encoding/json"
"testing" "testing"
) )
@ -36,7 +37,9 @@ var testKeys = `{
type testKeyDatabase struct{} type testKeyDatabase struct{}
func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) { func (db *testKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
results := map[PublicKeyRequest]ServerKeys{} results := map[PublicKeyRequest]ServerKeys{}
var keys ServerKeys var keys ServerKeys
if err := json.Unmarshal([]byte(testKeys), &keys); err != nil { if err := json.Unmarshal([]byte(testKeys), &keys); err != nil {
@ -54,14 +57,16 @@ func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
return results, nil return results, nil
} }
func (db *testKeyDatabase) StoreKeys(requests map[PublicKeyRequest]ServerKeys) error { func (db *testKeyDatabase) StoreKeys(
ctx context.Context, requests map[PublicKeyRequest]ServerKeys,
) error {
return nil return nil
} }
func TestVerifyJSONsSuccess(t *testing.T) { func TestVerifyJSONsSuccess(t *testing.T) {
// Check that trying to verify the server key JSON works. // Check that trying to verify the server key JSON works.
k := KeyRing{nil, &testKeyDatabase{}} k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{ results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "localhost:8800", ServerName: "localhost:8800",
Message: []byte(testKeys), Message: []byte(testKeys),
AtTS: 1493142432964, AtTS: 1493142432964,
@ -77,7 +82,7 @@ func TestVerifyJSONsSuccess(t *testing.T) {
func TestVerifyJSONsUnknownServerFails(t *testing.T) { func TestVerifyJSONsUnknownServerFails(t *testing.T) {
// Check that trying to verify JSON for an unknown server fails. // Check that trying to verify JSON for an unknown server fails.
k := KeyRing{nil, &testKeyDatabase{}} k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{ results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "unknown:8800", ServerName: "unknown:8800",
Message: []byte(testKeys), Message: []byte(testKeys),
AtTS: 1493142432964, AtTS: 1493142432964,
@ -94,7 +99,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
// Check that trying to verify JSON from the distant future fails. // Check that trying to verify JSON from the distant future fails.
distantFuture := Timestamp(2000000000000) distantFuture := Timestamp(2000000000000)
k := KeyRing{nil, &testKeyDatabase{}} k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{ results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "unknown:8800", ServerName: "unknown:8800",
Message: []byte(testKeys), Message: []byte(testKeys),
AtTS: distantFuture, AtTS: distantFuture,
@ -110,7 +115,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
func TestVerifyJSONsFetcherError(t *testing.T) { func TestVerifyJSONsFetcherError(t *testing.T) {
// Check that if the database errors then the attempt to verify JSON fails. // Check that if the database errors then the attempt to verify JSON fails.
k := KeyRing{nil, &erroringKeyDatabase{}} k := KeyRing{nil, &erroringKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{ results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "localhost:8800", ServerName: "localhost:8800",
Message: []byte(testKeys), Message: []byte(testKeys),
AtTS: 1493142432964, AtTS: 1493142432964,
@ -129,10 +134,14 @@ func (e *erroringKeyDatabaseError) Error() string { return "An error with the ke
var testErrorFetch = erroringKeyDatabaseError(1) var testErrorFetch = erroringKeyDatabaseError(1)
var testErrorStore = erroringKeyDatabaseError(2) var testErrorStore = erroringKeyDatabaseError(2)
func (e *erroringKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) { func (e *erroringKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
return nil, &testErrorFetch return nil, &testErrorFetch
} }
func (e *erroringKeyDatabase) StoreKeys(keys map[PublicKeyRequest]ServerKeys) error { func (e *erroringKeyDatabase) StoreKeys(
ctx context.Context, keys map[PublicKeyRequest]ServerKeys,
) error {
return &testErrorStore return &testErrorStore
} }

View File

@ -1,4 +1,5 @@
{ {
"Deadline": "5m",
"Enable": [ "Enable": [
"vet", "vet",
"vetshadow", "vetshadow",

View File

@ -215,7 +215,7 @@ func VerifyHTTPRequest(
return nil, util.MessageResponse(401, message) return nil, util.MessageResponse(401, message)
} }
results, err := keys.VerifyJSONs([]VerifyJSONRequest{{ results, err := keys.VerifyJSONs(req.Context(), []VerifyJSONRequest{{
ServerName: request.Origin(), ServerName: request.Origin(),
AtTS: AsTimestamp(now), AtTS: AsTimestamp(now),
Message: toVerify, Message: toVerify,