Add type field to DeviceMessage, allow fields to be nullable (#1969)

main
Neil Alexander 2021-08-11 09:44:14 +01:00 committed by GitHub
parent b1377d991a
commit 125ea75b24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 41 additions and 14 deletions

View File

@ -53,9 +53,17 @@ func (k *KeyError) Error() string {
return k.Err return k.Err
} }
type DeviceMessageType int
const (
TypeDeviceKeyUpdate DeviceMessageType = iota
TypeCrossSigningUpdate
)
// DeviceMessage represents the message produced into Kafka by the key server. // DeviceMessage represents the message produced into Kafka by the key server.
type DeviceMessage struct { type DeviceMessage struct {
DeviceKeys Type DeviceMessageType `json:"Type,omitempty"`
*DeviceKeys `json:"DeviceKeys,omitempty"`
// A monotonically increasing number which represents device changes for this user. // A monotonically increasing number which represents device changes for this user.
StreamID int StreamID int
} }
@ -76,7 +84,7 @@ type DeviceKeys struct {
// WithStreamID returns a copy of this device message with the given stream ID // WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
return DeviceMessage{ return DeviceMessage{
DeviceKeys: *k, DeviceKeys: k,
StreamID: streamID, StreamID: streamID,
} }
} }

View File

@ -231,7 +231,8 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
} }
keys := []api.DeviceMessage{ keys := []api.DeviceMessage{
{ {
DeviceKeys: api.DeviceKeys{ Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: event.DeviceID, DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName, DisplayName: event.DeviceDisplayName,
KeyJSON: k, KeyJSON: k,
@ -417,8 +418,9 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
continue continue
} }
keys[i] = api.DeviceMessage{ keys[i] = api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
StreamID: res.StreamID, StreamID: res.StreamID,
DeviceKeys: api.DeviceKeys{ DeviceKeys: &api.DeviceKeys{
DeviceID: device.DeviceID, DeviceID: device.DeviceID,
DisplayName: device.DisplayName, DisplayName: device.DisplayName,
UserID: res.UserID, UserID: res.UserID,
@ -426,7 +428,8 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
}, },
} }
existingKeys[i] = api.DeviceMessage{ existingKeys[i] = api.DeviceMessage{
DeviceKeys: api.DeviceKeys{ Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: res.UserID, UserID: res.UserID,
DeviceID: device.DeviceID, DeviceID: device.DeviceID,
}, },

View File

@ -146,8 +146,9 @@ func TestUpdateHavePrevID(t *testing.T) {
t.Fatalf("Update returned an error: %s", err) t.Fatalf("Update returned an error: %s", err)
} }
want := api.DeviceMessage{ want := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
StreamID: event.StreamID, StreamID: event.StreamID,
DeviceKeys: api.DeviceKeys{ DeviceKeys: &api.DeviceKeys{
DeviceID: event.DeviceID, DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName, DisplayName: event.DeviceDisplayName,
KeyJSON: event.Keys, KeyJSON: event.Keys,
@ -224,8 +225,9 @@ func TestUpdateNoPrevID(t *testing.T) {
// wait a bit for db to be updated... // wait a bit for db to be updated...
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
want := api.DeviceMessage{ want := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
StreamID: 5, StreamID: 5,
DeviceKeys: api.DeviceKeys{ DeviceKeys: &api.DeviceKeys{
DeviceID: "JLAFKJWSCS", DeviceID: "JLAFKJWSCS",
DisplayName: "Mobile Phone", DisplayName: "Mobile Phone",
UserID: remoteUserID, UserID: remoteUserID,

View File

@ -573,7 +573,8 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
existingKeys := make([]api.DeviceMessage, len(keysToStore)) existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore { for i := range keysToStore {
existingKeys[i] = api.DeviceMessage{ existingKeys[i] = api.DeviceMessage{
DeviceKeys: api.DeviceKeys{ Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: keysToStore[i].UserID, UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID, DeviceID: keysToStore[i].DeviceID,
}, },

View File

@ -114,6 +114,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].Type = api.TypeDeviceKeyUpdate
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID keys[i].StreamID = streamID
if displayName.Valid { if displayName.Valid {
@ -179,7 +180,10 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
} }
var result []api.DeviceMessage var result []api.DeviceMessage
for rows.Next() { for rows.Next() {
var dk api.DeviceMessage dk := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{},
}
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int

View File

@ -113,7 +113,11 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceMessage var result []api.DeviceMessage
for rows.Next() { for rows.Next() {
var dk api.DeviceMessage dk := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{},
}
dk.Type = api.TypeDeviceKeyUpdate
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int
@ -144,6 +148,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].Type = api.TypeDeviceKeyUpdate
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID keys[i].StreamID = streamID
if displayName.Valid { if displayName.Valid {

View File

@ -105,7 +105,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
bob := "@bob:TestDeviceKeysStreamIDGeneration" bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{ msgs := []api.DeviceMessage{
{ {
DeviceKeys: api.DeviceKeys{ Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA", DeviceID: "AAA",
UserID: alice, UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`), KeyJSON: []byte(`{"key":"v1"}`),
@ -113,7 +114,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
// StreamID: 1 // StreamID: 1
}, },
{ {
DeviceKeys: api.DeviceKeys{ Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA", DeviceID: "AAA",
UserID: bob, UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`), KeyJSON: []byte(`{"key":"v1"}`),
@ -121,7 +123,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
// StreamID: 1 as this is a different user // StreamID: 1 as this is a different user
}, },
{ {
DeviceKeys: api.DeviceKeys{ Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "another_device", DeviceID: "another_device",
UserID: alice, UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`), KeyJSON: []byte(`{"key":"v1"}`),
@ -143,7 +146,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
// updating a device sets the next stream ID for that user // updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{ msgs = []api.DeviceMessage{
{ {
DeviceKeys: api.DeviceKeys{ Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA", DeviceID: "AAA",
UserID: alice, UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`), KeyJSON: []byte(`{"key":"v2"}`),