diff --git a/src/github.com/matrix-org/dendrite/roomserver/api/query.go b/src/github.com/matrix-org/dendrite/roomserver/api/query.go new file mode 100644 index 00000000..a3a647e6 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/api/query.go @@ -0,0 +1,102 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/matrix-org/gomatrixserverlib" + "net/http" +) + +// StateKeyTuple is a pair of an event type and state_key. +// This is used when requesting parts of the state of a room. +type StateKeyTuple struct { + // The "type" key + EventType string + // The "state_key" of a matrix event. + // The empty string is a legitimate value for the "state_key" in matrix + // so take care to initialise this field lest you accidentally request a + // "state_key" with the go default of the empty string. + EventStateKey string +} + +// QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState +type QueryLatestEventsAndStateRequest struct { + // The roomID to query the latest events for. + RoomID string + // The state key tuples to fetch from the room current state. + // If this list is empty or nil then no state events are returned. + StateToFetch []StateKeyTuple +} + +// QueryLatestEventsAndStateResponse is a response to QueryLatestEventsAndState +type QueryLatestEventsAndStateResponse struct { + // Copy of the request for debugging. + QueryLatestEventsAndStateRequest + // Does the room exist? + // If the room doesn't exist this will be false and LatestEvents will be empty. + RoomExists bool + // The latest events in the room. + LatestEvents []gomatrixserverlib.EventReference + // The state events requested. + StateEvents []gomatrixserverlib.Event +} + +// RoomserverQueryAPI is used to query information from the room server. +type RoomserverQueryAPI interface { + // Query the latest events and state for a room from the room server. + QueryLatestEventsAndState( + request *QueryLatestEventsAndStateRequest, + response *QueryLatestEventsAndStateResponse, + ) error +} + +// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. +const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/QueryLatestEventsAndState" + +// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. +// If httpClient is nil then it uses the http.DefaultClient +func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &httpRoomserverQueryAPI{roomserverURL, *httpClient} +} + +type httpRoomserverQueryAPI struct { + roomserverURL string + httpClient http.Client +} + +// QueryLatestEventsAndState implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryLatestEventsAndState( + request *QueryLatestEventsAndStateRequest, + response *QueryLatestEventsAndStateResponse, +) error { + apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath + return postJSON(h.httpClient, apiURL, request, response) +} + +func postJSON(httpClient http.Client, apiURL string, request, response interface{}) error { + jsonBytes, err := json.Marshal(request) + if err != nil { + return err + } + res, err := httpClient.Post(apiURL, "application/json", bytes.NewReader(jsonBytes)) + if res != nil { + defer res.Body.Close() + } + if err != nil { + return err + } + if res.StatusCode != 200 { + var errorBody struct { + Message string `json:"message"` + } + if err = json.NewDecoder(res.Body).Decode(&errorBody); err != nil { + return err + } + return fmt.Errorf("api: %d: %s", res.StatusCode, errorBody.Message) + } + return json.NewDecoder(res.Body).Decode(response) +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/query/query.go b/src/github.com/matrix-org/dendrite/roomserver/query/query.go new file mode 100644 index 00000000..0be51886 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/query/query.go @@ -0,0 +1,68 @@ +package query + +import ( + "encoding/json" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "net/http" +) + +// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API. +type RoomserverQueryAPIDatabase interface { + // Lookup the numeric ID for the room. + // Returns 0 if the room doesn't exists. + // Returns an error if there was a problem talking to the database. + RoomNID(roomID string) (types.RoomNID, error) + // Lookup event references for the latest events in the room. + // Returns an error if there was a problem talking to the database. + LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error) +} + +// RoomserverQueryAPI is an implementation of RoomserverQueryAPI +type RoomserverQueryAPI struct { + DB RoomserverQueryAPIDatabase +} + +// QueryLatestEventsAndState implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryLatestEventsAndState( + request *api.QueryLatestEventsAndStateRequest, + response *api.QueryLatestEventsAndStateResponse, +) (err error) { + response.QueryLatestEventsAndStateRequest = *request + roomNID, err := r.DB.RoomNID(request.RoomID) + if err != nil { + return err + } + if roomNID == 0 { + return nil + } + response.RoomExists = true + response.LatestEvents, err = r.DB.LatestEventIDs(roomNID) + // TODO: look up the current state. + return err +} + +// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. +func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { + servMux.Handle( + api.RoomserverQueryLatestEventsAndStatePath, + makeAPI("query_latest_events_and_state", func(req *http.Request) util.JSONResponse { + var request api.QueryLatestEventsAndStateRequest + var response api.QueryLatestEventsAndStateResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryLatestEventsAndState(&request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: 200, JSON: &response} + }), + ) +} + +func makeAPI(metric string, apiFunc func(req *http.Request) util.JSONResponse) http.Handler { + return prometheus.InstrumentHandler(metric, util.MakeJSONAPI(util.NewJSONRequestHandler(apiFunc))) +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go b/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go index 06f22a2a..b8766ca8 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go +++ b/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "os" "os/exec" @@ -17,8 +18,12 @@ var ( zookeeperURI = defaulting(os.Getenv("ZOOKEEPER_URI"), "localhost:2181") // The URI the kafka server is listening on. kafkaURI = defaulting(os.Getenv("KAFKA_URIS"), "localhost:9092") + // The address the roomserver should listen on. + roomserverAddr = defaulting(os.Getenv("ROOMSERVER_URI"), "localhost:9876") // How long to wait for the roomserver to write the expected output messages. - timeoutString = defaulting(os.Getenv("TIMEOUT"), "10s") + // This needs to be high enough to account for the time it takes to create + // the postgres database tables which can take a while on travis. + timeoutString = defaulting(os.Getenv("TIMEOUT"), "60s") // The name of maintenance database to connect to in order to create the test database. postgresDatabase = defaulting(os.Getenv("POSTGRES_DATABASE"), "postgres") // The name of the test database to create. @@ -91,7 +96,7 @@ func writeToTopic(topic string, data []string) error { // messages is reached or after a timeout. It kills the command before it returns. // It returns a list of the messages read from the command on success or an error // on failure. -func runAndReadFromTopic(runCmd *exec.Cmd, topic string, count int) ([]string, error) { +func runAndReadFromTopic(runCmd *exec.Cmd, topic string, count int, checkQueryAPI func()) ([]string, error) { type result struct { // data holds all of stdout on success. data []byte @@ -111,7 +116,17 @@ func runAndReadFromTopic(runCmd *exec.Cmd, topic string, count int) ([]string, e // Run the command, read the messages and wait for a timeout in parallel. go func() { // Read all of stdout. + defer func() { + if err := recover(); err != nil { + if errv, ok := err.(error); ok { + done <- result{nil, errv} + } else { + panic(err) + } + } + }() data, err := readCmd.Output() + checkQueryAPI() done <- result{data, err} }() go func() { @@ -157,7 +172,16 @@ func deleteTopic(topic string) error { return cmd.Run() } -func testRoomServer(input []string, wantOutput []string) { +// testRoomserver is used to run integration tests against a single roomserver. +// It creates new kafka topics for the input and output of the roomserver. +// It writes the input messages to the input kafka topic, formatting each message +// as canonical JSON so that it fits on a single line. +// It then runs the roomserver and waits for a number of messages to be written +// to the output topic. +// Once those messages have been written it runs the checkQueries function passing +// a api.RoomserverQueryAPI client. The caller can use this function to check the +// behaviour of the query API. +func testRoomserver(input []string, wantOutput []string, checkQueries func(api.RoomserverQueryAPI)) { const ( inputTopic = "roomserverInput" outputTopic = "roomserverOutput" @@ -191,10 +215,14 @@ func testRoomServer(input []string, wantOutput []string) { fmt.Sprintf("KAFKA_URIS=%s", kafkaURI), fmt.Sprintf("TOPIC_INPUT_ROOM_EVENT=%s", inputTopic), fmt.Sprintf("TOPIC_OUTPUT_ROOM_EVENT=%s", outputTopic), + fmt.Sprintf("BIND_ADDRESS=%s", roomserverAddr), ) cmd.Stderr = os.Stderr - gotOutput, err := runAndReadFromTopic(cmd, outputTopic, 1) + gotOutput, err := runAndReadFromTopic(cmd, outputTopic, len(wantOutput), func() { + queryAPI := api.NewRoomserverQueryAPIHTTP("http://"+roomserverAddr, nil) + checkQueries(queryAPI) + }) if err != nil { panic(err) } @@ -334,7 +362,21 @@ func main() { }`, } - testRoomServer(input, want) + testRoomserver(input, want, func(q api.RoomserverQueryAPI) { + var response api.QueryLatestEventsAndStateResponse + if err := q.QueryLatestEventsAndState( + &api.QueryLatestEventsAndStateRequest{RoomID: "!HCXfdvrfksxuYnIFiJ:matrix.org"}, + &response, + ); err != nil { + panic(err) + } + if !response.RoomExists { + panic(fmt.Errorf(`Wanted room "!HCXfdvrfksxuYnIFiJ:matrix.org" to exist`)) + } + if len(response.LatestEvents) != 1 || response.LatestEvents[0].EventID != "$1463671339126270PnVwC:matrix.org" { + panic(fmt.Errorf(`Wanted "$1463671339126270PnVwC:matrix.org" to be the latest event got %#v`, response.LatestEvents)) + } + }) fmt.Println("==PASSED==", os.Args[0]) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go b/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go index d2f126bf..689fb48d 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go +++ b/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go @@ -3,8 +3,10 @@ package main import ( "fmt" "github.com/matrix-org/dendrite/roomserver/input" + "github.com/matrix-org/dendrite/roomserver/query" "github.com/matrix-org/dendrite/roomserver/storage" sarama "gopkg.in/Shopify/sarama.v1" + "net/http" "os" "strings" ) @@ -14,6 +16,7 @@ var ( kafkaURIs = strings.Split(os.Getenv("KAFKA_URIS"), ",") inputRoomEventTopic = os.Getenv("TOPIC_INPUT_ROOM_EVENT") outputRoomEventTopic = os.Getenv("TOPIC_OUTPUT_ROOM_EVENT") + bindAddr = os.Getenv("BIND_ADDRESS") ) func main() { @@ -44,9 +47,14 @@ func main() { panic(err) } + queryAPI := query.RoomserverQueryAPI{ + DB: db, + } + + queryAPI.SetupHTTP(http.DefaultServeMux) + fmt.Println("Started roomserver") - // Wait forever. // TODO: Implement clean shutdown. - select {} + http.ListenAndServe(bindAddr, nil) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/event_json_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/event_json_table.go index 6296d706..f7c8052c 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/event_json_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/event_json_table.go @@ -2,7 +2,6 @@ package storage import ( "database/sql" - "github.com/lib/pq" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -65,11 +64,7 @@ type eventJSONPair struct { } func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) { - nids := make([]int64, len(eventNIDs)) - for i := range eventNIDs { - nids[i] = int64(eventNIDs[i]) - } - rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(nids)) + rows, err := s.bulkSelectEventJSONStmt.Query(eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go index 2a2f7e20..471bdcfb 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/lib/pq" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" ) const eventsSchema = ` @@ -83,6 +84,9 @@ const bulkSelectStateAtEventAndReferenceSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + " FROM events WHERE event_nid = ANY($1)" +const bulkSelectEventReferenceSQL = "" + + "SELECT event_id, reference_sha256 FROM events WHERE event_nid = ANY($1)" + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -93,6 +97,7 @@ type eventStatements struct { updateEventSentToOutputStmt *sql.Stmt selectEventIDStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt } func (s *eventStatements) prepare(db *sql.DB) (err error) { @@ -127,6 +132,9 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { if s.bulkSelectStateAtEventAndReferenceStmt, err = db.Prepare(bulkSelectStateAtEventAndReferenceSQL); err != nil { return } + if s.bulkSelectEventReferenceStmt, err = db.Prepare(bulkSelectEventReferenceSQL); err != nil { + return + } return } @@ -136,15 +144,11 @@ func (s *eventStatements) insertEvent( referenceSHA256 []byte, authEventNIDs []types.EventNID, ) (types.EventNID, types.StateSnapshotNID, error) { - nids := make([]int64, len(authEventNIDs)) - for i := range authEventNIDs { - nids[i] = int64(authEventNIDs[i]) - } var eventNID int64 var stateNID int64 err := s.insertEventStmt.QueryRow( int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, - pq.Int64Array(nids), + eventNIDsAsArray(authEventNIDs), ).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } @@ -238,11 +242,7 @@ func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (e } func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) { - nids := make([]int64, len(eventNIDs)) - for i := range eventNIDs { - nids[i] = int64(eventNIDs[i]) - } - rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(pq.Int64Array(nids)) + rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } @@ -276,3 +276,31 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventN } return results, nil } + +func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) { + rows, err := s.bulkSelectEventReferenceStmt.Query(eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { + return nil, err + } + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { + nids := make([]int64, len(eventNIDs)) + for i := range eventNIDs { + nids[i] = int64(eventNIDs[i]) + } + return nids +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go index ff932344..a6be8fcb 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go @@ -32,16 +32,20 @@ const selectRoomNIDSQL = "" + "SELECT room_nid FROM rooms WHERE room_id = $1" const selectLatestEventNIDsSQL = "" + + "SELECT latest_event_nids FROM rooms WHERE room_nid = $1" + +const selectLatestEventNIDsForUpdateSQL = "" + "SELECT latest_event_nids, last_event_sent_nid FROM rooms WHERE room_nid = $1 FOR UPDATE" const updateLatestEventNIDsSQL = "" + "UPDATE rooms SET latest_event_nids = $2, last_event_sent_nid = $3 WHERE room_nid = $1" type roomStatements struct { - insertRoomNIDStmt *sql.Stmt - selectRoomNIDStmt *sql.Stmt - selectLatestEventNIDsStmt *sql.Stmt - updateLatestEventNIDsStmt *sql.Stmt + insertRoomNIDStmt *sql.Stmt + selectRoomNIDStmt *sql.Stmt + selectLatestEventNIDsStmt *sql.Stmt + selectLatestEventNIDsForUpdateStmt *sql.Stmt + updateLatestEventNIDsStmt *sql.Stmt } func (s *roomStatements) prepare(db *sql.DB) (err error) { @@ -58,6 +62,9 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { if s.selectLatestEventNIDsStmt, err = db.Prepare(selectLatestEventNIDsSQL); err != nil { return } + if s.selectLatestEventNIDsForUpdateStmt, err = db.Prepare(selectLatestEventNIDsForUpdateSQL); err != nil { + return + } if s.updateLatestEventNIDsStmt, err = db.Prepare(updateLatestEventNIDsSQL); err != nil { return } @@ -76,10 +83,23 @@ func (s *roomStatements) selectRoomNID(roomID string) (types.RoomNID, error) { return types.RoomNID(roomNID), err } +func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, error) { + var nids pq.Int64Array + err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids) + if err != nil { + return nil, err + } + eventNIDs := make([]types.EventNID, len(nids)) + for i := range nids { + eventNIDs[i] = types.EventNID(nids[i]) + } + return eventNIDs, nil +} + func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, error) { var nids pq.Int64Array var lastEventSentNID int64 - err := txn.Stmt(s.selectLatestEventNIDsStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID) + err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID) if err != nil { return nil, 0, err } @@ -91,10 +111,6 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty } func (s *roomStatements) updateLatestEventNIDs(txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID) error { - nids := make([]int64, len(eventNIDs)) - for i := range eventNIDs { - nids[i] = int64(eventNIDs[i]) - } - _, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(roomNID, pq.Int64Array(nids), int64(lastEventSentNID)) + _, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID)) return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index 9db84085..bede936a 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -280,3 +280,21 @@ func (u *roomRecentEventsUpdater) Commit() error { func (u *roomRecentEventsUpdater) Rollback() error { return u.txn.Rollback() } + +// RoomNID implements query.RoomserverQueryAPIDB +func (d *Database) RoomNID(roomID string) (types.RoomNID, error) { + roomNID, err := d.statements.selectRoomNID(roomID) + if err == sql.ErrNoRows { + return 0, nil + } + return roomNID, err +} + +// LatestEventIDs implements query.RoomserverQueryAPIDB +func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error) { + eventNIDs, err := d.statements.selectLatestEventNIDs(roomNID) + if err != nil { + return nil, err + } + return d.statements.bulkSelectEventReference(eventNIDs) +} diff --git a/vendor/manifest b/vendor/manifest index 79bac494..35065c88 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -98,7 +98,7 @@ { "importpath": "github.com/matrix-org/util", "repository": "https://github.com/matrix-org/util", - "revision": "ccef6dc7c24a7c896d96b433a9107b7c47ecf828", + "revision": "28bd7491c8aafbf346ca23821664f0f9911ef52b", "branch": "master" }, { @@ -206,4 +206,4 @@ "branch": "master" } ] -} +} \ No newline at end of file diff --git a/vendor/src/github.com/matrix-org/util/context.go b/vendor/src/github.com/matrix-org/util/context.go index d8def4f9..f2477a56 100644 --- a/vendor/src/github.com/matrix-org/util/context.go +++ b/vendor/src/github.com/matrix-org/util/context.go @@ -25,11 +25,13 @@ func GetRequestID(ctx context.Context) string { // ctxValueLogger is the key to extract the logrus Logger. const ctxValueLogger = contextKeys("logger") -// GetLogger retrieves the logrus logger from the supplied context. Returns nil if there is no logger. +// GetLogger retrieves the logrus logger from the supplied context. Always returns a logger, +// even if there wasn't one originally supplied. func GetLogger(ctx context.Context) *log.Entry { l := ctx.Value(ctxValueLogger) if l == nil { - return nil + // Always return a logger so callers don't need to constantly nil check. + return log.WithField("context", "missing") } return l.(*log.Entry) } diff --git a/vendor/src/github.com/matrix-org/util/json.go b/vendor/src/github.com/matrix-org/util/json.go index b0834eac..46c5396f 100644 --- a/vendor/src/github.com/matrix-org/util/json.go +++ b/vendor/src/github.com/matrix-org/util/json.go @@ -58,6 +58,21 @@ type JSONRequestHandler interface { OnIncomingRequest(req *http.Request) JSONResponse } +// jsonRequestHandlerWrapper is a wrapper to allow in-line functions to conform to util.JSONRequestHandler +type jsonRequestHandlerWrapper struct { + function func(req *http.Request) JSONResponse +} + +// OnIncomingRequest implements util.JSONRequestHandler +func (r *jsonRequestHandlerWrapper) OnIncomingRequest(req *http.Request) JSONResponse { + return r.function(req) +} + +// NewJSONRequestHandler converts the given OnIncomingRequest function into a JSONRequestHandler +func NewJSONRequestHandler(f func(req *http.Request) JSONResponse) JSONRequestHandler { + return &jsonRequestHandlerWrapper{f} +} + // Protect panicking HTTP requests from taking down the entire process, and log them using // the correct logger, returning a 500 with a JSON response rather than abruptly closing the // connection. The http.Request MUST have a ctxValueLogger. diff --git a/vendor/src/github.com/matrix-org/util/json_test.go b/vendor/src/github.com/matrix-org/util/json_test.go index 687db277..3ce03a88 100644 --- a/vendor/src/github.com/matrix-org/util/json_test.go +++ b/vendor/src/github.com/matrix-org/util/json_test.go @@ -164,8 +164,8 @@ func TestGetLogger(t *testing.T) { noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) ctxLogger = GetLogger(noLoggerInReq.Context()) - if ctxLogger != nil { - t.Errorf("TestGetLogger wanted nil logger, got '%v'", ctxLogger) + if ctxLogger == nil { + t.Errorf("TestGetLogger wanted logger, got nil") } }