Use a Postgres database rather than Memory for Naffka (#337)
* Update naffka dep * User Postgres database rather than Memory for Naffkamain
parent
bdc44c4bde
commit
8599a36fa6
|
@ -16,6 +16,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"flag"
|
"flag"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -199,7 +200,21 @@ func (m *monolith) setupFederation() {
|
||||||
|
|
||||||
func (m *monolith) setupKafka() {
|
func (m *monolith) setupKafka() {
|
||||||
if m.cfg.Kafka.UseNaffka {
|
if m.cfg.Kafka.UseNaffka {
|
||||||
naff, err := naffka.New(&naffka.MemoryDatabase{})
|
db, err := sql.Open("postgres", string(m.cfg.Database.Naffka))
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
log.ErrorKey: err,
|
||||||
|
}).Panic("Failed to open naffka database")
|
||||||
|
}
|
||||||
|
|
||||||
|
naffkaDB, err := naffka.NewPostgresqlDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
log.ErrorKey: err,
|
||||||
|
}).Panic("Failed to setup naffka database")
|
||||||
|
}
|
||||||
|
|
||||||
|
naff, err := naffka.New(naffkaDB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
log.ErrorKey: err,
|
log.ErrorKey: err,
|
||||||
|
|
|
@ -148,6 +148,8 @@ type Dendrite struct {
|
||||||
// The PublicRoomsAPI database stores information used to compute the public
|
// The PublicRoomsAPI database stores information used to compute the public
|
||||||
// room directory. It is only accessed by the PublicRoomsAPI server.
|
// room directory. It is only accessed by the PublicRoomsAPI server.
|
||||||
PublicRoomsAPI DataSource `yaml:"public_rooms_api"`
|
PublicRoomsAPI DataSource `yaml:"public_rooms_api"`
|
||||||
|
// The Naffka database is used internally by the naffka library, if used.
|
||||||
|
Naffka DataSource `yaml:"naffka,omitempty"`
|
||||||
} `yaml:"database"`
|
} `yaml:"database"`
|
||||||
|
|
||||||
// TURN Server Config
|
// TURN Server Config
|
||||||
|
@ -386,6 +388,8 @@ func (config *Dendrite) check(monolithic bool) error {
|
||||||
if !monolithic {
|
if !monolithic {
|
||||||
problems = append(problems, fmt.Sprintf("naffka can only be used in a monolithic server"))
|
problems = append(problems, fmt.Sprintf("naffka can only be used in a monolithic server"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkNotEmpty("database.naffka", string(config.Database.Naffka))
|
||||||
} else {
|
} else {
|
||||||
// If we aren't using naffka then we need to have at least one kafka
|
// If we aren't using naffka then we need to have at least one kafka
|
||||||
// server to talk to.
|
// server to talk to.
|
||||||
|
|
|
@ -141,7 +141,7 @@
|
||||||
{
|
{
|
||||||
"importpath": "github.com/matrix-org/naffka",
|
"importpath": "github.com/matrix-org/naffka",
|
||||||
"repository": "https://github.com/matrix-org/naffka",
|
"repository": "https://github.com/matrix-org/naffka",
|
||||||
"revision": "d28656e34f96a8eeaab53e3b7678c9ce14af5786",
|
"revision": "662bfd0841d0194bfe0a700d54226bb96eac574d",
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -8,7 +8,8 @@ import (
|
||||||
// A MemoryDatabase stores the message history as arrays in memory.
|
// A MemoryDatabase stores the message history as arrays in memory.
|
||||||
// It can be used to run unit tests.
|
// It can be used to run unit tests.
|
||||||
// If the process is stopped then any messages that haven't been
|
// If the process is stopped then any messages that haven't been
|
||||||
// processed by a consumer are lost forever.
|
// processed by a consumer are lost forever and all offsets become
|
||||||
|
// invalid.
|
||||||
type MemoryDatabase struct {
|
type MemoryDatabase struct {
|
||||||
topicsMutex sync.Mutex
|
topicsMutex sync.Mutex
|
||||||
topics map[string]*memoryDatabaseTopic
|
topics map[string]*memoryDatabaseTopic
|
||||||
|
@ -58,10 +59,7 @@ func (m *MemoryDatabase) getTopic(topicName string) *memoryDatabaseTopic {
|
||||||
|
|
||||||
// StoreMessages implements Database
|
// StoreMessages implements Database
|
||||||
func (m *MemoryDatabase) StoreMessages(topic string, messages []Message) error {
|
func (m *MemoryDatabase) StoreMessages(topic string, messages []Message) error {
|
||||||
if err := m.getTopic(topic).addMessages(messages); err != nil {
|
return m.getTopic(topic).addMessages(messages)
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchMessages implements Database
|
// FetchMessages implements Database
|
||||||
|
@ -73,10 +71,10 @@ func (m *MemoryDatabase) FetchMessages(topic string, startOffset, endOffset int6
|
||||||
if startOffset >= endOffset {
|
if startOffset >= endOffset {
|
||||||
return nil, fmt.Errorf("start offset %d greater than or equal to end offset %d", startOffset, endOffset)
|
return nil, fmt.Errorf("start offset %d greater than or equal to end offset %d", startOffset, endOffset)
|
||||||
}
|
}
|
||||||
if startOffset < -1 {
|
if startOffset < 0 {
|
||||||
return nil, fmt.Errorf("start offset %d less than -1", startOffset)
|
return nil, fmt.Errorf("start offset %d less than 0", startOffset)
|
||||||
}
|
}
|
||||||
return messages[startOffset+1 : endOffset], nil
|
return messages[startOffset:endOffset], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaxOffsets implements Database
|
// MaxOffsets implements Database
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
// single go process. It implements both the sarama.SyncProducer and the
|
// single go process. It implements both the sarama.SyncProducer and the
|
||||||
// sarama.Consumer interfaces. This means it can act as a drop in replacement
|
// sarama.Consumer interfaces. This means it can act as a drop in replacement
|
||||||
// for kafka for testing or single instance deployment.
|
// for kafka for testing or single instance deployment.
|
||||||
|
// Does not support multiple partitions.
|
||||||
type Naffka struct {
|
type Naffka struct {
|
||||||
db Database
|
db Database
|
||||||
topicsMutex sync.Mutex
|
topicsMutex sync.Mutex
|
||||||
|
@ -28,6 +29,7 @@ func New(db Database) (*Naffka, error) {
|
||||||
}
|
}
|
||||||
for topicName, offset := range maxOffsets {
|
for topicName, offset := range maxOffsets {
|
||||||
n.topics[topicName] = &topic{
|
n.topics[topicName] = &topic{
|
||||||
|
db: db,
|
||||||
topicName: topicName,
|
topicName: topicName,
|
||||||
nextOffset: offset + 1,
|
nextOffset: offset + 1,
|
||||||
}
|
}
|
||||||
|
@ -64,7 +66,7 @@ type Database interface {
|
||||||
// So for a given topic the message with offset n+1 is stored after the
|
// So for a given topic the message with offset n+1 is stored after the
|
||||||
// the message with offset n.
|
// the message with offset n.
|
||||||
StoreMessages(topic string, messages []Message) error
|
StoreMessages(topic string, messages []Message) error
|
||||||
// FetchMessages fetches all messages with an offset greater than but not
|
// FetchMessages fetches all messages with an offset greater than and
|
||||||
// including startOffset and less than but not including endOffset.
|
// including startOffset and less than but not including endOffset.
|
||||||
// The range of offsets requested must not overlap with those stored by a
|
// The range of offsets requested must not overlap with those stored by a
|
||||||
// concurrent StoreMessages. The message offsets within the requested range
|
// concurrent StoreMessages. The message offsets within the requested range
|
||||||
|
@ -138,6 +140,7 @@ func (n *Naffka) Partitions(topic string) ([]int32, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConsumePartition implements sarama.Consumer
|
// ConsumePartition implements sarama.Consumer
|
||||||
|
// Note: offset is *inclusive*, i.e. it will include the message with that offset.
|
||||||
func (n *Naffka) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) {
|
func (n *Naffka) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) {
|
||||||
if partition != 0 {
|
if partition != 0 {
|
||||||
return nil, fmt.Errorf("Unknown partition ID %d", partition)
|
return nil, fmt.Errorf("Unknown partition ID %d", partition)
|
||||||
|
@ -166,13 +169,16 @@ func (n *Naffka) Close() error {
|
||||||
|
|
||||||
const channelSize = 1024
|
const channelSize = 1024
|
||||||
|
|
||||||
|
// partitionConsumer ensures that all messages written to a particular
|
||||||
|
// topic, from an offset, get sent in order to a channel.
|
||||||
|
// Implements sarama.PartitionConsumer
|
||||||
type partitionConsumer struct {
|
type partitionConsumer struct {
|
||||||
topic *topic
|
topic *topic
|
||||||
messages chan *sarama.ConsumerMessage
|
messages chan *sarama.ConsumerMessage
|
||||||
// Whether the consumer is ready for new messages or whether it
|
// Whether the consumer is in "catchup" mode or not.
|
||||||
// is catching up on historic messages.
|
// See "catchup" function for details.
|
||||||
// Reads and writes to this field are proctected by the topic mutex.
|
// Reads and writes to this field are proctected by the topic mutex.
|
||||||
ready bool
|
catchingUp bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// AsyncClose implements sarama.PartitionConsumer
|
// AsyncClose implements sarama.PartitionConsumer
|
||||||
|
@ -201,66 +207,101 @@ func (c *partitionConsumer) HighWaterMarkOffset() int64 {
|
||||||
return c.topic.highwaterMark()
|
return c.topic.highwaterMark()
|
||||||
}
|
}
|
||||||
|
|
||||||
// block writes the message to the consumer blocking until the consumer is ready
|
// catchup makes the consumer go into "catchup" mode, where messages are read
|
||||||
// to add the message to the channel. Once the message is successfully added to
|
// from the database instead of directly from producers.
|
||||||
// the channel it will catch up by pulling historic messsages from the database.
|
// Once the consumer is up to date, i.e. no new messages in the database, then
|
||||||
func (c *partitionConsumer) block(cmsg *sarama.ConsumerMessage) {
|
// the consumer will go back into normal mode where new messages are written
|
||||||
c.messages <- cmsg
|
// directly to the channel.
|
||||||
c.catchup(cmsg.Offset)
|
// Must be called with the c.topic.mutex lock
|
||||||
|
func (c *partitionConsumer) catchup(fromOffset int64) {
|
||||||
|
// If we're already in catchup mode or up to date, noop
|
||||||
|
if c.catchingUp || fromOffset == c.topic.nextOffset {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.catchingUp = true
|
||||||
|
|
||||||
|
// Due to the checks above there can only be one of these goroutines
|
||||||
|
// running at a time
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
// Check if we're up to date yet. If we are we exit catchup mode.
|
||||||
|
c.topic.mutex.Lock()
|
||||||
|
nextOffset := c.topic.nextOffset
|
||||||
|
if fromOffset == nextOffset {
|
||||||
|
c.catchingUp = false
|
||||||
|
c.topic.mutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.topic.mutex.Unlock()
|
||||||
|
|
||||||
|
// Limit the number of messages we request from the database to be the
|
||||||
|
// capacity of the channel.
|
||||||
|
if nextOffset > fromOffset+int64(cap(c.messages)) {
|
||||||
|
nextOffset = fromOffset + int64(cap(c.messages))
|
||||||
|
}
|
||||||
|
// Fetch the messages from the database.
|
||||||
|
msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: Add option to write consumer errors to an errors channel
|
||||||
|
// as an alternative to logging the errors.
|
||||||
|
log.Print("Error reading messages: ", err)
|
||||||
|
// Wait before retrying.
|
||||||
|
// TODO: Maybe use an exponentional backoff scheme here.
|
||||||
|
// TODO: This timeout should take account of all the other goroutines
|
||||||
|
// that might be doing the same thing. (If there are a 10000 consumers
|
||||||
|
// then we don't want to end up retrying every millisecond)
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
// This should only happen if the database is corrupted and has lost the
|
||||||
|
// messages between the requested offsets.
|
||||||
|
log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass the messages into the consumer channel.
|
||||||
|
// Blocking each write until the channel has enough space for the message.
|
||||||
|
for i := range msgs {
|
||||||
|
c.messages <- msgs[i].consumerMessage(c.topic.topicName)
|
||||||
|
}
|
||||||
|
// Update our the offset for the next loop iteration.
|
||||||
|
fromOffset = msgs[len(msgs)-1].Offset + 1
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// catchup reads historic messages from the database until the consumer has caught
|
// notifyNewMessage tells the consumer about a new message
|
||||||
// up on all the historic messages.
|
// Must be called with the c.topic.mutex lock
|
||||||
func (c *partitionConsumer) catchup(fromOffset int64) {
|
func (c *partitionConsumer) notifyNewMessage(cmsg *sarama.ConsumerMessage) {
|
||||||
for {
|
// If we're in "catchup" mode then the catchup routine will send the
|
||||||
// First check if we have caught up.
|
// message later, since cmsg has already been written to the database
|
||||||
caughtUp, nextOffset := c.topic.hasCaughtUp(c, fromOffset)
|
if c.catchingUp {
|
||||||
if caughtUp {
|
return
|
||||||
return
|
}
|
||||||
}
|
|
||||||
// Limit the number of messages we request from the database to be the
|
|
||||||
// capacity of the channel.
|
|
||||||
if nextOffset > fromOffset+int64(cap(c.messages)) {
|
|
||||||
nextOffset = fromOffset + int64(cap(c.messages))
|
|
||||||
}
|
|
||||||
// Fetch the messages from the database.
|
|
||||||
msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: Add option to write consumer errors to an errors channel
|
|
||||||
// as an alternative to logging the errors.
|
|
||||||
log.Print("Error reading messages: ", err)
|
|
||||||
// Wait before retrying.
|
|
||||||
// TODO: Maybe use an exponentional backoff scheme here.
|
|
||||||
// TODO: This timeout should take account of all the other goroutines
|
|
||||||
// that might be doing the same thing. (If there are a 10000 consumers
|
|
||||||
// then we don't want to end up retrying every millisecond)
|
|
||||||
time.Sleep(10 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(msgs) == 0 {
|
|
||||||
// This should only happen if the database is corrupted and has lost the
|
|
||||||
// messages between the requested offsets.
|
|
||||||
log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pass the messages into the consumer channel.
|
// Otherwise, lets try writing the message directly to the channel
|
||||||
// Blocking each write until the channel has enough space for the message.
|
select {
|
||||||
for i := range msgs {
|
case c.messages <- cmsg:
|
||||||
c.messages <- msgs[i].consumerMessage(c.topic.topicName)
|
default:
|
||||||
}
|
// The messages channel has filled up, so lets go into catchup
|
||||||
// Update our the offset for the next loop iteration.
|
// mode. Once the channel starts being read from again messages
|
||||||
fromOffset = msgs[len(msgs)-1].Offset
|
// will be read from the database
|
||||||
|
c.catchup(cmsg.Offset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type topic struct {
|
type topic struct {
|
||||||
db Database
|
db Database
|
||||||
topicName string
|
topicName string
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
consumers []*partitionConsumer
|
consumers []*partitionConsumer
|
||||||
|
// nextOffset is the offset that will be assigned to the next message in
|
||||||
|
// this topic, i.e. one greater than the last message offset.
|
||||||
nextOffset int64
|
nextOffset int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// send writes messages to a topic.
|
||||||
func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
|
func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
|
||||||
var err error
|
var err error
|
||||||
// Encode the message keys and values.
|
// Encode the message keys and values.
|
||||||
|
@ -298,21 +339,10 @@ func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
|
||||||
t.nextOffset = offset
|
t.nextOffset = offset
|
||||||
|
|
||||||
// Now notify the consumers about the messages.
|
// Now notify the consumers about the messages.
|
||||||
for i := range msgs {
|
for _, msg := range msgs {
|
||||||
cmsg := msgs[i].consumerMessage(t.topicName)
|
cmsg := msg.consumerMessage(t.topicName)
|
||||||
for _, c := range t.consumers {
|
for _, c := range t.consumers {
|
||||||
if c.ready {
|
c.notifyNewMessage(cmsg)
|
||||||
select {
|
|
||||||
case c.messages <- cmsg:
|
|
||||||
default:
|
|
||||||
// The consumer wasn't ready to receive a message because
|
|
||||||
// the channel buffer was full.
|
|
||||||
// Fork a goroutine to send the message so that we don't
|
|
||||||
// block sending messages to the other consumers.
|
|
||||||
c.ready = false
|
|
||||||
go c.block(cmsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -330,27 +360,17 @@ func (t *topic) consume(offset int64) *partitionConsumer {
|
||||||
offset = t.nextOffset
|
offset = t.nextOffset
|
||||||
}
|
}
|
||||||
if offset == sarama.OffsetOldest {
|
if offset == sarama.OffsetOldest {
|
||||||
offset = -1
|
offset = 0
|
||||||
}
|
}
|
||||||
c.messages = make(chan *sarama.ConsumerMessage, channelSize)
|
c.messages = make(chan *sarama.ConsumerMessage, channelSize)
|
||||||
t.consumers = append(t.consumers, c)
|
t.consumers = append(t.consumers, c)
|
||||||
// Start catching up on historic messages in the background.
|
|
||||||
go c.catchup(offset)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *topic) hasCaughtUp(c *partitionConsumer, offset int64) (bool, int64) {
|
// If we're not streaming from the latest offset we need to go into
|
||||||
t.mutex.Lock()
|
// "catchup" mode
|
||||||
defer t.mutex.Unlock()
|
if offset != t.nextOffset {
|
||||||
// Check if we have caught up while holding a lock on the topic so there
|
c.catchup(offset)
|
||||||
// isn't a way for our check to race with a new message being sent on the topic.
|
|
||||||
if offset+1 == t.nextOffset {
|
|
||||||
// We've caught up, the consumer can now receive messages as they are
|
|
||||||
// sent rather than fetching them from the database.
|
|
||||||
c.ready = true
|
|
||||||
return true, t.nextOffset
|
|
||||||
}
|
}
|
||||||
return false, t.nextOffset
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *topic) highwaterMark() int64 {
|
func (t *topic) highwaterMark() int64 {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package naffka
|
package naffka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -84,3 +85,142 @@ func TestDelayedReceive(t *testing.T) {
|
||||||
t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value))
|
t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCatchup(t *testing.T) {
|
||||||
|
naffka, err := New(&MemoryDatabase{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
producer := sarama.SyncProducer(naffka)
|
||||||
|
consumer := sarama.Consumer(naffka)
|
||||||
|
|
||||||
|
const topic = "testTopic"
|
||||||
|
const value = "Hello, World"
|
||||||
|
|
||||||
|
message := sarama.ProducerMessage{
|
||||||
|
Value: sarama.StringEncoder(value),
|
||||||
|
Topic: topic,
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, _, err = producer.SendMessage(&message); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := consumer.ConsumePartition(topic, 0, sarama.OffsetOldest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *sarama.ConsumerMessage
|
||||||
|
select {
|
||||||
|
case result = <-c.Messages():
|
||||||
|
case _ = <-time.NewTimer(10 * time.Second).C:
|
||||||
|
t.Fatal("expected to receive a message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Value) != value {
|
||||||
|
t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value))
|
||||||
|
}
|
||||||
|
|
||||||
|
currOffset := result.Offset
|
||||||
|
|
||||||
|
const value2 = "Hello, World2"
|
||||||
|
const value3 = "Hello, World3"
|
||||||
|
|
||||||
|
_, _, err = producer.SendMessage(&sarama.ProducerMessage{
|
||||||
|
Value: sarama.StringEncoder(value2),
|
||||||
|
Topic: topic,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = producer.SendMessage(&sarama.ProducerMessage{
|
||||||
|
Value: sarama.StringEncoder(value3),
|
||||||
|
Topic: topic,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Streaming from %q", currOffset+1)
|
||||||
|
|
||||||
|
c2, err := consumer.ConsumePartition(topic, 0, currOffset+1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result2 *sarama.ConsumerMessage
|
||||||
|
select {
|
||||||
|
case result2 = <-c2.Messages():
|
||||||
|
case _ = <-time.NewTimer(10 * time.Second).C:
|
||||||
|
t.Fatal("expected to receive a message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result2.Value) != value2 {
|
||||||
|
t.Fatalf("wrong value: wanted %q got %q", value2, string(result2.Value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannelSaturation(t *testing.T) {
|
||||||
|
// The channel returned by c.Messages() has a fixed capacity
|
||||||
|
|
||||||
|
naffka, err := New(&MemoryDatabase{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
producer := sarama.SyncProducer(naffka)
|
||||||
|
consumer := sarama.Consumer(naffka)
|
||||||
|
const topic = "testTopic"
|
||||||
|
const baseValue = "testValue: "
|
||||||
|
|
||||||
|
c, err := consumer.ConsumePartition(topic, 0, sarama.OffsetOldest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
channelSize := cap(c.Messages())
|
||||||
|
|
||||||
|
// We want to send enough messages to fill up the channel, so lets double
|
||||||
|
// the size of the channel. And add three in case its a zero sized channel
|
||||||
|
numberMessagesToSend := 2*channelSize + 3
|
||||||
|
|
||||||
|
var sentMessages []string
|
||||||
|
|
||||||
|
for i := 0; i < numberMessagesToSend; i++ {
|
||||||
|
value := baseValue + strconv.Itoa(i)
|
||||||
|
|
||||||
|
message := sarama.ProducerMessage{
|
||||||
|
Topic: topic,
|
||||||
|
Value: sarama.StringEncoder(value),
|
||||||
|
}
|
||||||
|
|
||||||
|
sentMessages = append(sentMessages, value)
|
||||||
|
|
||||||
|
if _, _, err = producer.SendMessage(&message); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *sarama.ConsumerMessage
|
||||||
|
|
||||||
|
j := 0
|
||||||
|
for ; j < numberMessagesToSend; j++ {
|
||||||
|
select {
|
||||||
|
case result = <-c.Messages():
|
||||||
|
case _ = <-time.NewTimer(10 * time.Second).C:
|
||||||
|
t.Fatalf("failed to receive message %d out of %d", j+1, numberMessagesToSend)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedValue := sentMessages[j]
|
||||||
|
if string(result.Value) != expectedValue {
|
||||||
|
t.Fatalf("wrong value: wanted %q got %q", expectedValue, string(result.Value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result = <-c.Messages():
|
||||||
|
t.Fatalf("expected to only receive %d messages", numberMessagesToSend)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,296 @@
|
||||||
|
package naffka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const postgresqlSchema = `
|
||||||
|
-- The topic table assigns each topic a unique numeric ID.
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS naffka_topic_nid_seq;
|
||||||
|
CREATE TABLE IF NOT EXISTS naffka_topics (
|
||||||
|
topic_name TEXT PRIMARY KEY,
|
||||||
|
topic_nid BIGINT NOT NULL DEFAULT nextval('naffka_topic_nid_seq')
|
||||||
|
);
|
||||||
|
|
||||||
|
-- The messages table contains the actual messages.
|
||||||
|
CREATE TABLE IF NOT EXISTS naffka_messages (
|
||||||
|
topic_nid BIGINT NOT NULL,
|
||||||
|
message_offset BIGINT NOT NULL,
|
||||||
|
message_key BYTEA NOT NULL,
|
||||||
|
message_value BYTEA NOT NULL,
|
||||||
|
message_timestamp_ns BIGINT NOT NULL,
|
||||||
|
UNIQUE (topic_nid, message_offset)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertTopicSQL = "" +
|
||||||
|
"INSERT INTO naffka_topics (topic_name) VALUES ($1)" +
|
||||||
|
" ON CONFLICT DO NOTHING" +
|
||||||
|
" RETURNING (topic_nid)"
|
||||||
|
|
||||||
|
const selectTopicSQL = "" +
|
||||||
|
"SELECT topic_nid FROM naffka_topics WHERE topic_name = $1"
|
||||||
|
|
||||||
|
const selectTopicsSQL = "" +
|
||||||
|
"SELECT topic_name, topic_nid FROM naffka_topics"
|
||||||
|
|
||||||
|
const insertMessageSQL = "" +
|
||||||
|
"INSERT INTO naffka_messages (topic_nid, message_offset, message_key, message_value, message_timestamp_ns)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
|
const selectMessagesSQL = "" +
|
||||||
|
"SELECT message_offset, message_key, message_value, message_timestamp_ns" +
|
||||||
|
" FROM naffka_messages WHERE topic_nid = $1 AND $2 <= message_offset AND message_offset < $3" +
|
||||||
|
" ORDER BY message_offset ASC"
|
||||||
|
|
||||||
|
const selectMaxOffsetSQL = "" +
|
||||||
|
"SELECT message_offset FROM naffka_messages WHERE topic_nid = $1" +
|
||||||
|
" ORDER BY message_offset DESC LIMIT 1"
|
||||||
|
|
||||||
|
type postgresqlDatabase struct {
|
||||||
|
db *sql.DB
|
||||||
|
topicsMutex sync.Mutex
|
||||||
|
topicNIDs map[string]int64
|
||||||
|
insertTopicStmt *sql.Stmt
|
||||||
|
selectTopicStmt *sql.Stmt
|
||||||
|
selectTopicsStmt *sql.Stmt
|
||||||
|
insertMessageStmt *sql.Stmt
|
||||||
|
selectMessagesStmt *sql.Stmt
|
||||||
|
selectMaxOffsetStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostgresqlDatabase creates a new naffka database using a postgresql database.
|
||||||
|
// Returns an error if there was a problem setting up the database.
|
||||||
|
func NewPostgresqlDatabase(db *sql.DB) (Database, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
p := &postgresqlDatabase{
|
||||||
|
db: db,
|
||||||
|
topicNIDs: map[string]int64{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = db.Exec(postgresqlSchema); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range []struct {
|
||||||
|
sql string
|
||||||
|
stmt **sql.Stmt
|
||||||
|
}{
|
||||||
|
{insertTopicSQL, &p.insertTopicStmt},
|
||||||
|
{selectTopicSQL, &p.selectTopicStmt},
|
||||||
|
{selectTopicsSQL, &p.selectTopicsStmt},
|
||||||
|
{insertMessageSQL, &p.insertMessageStmt},
|
||||||
|
{selectMessagesSQL, &p.selectMessagesStmt},
|
||||||
|
{selectMaxOffsetSQL, &p.selectMaxOffsetStmt},
|
||||||
|
} {
|
||||||
|
*s.stmt, err = db.Prepare(s.sql)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreMessages implements Database.
|
||||||
|
func (p *postgresqlDatabase) StoreMessages(topic string, messages []Message) error {
|
||||||
|
// Store the messages inside a single database transaction.
|
||||||
|
return withTransaction(p.db, func(txn *sql.Tx) error {
|
||||||
|
s := txn.Stmt(p.insertMessageStmt)
|
||||||
|
topicNID, err := p.assignTopicNID(txn, topic)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, m := range messages {
|
||||||
|
_, err = s.Exec(topicNID, m.Offset, m.Key, m.Value, m.Timestamp.UnixNano())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchMessages implements Database.
|
||||||
|
func (p *postgresqlDatabase) FetchMessages(topic string, startOffset, endOffset int64) (messages []Message, err error) {
|
||||||
|
topicNID, err := p.getTopicNID(nil, topic)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rows, err := p.selectMessagesStmt.Query(topicNID, startOffset, endOffset)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
offset int64
|
||||||
|
key []byte
|
||||||
|
value []byte
|
||||||
|
timestampNano int64
|
||||||
|
)
|
||||||
|
if err = rows.Scan(&offset, &key, &value, ×tampNano); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages = append(messages, Message{
|
||||||
|
Offset: offset,
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
Timestamp: time.Unix(0, timestampNano),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxOffsets implements Database.
|
||||||
|
func (p *postgresqlDatabase) MaxOffsets() (map[string]int64, error) {
|
||||||
|
topicNames, err := p.selectTopics()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := map[string]int64{}
|
||||||
|
for topicName, topicNID := range topicNames {
|
||||||
|
// Lookup the maximum offset.
|
||||||
|
maxOffset, err := p.selectMaxOffset(topicNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if maxOffset > -1 {
|
||||||
|
// Don't include the topic if we haven't sent any messages on it.
|
||||||
|
result[topicName] = maxOffset
|
||||||
|
}
|
||||||
|
// Prefill the numeric ID cache.
|
||||||
|
p.addTopicNIDToCache(topicName, topicNID)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectTopics fetches the names and numeric IDs for all the topics the
|
||||||
|
// database is aware of.
|
||||||
|
func (p *postgresqlDatabase) selectTopics() (map[string]int64, error) {
|
||||||
|
rows, err := p.selectTopicsStmt.Query()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
result := map[string]int64{}
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
topicName string
|
||||||
|
topicNID int64
|
||||||
|
)
|
||||||
|
if err = rows.Scan(&topicName, &topicNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[topicName] = topicNID
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectMaxOffset selects the maximum offset for a topic.
|
||||||
|
// Returns -1 if there aren't any messages for that topic.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func (p *postgresqlDatabase) selectMaxOffset(topicNID int64) (maxOffset int64, err error) {
|
||||||
|
err = p.selectMaxOffsetStmt.QueryRow(topicNID).Scan(&maxOffset)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return -1, nil
|
||||||
|
}
|
||||||
|
return maxOffset, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTopicNID finds the numeric ID for a topic.
|
||||||
|
// The txn argument is optional, this can be used outside a transaction
|
||||||
|
// by setting the txn argument to nil.
|
||||||
|
func (p *postgresqlDatabase) getTopicNID(txn *sql.Tx, topicName string) (topicNID int64, err error) {
|
||||||
|
// Get from the cache.
|
||||||
|
topicNID = p.getTopicNIDFromCache(topicName)
|
||||||
|
if topicNID != 0 {
|
||||||
|
return topicNID, nil
|
||||||
|
}
|
||||||
|
// Get from the database
|
||||||
|
s := p.selectTopicStmt
|
||||||
|
if txn != nil {
|
||||||
|
s = txn.Stmt(s)
|
||||||
|
}
|
||||||
|
err = s.QueryRow(topicName).Scan(&topicNID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// Update the shared cache.
|
||||||
|
p.addTopicNIDToCache(topicName, topicNID)
|
||||||
|
return topicNID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// assignTopicNID assigns a new numeric ID to a topic.
|
||||||
|
// The txn argument is mandatory, this is always called inside a transaction.
|
||||||
|
func (p *postgresqlDatabase) assignTopicNID(txn *sql.Tx, topicName string) (topicNID int64, err error) {
|
||||||
|
// Check if we already have a numeric ID for the topic name.
|
||||||
|
topicNID, err = p.getTopicNID(txn, topicName)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if topicNID != 0 {
|
||||||
|
return topicNID, err
|
||||||
|
}
|
||||||
|
// We don't have a numeric ID for the topic name so we add an entry to the
|
||||||
|
// topics table. If the insert stmt succeeds then it will return the ID.
|
||||||
|
err = txn.Stmt(p.insertTopicStmt).QueryRow(topicName).Scan(&topicNID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
// If the insert stmt succeeded, but didn't return any rows then it
|
||||||
|
// means that someone has added a row for the topic name between us
|
||||||
|
// selecting it the first time and us inserting our own row.
|
||||||
|
// (N.B. postgres only returns modified rows when using "RETURNING")
|
||||||
|
// So we can now just select the row that someone else added.
|
||||||
|
// TODO: This is probably unnecessary since naffka writes to a topic
|
||||||
|
// from a single thread.
|
||||||
|
return p.getTopicNID(txn, topicName)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// Update the cache.
|
||||||
|
p.addTopicNIDToCache(topicName, topicNID)
|
||||||
|
return topicNID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTopicNIDFromCache returns the topicNID from the cache or returns 0 if the
|
||||||
|
// topic is not in the cache.
|
||||||
|
func (p *postgresqlDatabase) getTopicNIDFromCache(topicName string) (topicNID int64) {
|
||||||
|
p.topicsMutex.Lock()
|
||||||
|
defer p.topicsMutex.Unlock()
|
||||||
|
return p.topicNIDs[topicName]
|
||||||
|
}
|
||||||
|
|
||||||
|
// addTopicNIDToCache adds the numeric ID for the topic to the cache.
|
||||||
|
func (p *postgresqlDatabase) addTopicNIDToCache(topicName string, topicNID int64) {
|
||||||
|
p.topicsMutex.Lock()
|
||||||
|
defer p.topicsMutex.Unlock()
|
||||||
|
p.topicNIDs[topicName] = topicNID
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTransaction runs a block of code passing in an SQL transaction
|
||||||
|
// If the code returns an error or panics then the transactions is rolledback
|
||||||
|
// Otherwise the transaction is committed.
|
||||||
|
func withTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||||
|
txn, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
panic(r)
|
||||||
|
} else if err != nil {
|
||||||
|
txn.Rollback()
|
||||||
|
} else {
|
||||||
|
err = txn.Commit()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = fn(txn)
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue