2020-07-28 16:38:30 +00:00
|
|
|
package storage
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"reflect"
|
|
|
|
"testing"
|
2020-07-30 13:52:21 +00:00
|
|
|
|
|
|
|
"github.com/Shopify/sarama"
|
2020-08-03 16:07:06 +00:00
|
|
|
"github.com/matrix-org/dendrite/keyserver/api"
|
2020-07-28 16:38:30 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
var ctx = context.Background()
|
|
|
|
|
|
|
|
func MustNotError(t *testing.T, err error) {
|
|
|
|
t.Helper()
|
|
|
|
if err == nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
t.Fatalf("operation failed: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestKeyChanges(t *testing.T) {
|
|
|
|
db, err := NewDatabase("file::memory:", nil)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
|
|
|
}
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
2020-07-30 13:52:21 +00:00
|
|
|
userIDs, latest, err := db.KeyChanges(ctx, 0, 1, sarama.OffsetNewest)
|
2020-07-28 16:38:30 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
|
|
|
}
|
|
|
|
if latest != 2 {
|
|
|
|
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
|
|
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestKeyChangesNoDupes(t *testing.T) {
|
|
|
|
db, err := NewDatabase("file::memory:", nil)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
|
|
|
}
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
|
2020-07-30 13:52:21 +00:00
|
|
|
userIDs, latest, err := db.KeyChanges(ctx, 0, 0, sarama.OffsetNewest)
|
2020-07-28 16:38:30 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
|
|
|
}
|
|
|
|
if latest != 2 {
|
|
|
|
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
|
|
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
|
|
|
}
|
|
|
|
}
|
2020-07-30 13:52:21 +00:00
|
|
|
|
|
|
|
func TestKeyChangesUpperLimit(t *testing.T) {
|
|
|
|
db, err := NewDatabase("file::memory:", nil)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
|
|
|
}
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
|
|
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
|
|
|
userIDs, latest, err := db.KeyChanges(ctx, 0, 0, 1)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
|
|
|
}
|
|
|
|
if latest != 1 {
|
|
|
|
t.Fatalf("KeyChanges: got latest=%d want 1", latest)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
|
|
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
|
|
|
}
|
|
|
|
}
|
2020-08-03 16:07:06 +00:00
|
|
|
|
|
|
|
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
|
|
|
// and that they are returned correctly when querying for device keys.
|
|
|
|
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|
|
|
db, err := NewDatabase("file::memory:", nil)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
|
|
|
}
|
|
|
|
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
|
|
|
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
|
|
|
msgs := []api.DeviceMessage{
|
|
|
|
{
|
|
|
|
DeviceKeys: api.DeviceKeys{
|
|
|
|
DeviceID: "AAA",
|
|
|
|
UserID: alice,
|
|
|
|
KeyJSON: []byte(`{"key":"v1"}`),
|
|
|
|
},
|
|
|
|
// StreamID: 1
|
|
|
|
},
|
|
|
|
{
|
|
|
|
DeviceKeys: api.DeviceKeys{
|
|
|
|
DeviceID: "AAA",
|
|
|
|
UserID: bob,
|
|
|
|
KeyJSON: []byte(`{"key":"v1"}`),
|
|
|
|
},
|
|
|
|
// StreamID: 1 as this is a different user
|
|
|
|
},
|
|
|
|
{
|
|
|
|
DeviceKeys: api.DeviceKeys{
|
|
|
|
DeviceID: "another_device",
|
|
|
|
UserID: alice,
|
|
|
|
KeyJSON: []byte(`{"key":"v1"}`),
|
|
|
|
},
|
|
|
|
// StreamID: 2 as this is a 2nd device key
|
|
|
|
},
|
|
|
|
}
|
|
|
|
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
|
|
|
if msgs[0].StreamID != 1 {
|
|
|
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
|
|
|
}
|
|
|
|
if msgs[1].StreamID != 1 {
|
|
|
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
|
|
|
}
|
|
|
|
if msgs[2].StreamID != 2 {
|
|
|
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
|
|
|
}
|
|
|
|
|
|
|
|
// updating a device sets the next stream ID for that user
|
|
|
|
msgs = []api.DeviceMessage{
|
|
|
|
{
|
|
|
|
DeviceKeys: api.DeviceKeys{
|
|
|
|
DeviceID: "AAA",
|
|
|
|
UserID: alice,
|
|
|
|
KeyJSON: []byte(`{"key":"v2"}`),
|
|
|
|
},
|
|
|
|
// StreamID: 3
|
|
|
|
},
|
|
|
|
}
|
|
|
|
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
|
|
|
if msgs[0].StreamID != 3 {
|
|
|
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Querying for device keys returns the latest stream IDs
|
|
|
|
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
|
|
|
}
|
|
|
|
wantStreamIDs := map[string]int{
|
|
|
|
"AAA": 3,
|
|
|
|
"another_device": 2,
|
|
|
|
}
|
|
|
|
if len(msgs) != len(wantStreamIDs) {
|
|
|
|
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
|
|
|
}
|
|
|
|
for _, m := range msgs {
|
|
|
|
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
|
|
|
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|