From 125ea75b2419aa5ebf58506e3fbd03a90eb06d68 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 11 Aug 2021 09:44:14 +0100 Subject: [PATCH] Add type field to DeviceMessage, allow fields to be nullable (#1969) --- keyserver/api/api.go | 12 ++++++++++-- keyserver/internal/device_list_update.go | 9 ++++++--- keyserver/internal/device_list_update_test.go | 6 ++++-- keyserver/internal/internal.go | 3 ++- keyserver/storage/postgres/device_keys_table.go | 6 +++++- keyserver/storage/sqlite3/device_keys_table.go | 7 ++++++- keyserver/storage/storage_test.go | 12 ++++++++---- 7 files changed, 41 insertions(+), 14 deletions(-) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index aa6df96f..490f0e41 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -53,9 +53,17 @@ func (k *KeyError) Error() string { return k.Err } +type DeviceMessageType int + +const ( + TypeDeviceKeyUpdate DeviceMessageType = iota + TypeCrossSigningUpdate +) + // DeviceMessage represents the message produced into Kafka by the key server. type DeviceMessage struct { - DeviceKeys + Type DeviceMessageType `json:"Type,omitempty"` + *DeviceKeys `json:"DeviceKeys,omitempty"` // A monotonically increasing number which represents device changes for this user. StreamID int } @@ -76,7 +84,7 @@ type DeviceKeys struct { // WithStreamID returns a copy of this device message with the given stream ID func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { return DeviceMessage{ - DeviceKeys: *k, + DeviceKeys: k, StreamID: streamID, } } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 91d4b53d..1f7c6e2a 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -231,7 +231,8 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. } keys := []api.DeviceMessage{ { - DeviceKeys: api.DeviceKeys{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ DeviceID: event.DeviceID, DisplayName: event.DeviceDisplayName, KeyJSON: k, @@ -417,8 +418,9 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi continue } keys[i] = api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, StreamID: res.StreamID, - DeviceKeys: api.DeviceKeys{ + DeviceKeys: &api.DeviceKeys{ DeviceID: device.DeviceID, DisplayName: device.DisplayName, UserID: res.UserID, @@ -426,7 +428,8 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi }, } existingKeys[i] = api.DeviceMessage{ - DeviceKeys: api.DeviceKeys{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ UserID: res.UserID, DeviceID: device.DeviceID, }, diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 7c170de2..164be6be 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -146,8 +146,9 @@ func TestUpdateHavePrevID(t *testing.T) { t.Fatalf("Update returned an error: %s", err) } want := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, StreamID: event.StreamID, - DeviceKeys: api.DeviceKeys{ + DeviceKeys: &api.DeviceKeys{ DeviceID: event.DeviceID, DisplayName: event.DeviceDisplayName, KeyJSON: event.Keys, @@ -224,8 +225,9 @@ func TestUpdateNoPrevID(t *testing.T) { // wait a bit for db to be updated... time.Sleep(100 * time.Millisecond) want := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, StreamID: 5, - DeviceKeys: api.DeviceKeys{ + DeviceKeys: &api.DeviceKeys{ DeviceID: "JLAFKJWSCS", DisplayName: "Mobile Phone", UserID: remoteUserID, diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index de269911..47eda179 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -573,7 +573,8 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per existingKeys := make([]api.DeviceMessage, len(keysToStore)) for i := range keysToStore { existingKeys[i] = api.DeviceMessage{ - DeviceKeys: api.DeviceKeys{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ UserID: keysToStore[i].UserID, DeviceID: keysToStore[i].DeviceID, }, diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 95064fc8..e5f68fd0 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -114,6 +114,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] return err } // this will be '' when there is no device + keys[i].Type = api.TypeDeviceKeyUpdate keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID if displayName.Valid { @@ -179,7 +180,10 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID } var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceMessage + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{}, + } dk.UserID = userID var keyJSON string var streamID int diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 9112fc6e..ca7ed9cf 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -113,7 +113,11 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceMessage + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{}, + } + dk.Type = api.TypeDeviceKeyUpdate dk.UserID = userID var keyJSON string var streamID int @@ -144,6 +148,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] return err } // this will be '' when there is no device + keys[i].Type = api.TypeDeviceKeyUpdate keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID if displayName.Valid { diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index afdb086d..4e0a8af1 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -105,7 +105,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { bob := "@bob:TestDeviceKeysStreamIDGeneration" msgs := []api.DeviceMessage{ { - DeviceKeys: api.DeviceKeys{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: alice, KeyJSON: []byte(`{"key":"v1"}`), @@ -113,7 +114,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { // StreamID: 1 }, { - DeviceKeys: api.DeviceKeys{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: bob, KeyJSON: []byte(`{"key":"v1"}`), @@ -121,7 +123,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { // StreamID: 1 as this is a different user }, { - DeviceKeys: api.DeviceKeys{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ DeviceID: "another_device", UserID: alice, KeyJSON: []byte(`{"key":"v1"}`), @@ -143,7 +146,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { // updating a device sets the next stream ID for that user msgs = []api.DeviceMessage{ { - DeviceKeys: api.DeviceKeys{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ DeviceID: "AAA", UserID: alice, KeyJSON: []byte(`{"key":"v2"}`),