diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/consumer.go b/src/github.com/matrix-org/dendrite/roomserver/input/consumer.go index 2c50103b..b433d707 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/consumer.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/consumer.go @@ -3,9 +3,11 @@ package input import ( "encoding/json" + "fmt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" sarama "gopkg.in/Shopify/sarama.v1" + "sync/atomic" ) // A ConsumerDatabase has the storage APIs needed by the consumer. @@ -45,6 +47,13 @@ type Consumer struct { // The ErrorLogger for this consumer. // If left as nil then the consumer will panic when it encounters an error ErrorLogger ErrorLogger + // If non-nil then the consumer will stop processing messages after this + // many messages and will shutdown. Malformed messages are included in the count. + StopProcessingAfter *int64 + // If not-nil then the consumer will call this to shutdown the server. + ShutdownCallback func(reason string) + // How many messages the consumer has processed. + processed int64 } // WriteOutputRoomEvent implements OutputRoomEventWriter @@ -126,6 +135,19 @@ func (c *Consumer) consumePartition(pc sarama.PartitionConsumer) { if err := c.DB.SetPartitionOffset(c.InputRoomEventTopic, message.Partition, message.Offset); err != nil { c.logError(message, err) } + // Update the number of processed messages using atomic addition because it is accessed from multiple goroutines. + processed := atomic.AddInt64(&c.processed, 1) + // Check if we should stop processing. + // Note that since we have multiple goroutines it's quite likely that we'll overshoot by a few messages. + // If we try to stop processing after M message and we have N goroutines then we will process somewhere + // between M and (N + M) messages because the N goroutines could all try to process what they think will be the + // last message. We could be more careful here but this is good enough for getting rough benchmarks. + if c.StopProcessingAfter != nil && processed >= int64(*c.StopProcessingAfter) { + if c.ShutdownCallback != nil { + c.ShutdownCallback(fmt.Sprintf("Stopping processing after %d messages", c.processed)) + } + return + } } } diff --git a/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go b/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go index 689fb48d..e865b5be 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go +++ b/src/github.com/matrix-org/dendrite/roomserver/roomserver/roomserver.go @@ -7,7 +7,9 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" sarama "gopkg.in/Shopify/sarama.v1" "net/http" + _ "net/http/pprof" "os" + "strconv" "strings" ) @@ -17,6 +19,10 @@ var ( inputRoomEventTopic = os.Getenv("TOPIC_INPUT_ROOM_EVENT") outputRoomEventTopic = os.Getenv("TOPIC_OUTPUT_ROOM_EVENT") bindAddr = os.Getenv("BIND_ADDRESS") + // Shuts the roomserver down after processing a given number of messages. + // This is useful for running benchmarks for seeing how quickly the server + // can process a given number of messages. + stopProcessingAfter = os.Getenv("STOP_AFTER") ) func main() { @@ -43,6 +49,18 @@ func main() { OutputRoomEventTopic: outputRoomEventTopic, } + if stopProcessingAfter != "" { + count, err := strconv.ParseInt(stopProcessingAfter, 10, 64) + if err != nil { + panic(err) + } + consumer.StopProcessingAfter = &count + consumer.ShutdownCallback = func(message string) { + fmt.Println("Stopping roomserver", message) + os.Exit(0) + } + } + if err = consumer.Start(); err != nil { panic(err) } @@ -56,5 +74,7 @@ func main() { fmt.Println("Started roomserver") // TODO: Implement clean shutdown. - http.ListenAndServe(bindAddr, nil) + if err := http.ListenAndServe(bindAddr, nil); err != nil { + panic(err) + } }