Add a query API to the roomserver for getting the latest events in a room. (#23)
* Start implementing a query API for go using net/rpc * Use a conventional JSON POST API rather than go net/rpc net/rpc doesn't automatically handle reconnecting and we have better logging and metrics infrastructure for monitoring HTTP apis. * Implement the query API and add it to the integration tests * Increase the timeout, travis seems to be a bit slow * Clarify that state events are the things that are not returned if they are not requested * Add utility function for converting arrays of numeric event IDs to pq Int64Arrays * Warn people against requesting empty state keys by accidentmain
parent
37e0b6c4c6
commit
9a8a8aedcb
|
@ -0,0 +1,102 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StateKeyTuple is a pair of an event type and state_key.
|
||||||
|
// This is used when requesting parts of the state of a room.
|
||||||
|
type StateKeyTuple struct {
|
||||||
|
// The "type" key
|
||||||
|
EventType string
|
||||||
|
// The "state_key" of a matrix event.
|
||||||
|
// The empty string is a legitimate value for the "state_key" in matrix
|
||||||
|
// so take care to initialise this field lest you accidentally request a
|
||||||
|
// "state_key" with the go default of the empty string.
|
||||||
|
EventStateKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState
|
||||||
|
type QueryLatestEventsAndStateRequest struct {
|
||||||
|
// The roomID to query the latest events for.
|
||||||
|
RoomID string
|
||||||
|
// The state key tuples to fetch from the room current state.
|
||||||
|
// If this list is empty or nil then no state events are returned.
|
||||||
|
StateToFetch []StateKeyTuple
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndStateResponse is a response to QueryLatestEventsAndState
|
||||||
|
type QueryLatestEventsAndStateResponse struct {
|
||||||
|
// Copy of the request for debugging.
|
||||||
|
QueryLatestEventsAndStateRequest
|
||||||
|
// Does the room exist?
|
||||||
|
// If the room doesn't exist this will be false and LatestEvents will be empty.
|
||||||
|
RoomExists bool
|
||||||
|
// The latest events in the room.
|
||||||
|
LatestEvents []gomatrixserverlib.EventReference
|
||||||
|
// The state events requested.
|
||||||
|
StateEvents []gomatrixserverlib.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomserverQueryAPI is used to query information from the room server.
|
||||||
|
type RoomserverQueryAPI interface {
|
||||||
|
// Query the latest events and state for a room from the room server.
|
||||||
|
QueryLatestEventsAndState(
|
||||||
|
request *QueryLatestEventsAndStateRequest,
|
||||||
|
response *QueryLatestEventsAndStateResponse,
|
||||||
|
) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API.
|
||||||
|
const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/QueryLatestEventsAndState"
|
||||||
|
|
||||||
|
// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API.
|
||||||
|
// If httpClient is nil then it uses the http.DefaultClient
|
||||||
|
func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &httpRoomserverQueryAPI{roomserverURL, *httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpRoomserverQueryAPI struct {
|
||||||
|
roomserverURL string
|
||||||
|
httpClient http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndState implements RoomserverQueryAPI
|
||||||
|
func (h *httpRoomserverQueryAPI) QueryLatestEventsAndState(
|
||||||
|
request *QueryLatestEventsAndStateRequest,
|
||||||
|
response *QueryLatestEventsAndStateResponse,
|
||||||
|
) error {
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath
|
||||||
|
return postJSON(h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func postJSON(httpClient http.Client, apiURL string, request, response interface{}) error {
|
||||||
|
jsonBytes, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res, err := httpClient.Post(apiURL, "application/json", bytes.NewReader(jsonBytes))
|
||||||
|
if res != nil {
|
||||||
|
defer res.Body.Close()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
var errorBody struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
if err = json.NewDecoder(res.Body).Decode(&errorBody); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("api: %d: %s", res.StatusCode, errorBody.Message)
|
||||||
|
}
|
||||||
|
return json.NewDecoder(res.Body).Decode(response)
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API.
|
||||||
|
type RoomserverQueryAPIDatabase interface {
|
||||||
|
// Lookup the numeric ID for the room.
|
||||||
|
// Returns 0 if the room doesn't exists.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
RoomNID(roomID string) (types.RoomNID, error)
|
||||||
|
// Lookup event references for the latest events in the room.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomserverQueryAPI is an implementation of RoomserverQueryAPI
|
||||||
|
type RoomserverQueryAPI struct {
|
||||||
|
DB RoomserverQueryAPIDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndState implements api.RoomserverQueryAPI
|
||||||
|
func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
|
||||||
|
request *api.QueryLatestEventsAndStateRequest,
|
||||||
|
response *api.QueryLatestEventsAndStateResponse,
|
||||||
|
) (err error) {
|
||||||
|
response.QueryLatestEventsAndStateRequest = *request
|
||||||
|
roomNID, err := r.DB.RoomNID(request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if roomNID == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
response.RoomExists = true
|
||||||
|
response.LatestEvents, err = r.DB.LatestEventIDs(roomNID)
|
||||||
|
// TODO: look up the current state.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
|
||||||
|
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
|
||||||
|
servMux.Handle(
|
||||||
|
api.RoomserverQueryLatestEventsAndStatePath,
|
||||||
|
makeAPI("query_latest_events_and_state", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request api.QueryLatestEventsAndStateRequest
|
||||||
|
var response api.QueryLatestEventsAndStateResponse
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
if err := r.QueryLatestEventsAndState(&request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: 200, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeAPI(metric string, apiFunc func(req *http.Request) util.JSONResponse) http.Handler {
|
||||||
|
return prometheus.InstrumentHandler(metric, util.MakeJSONAPI(util.NewJSONRequestHandler(apiFunc)))
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
@ -17,8 +18,12 @@ var (
|
||||||
zookeeperURI = defaulting(os.Getenv("ZOOKEEPER_URI"), "localhost:2181")
|
zookeeperURI = defaulting(os.Getenv("ZOOKEEPER_URI"), "localhost:2181")
|
||||||
// The URI the kafka server is listening on.
|
// The URI the kafka server is listening on.
|
||||||
kafkaURI = defaulting(os.Getenv("KAFKA_URIS"), "localhost:9092")
|
kafkaURI = defaulting(os.Getenv("KAFKA_URIS"), "localhost:9092")
|
||||||
|
// The address the roomserver should listen on.
|
||||||
|
roomserverAddr = defaulting(os.Getenv("ROOMSERVER_URI"), "localhost:9876")
|
||||||
// How long to wait for the roomserver to write the expected output messages.
|
// How long to wait for the roomserver to write the expected output messages.
|
||||||
timeoutString = defaulting(os.Getenv("TIMEOUT"), "10s")
|
// This needs to be high enough to account for the time it takes to create
|
||||||
|
// the postgres database tables which can take a while on travis.
|
||||||
|
timeoutString = defaulting(os.Getenv("TIMEOUT"), "60s")
|
||||||
// The name of maintenance database to connect to in order to create the test database.
|
// The name of maintenance database to connect to in order to create the test database.
|
||||||
postgresDatabase = defaulting(os.Getenv("POSTGRES_DATABASE"), "postgres")
|
postgresDatabase = defaulting(os.Getenv("POSTGRES_DATABASE"), "postgres")
|
||||||
// The name of the test database to create.
|
// The name of the test database to create.
|
||||||
|
@ -91,7 +96,7 @@ func writeToTopic(topic string, data []string) error {
|
||||||
// messages is reached or after a timeout. It kills the command before it returns.
|
// messages is reached or after a timeout. It kills the command before it returns.
|
||||||
// It returns a list of the messages read from the command on success or an error
|
// It returns a list of the messages read from the command on success or an error
|
||||||
// on failure.
|
// on failure.
|
||||||
func runAndReadFromTopic(runCmd *exec.Cmd, topic string, count int) ([]string, error) {
|
func runAndReadFromTopic(runCmd *exec.Cmd, topic string, count int, checkQueryAPI func()) ([]string, error) {
|
||||||
type result struct {
|
type result struct {
|
||||||
// data holds all of stdout on success.
|
// data holds all of stdout on success.
|
||||||
data []byte
|
data []byte
|
||||||
|
@ -111,7 +116,17 @@ func runAndReadFromTopic(runCmd *exec.Cmd, topic string, count int) ([]string, e
|
||||||
// Run the command, read the messages and wait for a timeout in parallel.
|
// Run the command, read the messages and wait for a timeout in parallel.
|
||||||
go func() {
|
go func() {
|
||||||
// Read all of stdout.
|
// Read all of stdout.
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
if errv, ok := err.(error); ok {
|
||||||
|
done <- result{nil, errv}
|
||||||
|
} else {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
data, err := readCmd.Output()
|
data, err := readCmd.Output()
|
||||||
|
checkQueryAPI()
|
||||||
done <- result{data, err}
|
done <- result{data, err}
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -157,7 +172,16 @@ func deleteTopic(topic string) error {
|
||||||
return cmd.Run()
|
return cmd.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
func testRoomServer(input []string, wantOutput []string) {
|
// testRoomserver is used to run integration tests against a single roomserver.
|
||||||
|
// It creates new kafka topics for the input and output of the roomserver.
|
||||||
|
// It writes the input messages to the input kafka topic, formatting each message
|
||||||
|
// as canonical JSON so that it fits on a single line.
|
||||||
|
// It then runs the roomserver and waits for a number of messages to be written
|
||||||
|
// to the output topic.
|
||||||
|
// Once those messages have been written it runs the checkQueries function passing
|
||||||
|
// a api.RoomserverQueryAPI client. The caller can use this function to check the
|
||||||
|
// behaviour of the query API.
|
||||||
|
func testRoomserver(input []string, wantOutput []string, checkQueries func(api.RoomserverQueryAPI)) {
|
||||||
const (
|
const (
|
||||||
inputTopic = "roomserverInput"
|
inputTopic = "roomserverInput"
|
||||||
outputTopic = "roomserverOutput"
|
outputTopic = "roomserverOutput"
|
||||||
|
@ -191,10 +215,14 @@ func testRoomServer(input []string, wantOutput []string) {
|
||||||
fmt.Sprintf("KAFKA_URIS=%s", kafkaURI),
|
fmt.Sprintf("KAFKA_URIS=%s", kafkaURI),
|
||||||
fmt.Sprintf("TOPIC_INPUT_ROOM_EVENT=%s", inputTopic),
|
fmt.Sprintf("TOPIC_INPUT_ROOM_EVENT=%s", inputTopic),
|
||||||
fmt.Sprintf("TOPIC_OUTPUT_ROOM_EVENT=%s", outputTopic),
|
fmt.Sprintf("TOPIC_OUTPUT_ROOM_EVENT=%s", outputTopic),
|
||||||
|
fmt.Sprintf("BIND_ADDRESS=%s", roomserverAddr),
|
||||||
)
|
)
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
gotOutput, err := runAndReadFromTopic(cmd, outputTopic, 1)
|
gotOutput, err := runAndReadFromTopic(cmd, outputTopic, len(wantOutput), func() {
|
||||||
|
queryAPI := api.NewRoomserverQueryAPIHTTP("http://"+roomserverAddr, nil)
|
||||||
|
checkQueries(queryAPI)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -334,7 +362,21 @@ func main() {
|
||||||
}`,
|
}`,
|
||||||
}
|
}
|
||||||
|
|
||||||
testRoomServer(input, want)
|
testRoomserver(input, want, func(q api.RoomserverQueryAPI) {
|
||||||
|
var response api.QueryLatestEventsAndStateResponse
|
||||||
|
if err := q.QueryLatestEventsAndState(
|
||||||
|
&api.QueryLatestEventsAndStateRequest{RoomID: "!HCXfdvrfksxuYnIFiJ:matrix.org"},
|
||||||
|
&response,
|
||||||
|
); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if !response.RoomExists {
|
||||||
|
panic(fmt.Errorf(`Wanted room "!HCXfdvrfksxuYnIFiJ:matrix.org" to exist`))
|
||||||
|
}
|
||||||
|
if len(response.LatestEvents) != 1 || response.LatestEvents[0].EventID != "$1463671339126270PnVwC:matrix.org" {
|
||||||
|
panic(fmt.Errorf(`Wanted "$1463671339126270PnVwC:matrix.org" to be the latest event got %#v`, response.LatestEvents))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
fmt.Println("==PASSED==", os.Args[0])
|
fmt.Println("==PASSED==", os.Args[0])
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,10 @@ package main
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/matrix-org/dendrite/roomserver/input"
|
"github.com/matrix-org/dendrite/roomserver/input"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/query"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
sarama "gopkg.in/Shopify/sarama.v1"
|
sarama "gopkg.in/Shopify/sarama.v1"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -14,6 +16,7 @@ var (
|
||||||
kafkaURIs = strings.Split(os.Getenv("KAFKA_URIS"), ",")
|
kafkaURIs = strings.Split(os.Getenv("KAFKA_URIS"), ",")
|
||||||
inputRoomEventTopic = os.Getenv("TOPIC_INPUT_ROOM_EVENT")
|
inputRoomEventTopic = os.Getenv("TOPIC_INPUT_ROOM_EVENT")
|
||||||
outputRoomEventTopic = os.Getenv("TOPIC_OUTPUT_ROOM_EVENT")
|
outputRoomEventTopic = os.Getenv("TOPIC_OUTPUT_ROOM_EVENT")
|
||||||
|
bindAddr = os.Getenv("BIND_ADDRESS")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -44,9 +47,14 @@ func main() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
queryAPI := query.RoomserverQueryAPI{
|
||||||
|
DB: db,
|
||||||
|
}
|
||||||
|
|
||||||
|
queryAPI.SetupHTTP(http.DefaultServeMux)
|
||||||
|
|
||||||
fmt.Println("Started roomserver")
|
fmt.Println("Started roomserver")
|
||||||
|
|
||||||
// Wait forever.
|
|
||||||
// TODO: Implement clean shutdown.
|
// TODO: Implement clean shutdown.
|
||||||
select {}
|
http.ListenAndServe(bindAddr, nil)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -65,11 +64,7 @@ type eventJSONPair struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
|
func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
|
||||||
nids := make([]int64, len(eventNIDs))
|
rows, err := s.bulkSelectEventJSONStmt.Query(eventNIDsAsArray(eventNIDs))
|
||||||
for i := range eventNIDs {
|
|
||||||
nids[i] = int64(eventNIDs[i])
|
|
||||||
}
|
|
||||||
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(nids))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventsSchema = `
|
const eventsSchema = `
|
||||||
|
@ -83,6 +84,9 @@ const bulkSelectStateAtEventAndReferenceSQL = "" +
|
||||||
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
|
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
|
||||||
" FROM events WHERE event_nid = ANY($1)"
|
" FROM events WHERE event_nid = ANY($1)"
|
||||||
|
|
||||||
|
const bulkSelectEventReferenceSQL = "" +
|
||||||
|
"SELECT event_id, reference_sha256 FROM events WHERE event_nid = ANY($1)"
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
|
@ -93,6 +97,7 @@ type eventStatements struct {
|
||||||
updateEventSentToOutputStmt *sql.Stmt
|
updateEventSentToOutputStmt *sql.Stmt
|
||||||
selectEventIDStmt *sql.Stmt
|
selectEventIDStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
||||||
|
bulkSelectEventReferenceStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
@ -127,6 +132,9 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.bulkSelectStateAtEventAndReferenceStmt, err = db.Prepare(bulkSelectStateAtEventAndReferenceSQL); err != nil {
|
if s.bulkSelectStateAtEventAndReferenceStmt, err = db.Prepare(bulkSelectStateAtEventAndReferenceSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.bulkSelectEventReferenceStmt, err = db.Prepare(bulkSelectEventReferenceSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,15 +144,11 @@ func (s *eventStatements) insertEvent(
|
||||||
referenceSHA256 []byte,
|
referenceSHA256 []byte,
|
||||||
authEventNIDs []types.EventNID,
|
authEventNIDs []types.EventNID,
|
||||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||||
nids := make([]int64, len(authEventNIDs))
|
|
||||||
for i := range authEventNIDs {
|
|
||||||
nids[i] = int64(authEventNIDs[i])
|
|
||||||
}
|
|
||||||
var eventNID int64
|
var eventNID int64
|
||||||
var stateNID int64
|
var stateNID int64
|
||||||
err := s.insertEventStmt.QueryRow(
|
err := s.insertEventStmt.QueryRow(
|
||||||
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
|
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
|
||||||
pq.Int64Array(nids),
|
eventNIDsAsArray(authEventNIDs),
|
||||||
).Scan(&eventNID, &stateNID)
|
).Scan(&eventNID, &stateNID)
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||||
}
|
}
|
||||||
|
@ -238,11 +242,7 @@ func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
|
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
|
||||||
nids := make([]int64, len(eventNIDs))
|
rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
|
||||||
for i := range eventNIDs {
|
|
||||||
nids[i] = int64(eventNIDs[i])
|
|
||||||
}
|
|
||||||
rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(pq.Int64Array(nids))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -276,3 +276,31 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventN
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) {
|
||||||
|
rows, err := s.bulkSelectEventReferenceStmt.Query(eventNIDsAsArray(eventNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
results := make([]gomatrixserverlib.EventReference, len(eventNIDs))
|
||||||
|
i := 0
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
|
result := &results[i]
|
||||||
|
if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i != len(eventNIDs) {
|
||||||
|
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
|
||||||
|
nids := make([]int64, len(eventNIDs))
|
||||||
|
for i := range eventNIDs {
|
||||||
|
nids[i] = int64(eventNIDs[i])
|
||||||
|
}
|
||||||
|
return nids
|
||||||
|
}
|
||||||
|
|
|
@ -32,16 +32,20 @@ const selectRoomNIDSQL = "" +
|
||||||
"SELECT room_nid FROM rooms WHERE room_id = $1"
|
"SELECT room_nid FROM rooms WHERE room_id = $1"
|
||||||
|
|
||||||
const selectLatestEventNIDsSQL = "" +
|
const selectLatestEventNIDsSQL = "" +
|
||||||
|
"SELECT latest_event_nids FROM rooms WHERE room_nid = $1"
|
||||||
|
|
||||||
|
const selectLatestEventNIDsForUpdateSQL = "" +
|
||||||
"SELECT latest_event_nids, last_event_sent_nid FROM rooms WHERE room_nid = $1 FOR UPDATE"
|
"SELECT latest_event_nids, last_event_sent_nid FROM rooms WHERE room_nid = $1 FOR UPDATE"
|
||||||
|
|
||||||
const updateLatestEventNIDsSQL = "" +
|
const updateLatestEventNIDsSQL = "" +
|
||||||
"UPDATE rooms SET latest_event_nids = $2, last_event_sent_nid = $3 WHERE room_nid = $1"
|
"UPDATE rooms SET latest_event_nids = $2, last_event_sent_nid = $3 WHERE room_nid = $1"
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
selectRoomNIDStmt *sql.Stmt
|
selectRoomNIDStmt *sql.Stmt
|
||||||
selectLatestEventNIDsStmt *sql.Stmt
|
selectLatestEventNIDsStmt *sql.Stmt
|
||||||
updateLatestEventNIDsStmt *sql.Stmt
|
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
||||||
|
updateLatestEventNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
@ -58,6 +62,9 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectLatestEventNIDsStmt, err = db.Prepare(selectLatestEventNIDsSQL); err != nil {
|
if s.selectLatestEventNIDsStmt, err = db.Prepare(selectLatestEventNIDsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectLatestEventNIDsForUpdateStmt, err = db.Prepare(selectLatestEventNIDsForUpdateSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.updateLatestEventNIDsStmt, err = db.Prepare(updateLatestEventNIDsSQL); err != nil {
|
if s.updateLatestEventNIDsStmt, err = db.Prepare(updateLatestEventNIDsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -76,10 +83,23 @@ func (s *roomStatements) selectRoomNID(roomID string) (types.RoomNID, error) {
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, error) {
|
||||||
|
var nids pq.Int64Array
|
||||||
|
err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
eventNIDs := make([]types.EventNID, len(nids))
|
||||||
|
for i := range nids {
|
||||||
|
eventNIDs[i] = types.EventNID(nids[i])
|
||||||
|
}
|
||||||
|
return eventNIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, error) {
|
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, error) {
|
||||||
var nids pq.Int64Array
|
var nids pq.Int64Array
|
||||||
var lastEventSentNID int64
|
var lastEventSentNID int64
|
||||||
err := txn.Stmt(s.selectLatestEventNIDsStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID)
|
err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
@ -91,10 +111,6 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) updateLatestEventNIDs(txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID) error {
|
func (s *roomStatements) updateLatestEventNIDs(txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID) error {
|
||||||
nids := make([]int64, len(eventNIDs))
|
_, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID))
|
||||||
for i := range eventNIDs {
|
|
||||||
nids[i] = int64(eventNIDs[i])
|
|
||||||
}
|
|
||||||
_, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(roomNID, pq.Int64Array(nids), int64(lastEventSentNID))
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -280,3 +280,21 @@ func (u *roomRecentEventsUpdater) Commit() error {
|
||||||
func (u *roomRecentEventsUpdater) Rollback() error {
|
func (u *roomRecentEventsUpdater) Rollback() error {
|
||||||
return u.txn.Rollback()
|
return u.txn.Rollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoomNID implements query.RoomserverQueryAPIDB
|
||||||
|
func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
|
||||||
|
roomNID, err := d.statements.selectRoomNID(roomID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return roomNID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// LatestEventIDs implements query.RoomserverQueryAPIDB
|
||||||
|
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error) {
|
||||||
|
eventNIDs, err := d.statements.selectLatestEventNIDs(roomNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.statements.bulkSelectEventReference(eventNIDs)
|
||||||
|
}
|
||||||
|
|
|
@ -98,7 +98,7 @@
|
||||||
{
|
{
|
||||||
"importpath": "github.com/matrix-org/util",
|
"importpath": "github.com/matrix-org/util",
|
||||||
"repository": "https://github.com/matrix-org/util",
|
"repository": "https://github.com/matrix-org/util",
|
||||||
"revision": "ccef6dc7c24a7c896d96b433a9107b7c47ecf828",
|
"revision": "28bd7491c8aafbf346ca23821664f0f9911ef52b",
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -25,11 +25,13 @@ func GetRequestID(ctx context.Context) string {
|
||||||
// ctxValueLogger is the key to extract the logrus Logger.
|
// ctxValueLogger is the key to extract the logrus Logger.
|
||||||
const ctxValueLogger = contextKeys("logger")
|
const ctxValueLogger = contextKeys("logger")
|
||||||
|
|
||||||
// GetLogger retrieves the logrus logger from the supplied context. Returns nil if there is no logger.
|
// GetLogger retrieves the logrus logger from the supplied context. Always returns a logger,
|
||||||
|
// even if there wasn't one originally supplied.
|
||||||
func GetLogger(ctx context.Context) *log.Entry {
|
func GetLogger(ctx context.Context) *log.Entry {
|
||||||
l := ctx.Value(ctxValueLogger)
|
l := ctx.Value(ctxValueLogger)
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return nil
|
// Always return a logger so callers don't need to constantly nil check.
|
||||||
|
return log.WithField("context", "missing")
|
||||||
}
|
}
|
||||||
return l.(*log.Entry)
|
return l.(*log.Entry)
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,6 +58,21 @@ type JSONRequestHandler interface {
|
||||||
OnIncomingRequest(req *http.Request) JSONResponse
|
OnIncomingRequest(req *http.Request) JSONResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// jsonRequestHandlerWrapper is a wrapper to allow in-line functions to conform to util.JSONRequestHandler
|
||||||
|
type jsonRequestHandlerWrapper struct {
|
||||||
|
function func(req *http.Request) JSONResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnIncomingRequest implements util.JSONRequestHandler
|
||||||
|
func (r *jsonRequestHandlerWrapper) OnIncomingRequest(req *http.Request) JSONResponse {
|
||||||
|
return r.function(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJSONRequestHandler converts the given OnIncomingRequest function into a JSONRequestHandler
|
||||||
|
func NewJSONRequestHandler(f func(req *http.Request) JSONResponse) JSONRequestHandler {
|
||||||
|
return &jsonRequestHandlerWrapper{f}
|
||||||
|
}
|
||||||
|
|
||||||
// Protect panicking HTTP requests from taking down the entire process, and log them using
|
// Protect panicking HTTP requests from taking down the entire process, and log them using
|
||||||
// the correct logger, returning a 500 with a JSON response rather than abruptly closing the
|
// the correct logger, returning a 500 with a JSON response rather than abruptly closing the
|
||||||
// connection. The http.Request MUST have a ctxValueLogger.
|
// connection. The http.Request MUST have a ctxValueLogger.
|
||||||
|
|
|
@ -164,8 +164,8 @@ func TestGetLogger(t *testing.T) {
|
||||||
|
|
||||||
noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
|
noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
|
||||||
ctxLogger = GetLogger(noLoggerInReq.Context())
|
ctxLogger = GetLogger(noLoggerInReq.Context())
|
||||||
if ctxLogger != nil {
|
if ctxLogger == nil {
|
||||||
t.Errorf("TestGetLogger wanted nil logger, got '%v'", ctxLogger)
|
t.Errorf("TestGetLogger wanted logger, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue