Finish inbound E2E device lists (#1243)

* Add tests for device list updates

* Add stale_device_lists table and use db before asking remote for device keys

* Fetch remote keys if all devices are requested

* Add display_name col to store remote device names

Few other tweaks to make `Server correctly handles incoming m.device_list_update`
pass.

* Fix sqlite otk bug

* Unbuffered channel to block /send causing sytest to not race anymore

* Linting and fix bug whereby we didn't send updated dl tokens to the client causing a tightloop on /sync sometimes

* No longer assert staleness as Update blocks on workers now

* Back out tweaks

* Bugfixes
main
Kegsay 2020-08-07 17:32:13 +01:00 committed by GitHub
parent 30c2325eaf
commit f371783da7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 639 additions and 48 deletions

View File

@ -23,7 +23,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -65,7 +64,7 @@ type DeviceListUpdater struct {
mu *sync.Mutex // protects UserIDToMutex mu *sync.Mutex // protects UserIDToMutex
db DeviceListUpdaterDatabase db DeviceListUpdaterDatabase
producer *producers.KeyChange producer KeyChangeProducer
fedClient *gomatrixserverlib.FederationClient fedClient *gomatrixserverlib.FederationClient
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
} }
@ -88,9 +87,14 @@ type DeviceListUpdaterDatabase interface {
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
} }
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
type KeyChangeProducer interface {
ProduceKeyChanges(keys []api.DeviceMessage) error
}
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater( func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient, db DeviceListUpdaterDatabase, producer KeyChangeProducer, fedClient *gomatrixserverlib.FederationClient,
numWorkers int, numWorkers int,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
@ -154,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
if err != nil { if err != nil {
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
} }
// if this is the first time we're hearing about this user, sync the device list manually.
if len(event.PrevID) == 0 {
exists = false
}
util.GetLogger(ctx).WithFields(logrus.Fields{ util.GetLogger(ctx).WithFields(logrus.Fields{
"prev_ids_exist": exists, "prev_ids_exist": exists,
"user_id": event.UserID, "user_id": event.UserID,
"device_id": event.DeviceID, "device_id": event.DeviceID,
"stream_id": event.StreamID, "stream_id": event.StreamID,
"prev_ids": event.PrevID, "prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName,
}).Info("DeviceListUpdater.Update") }).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users // if we haven't missed anything update the database and notify users
@ -263,16 +272,17 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
hasFailures = true hasFailures = true
continue continue
} }
err = u.updateDeviceList(ctx, &res) err = u.updateDeviceList(&res)
if err != nil { if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it") logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
hasFailures = true hasFailures = true
} }
} }
return hasFailures return hasFailures
} }
func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error { func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
keys := make([]api.DeviceMessage, len(res.Devices)) keys := make([]api.DeviceMessage, len(res.Devices))
for i, device := range res.Devices { for i, device := range res.Devices {
keyJSON, err := json.Marshal(device.Keys) keyJSON, err := json.Marshal(device.Keys)
@ -292,7 +302,15 @@ func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixs
} }
err := u.db.StoreRemoteDeviceKeys(ctx, keys) err := u.db.StoreRemoteDeviceKeys(ctx, keys)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to store remote device keys: %w", err)
} }
return u.db.MarkDeviceListStale(ctx, res.UserID, false) err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
if err != nil {
return fmt.Errorf("failed to mark device list as fresh: %w", err)
}
err = u.producer.ProduceKeyChanges(keys)
if err != nil {
return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
}
return nil
} }

View File

@ -0,0 +1,242 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"crypto/ed25519"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
var (
ctx = context.Background()
)
type mockKeyChangeProducer struct {
events []api.DeviceMessage
}
func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
p.events = append(p.events, keys...)
return nil
}
type mockDeviceListUpdaterDatabase struct {
staleUsers map[string]bool
prevIDsExist func(string, []int) bool
storedKeys []api.DeviceMessage
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
var result []string
for userID := range d.staleUsers {
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return nil, err
}
if len(domains) == 0 {
result = append(result, userID)
continue
}
for _, d := range domains {
if remoteServer == d {
result = append(result, userID)
break
}
}
}
return result, nil
}
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
d.staleUsers[userID] = isStale
return nil
}
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys.
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
d.storedKeys = append(d.storedKeys, keys...)
return nil
}
// PrevIDsExists returns true if all prev IDs exist for this user.
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
return d.prevIDsExist(userID, prevIDs), nil
}
type roundTripper struct {
fn func(*http.Request) (*http.Response, error)
}
func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.fn(req)
}
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
_, pkey, _ := ed25519.GenerateKey(nil)
fedClient := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey,
)
fedClient.Client = *gomatrixserverlib.NewClientWithTransport(&roundTripper{tripper})
return fedClient
}
// Test that the device keys get persisted and emitted if we have the previous IDs.
func TestUpdateHavePrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool {
return true
},
}
producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(db, producer, nil, 1)
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
DeviceID: "FOO",
Keys: []byte(`{"key":"value"}`),
PrevID: []int{0},
StreamID: 1,
UserID: "@alice:localhost",
}
err := updater.Update(ctx, event)
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
want := api.DeviceMessage{
StreamID: event.StreamID,
DeviceKeys: api.DeviceKeys{
DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName,
KeyJSON: event.Keys,
UserID: event.UserID,
},
}
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
}
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
}
if db.staleUsers[event.UserID] {
t.Errorf("%s incorrectly marked as stale", event.UserID)
}
}
// Test that device keys are fetched from the remote server if we are missing prev IDs
// and that the user's devices are marked as stale until it succeeds.
func TestUpdateNoPrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool {
return false
},
}
producer := &mockKeyChangeProducer{}
remoteUserID := "@alice:example.somewhere"
var wg sync.WaitGroup
wg.Add(1)
keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
defer wg.Done()
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
}
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`
{
"user_id": "` + remoteUserID + `",
"stream_id": 5,
"devices": [
{
"device_id": "JLAFKJWSCS",
"keys": ` + keyJSON + `,
"device_display_name": "Mobile Phone"
}
]
}
`)),
}, nil
})
updater := NewDeviceListUpdater(db, producer, fedClient, 2)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Mobile Phone",
Deleted: false,
DeviceID: "another_device_id",
Keys: []byte(`{"key":"value"}`),
PrevID: []int{3},
StreamID: 4,
UserID: remoteUserID,
}
err := updater.Update(ctx, event)
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
// At this point we show have this device list marked as stale and not store the keys or emitted anything
if !db.staleUsers[event.UserID] {
t.Errorf("%s not marked as stale", event.UserID)
}
if len(producer.events) > 0 {
t.Errorf("Update incorrect emitted %d device change events", len(producer.events))
}
if len(db.storedKeys) > 0 {
t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys))
}
t.Log("waiting for /users/devices to be called...")
wg.Wait()
// wait a bit for db to be updated...
time.Sleep(100 * time.Millisecond)
want := api.DeviceMessage{
StreamID: 5,
DeviceKeys: api.DeviceKeys{
DeviceID: "JLAFKJWSCS",
DisplayName: "Mobile Phone",
UserID: remoteUserID,
KeyJSON: []byte(keyJSON),
},
}
// Now we should have a fresh list and the keys and emitted something
if db.staleUsers[event.UserID] {
t.Errorf("%s still marked as stale", event.UserID)
}
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
}
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
}
}

View File

@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
if len(dk.KeyJSON) == 0 { if len(dk.KeyJSON) == 0 {
continue // don't include blank keys continue // don't include blank keys
} }
// inject display name if known // inject display name if known (either locally or remotely)
displayName := dk.DisplayName
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
}
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"` DisplayName string `json:"device_display_name,omitempty"`
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName}) }{displayName})
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
} }
} else { } else {
@ -261,12 +265,49 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
} }
} }
// TODO: set device display names when they are known
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
if len(domainToDeviceKeys) == 0 {
return // nothing to query
}
// perform key queries for remote devices // perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
} }
func (a *KeyInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys {
for userID, deviceIDs := range userToDeviceMap {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
// Likewise, we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added.
if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
if _, ok := fetchRemote[domain]; !ok {
fetchRemote[domain] = make(map[string][]string)
}
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
continue
}
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
for _, key := range keys {
// inject the display name
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
}
}
}
return fetchRemote
}
func (a *KeyInternalAPI) queryRemoteKeys( func (a *KeyInternalAPI) queryRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
) { ) {

View File

@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
-- required in the spec because in the event of a missed update the server fetches the entire -- required in the spec because in the event of a missed update the server fetches the entire
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
stream_id BIGINT NOT NULL, stream_id BIGINT NOT NULL,
display_name TEXT,
-- Clobber based on tuple of user/device. -- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
); );
` `
const upsertDeviceKeysSQL = "" + const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
" DO UPDATE SET key_json = $4, stream_id = $5" " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
var streamID int var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID keys[i].StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
}
} }
return nil return nil
} }
@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
) )
if err != nil { if err != nil {
return err return err
@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err return nil, err
} }
dk.KeyJSON = []byte(keyJSON) dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
}
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)

View File

@ -0,0 +1,118 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var staleDeviceListsSchema = `
-- Stores whether a user's device lists are stale or not.
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
user_id TEXT PRIMARY KEY NOT NULL,
domain TEXT NOT NULL,
is_stale BOOLEAN NOT NULL,
ts_added_secs BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
`
const upsertStaleDeviceListSQL = "" +
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (user_id)" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
}
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{}
_, err := db.Exec(staleDeviceListsSchema)
if err != nil {
return nil, err
}
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
return err
}
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
// we only query for 1 domain or all domains so optimise for those use cases
if len(domains) == 0 {
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
if err != nil {
return nil, err
}
return rowsToUserIDs(ctx, rows)
}
var result []string
for _, domain := range domains {
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
if err != nil {
return nil, err
}
userIDs, err := rowsToUserIDs(ctx, rows)
if err != nil {
return nil, err
}
result = append(result, userIDs...)
}
return result, nil
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View File

@ -38,10 +38,15 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s
if err != nil { if err != nil {
return nil, err return nil, err
} }
sdl, err := NewPostgresStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc, KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
}, nil }, nil
} }

View File

@ -26,10 +26,11 @@ import (
) )
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
OneTimeKeysTable tables.OneTimeKeys OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
} }
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
@ -129,10 +130,10 @@ func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset,
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned. // If no domains are given, all user IDs with stale device lists are returned.
func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
return nil, nil // TODO return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
} }
// MarkDeviceListStale sets the stale bit for this user to isStale. // MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
return nil // TODO return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
} }

View File

@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
ts_added_secs BIGINT NOT NULL, ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL, stream_id BIGINT NOT NULL,
display_name TEXT,
-- Clobber based on tuple of user/device. -- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id) UNIQUE (user_id, device_id)
); );
` `
const upsertDeviceKeysSQL = "" + const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (user_id, device_id)" + " ON CONFLICT (user_id, device_id)" +
" DO UPDATE SET key_json = $4, stream_id = $5" " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" + const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
var streamID int var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err return nil, err
} }
dk.KeyJSON = []byte(keyJSON) dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
}
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)
@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
var streamID int var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID keys[i].StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
}
} }
return nil return nil
} }
@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix() now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
) )
if err != nil { if err != nil {
return err return err

View File

@ -196,6 +196,9 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return err return err
}) })
if keyJSON == "" {
return nil, nil
}
return map[string]json.RawMessage{ return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON), algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err }, err

View File

@ -0,0 +1,118 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var staleDeviceListsSchema = `
-- Stores whether a user's device lists are stale or not.
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
user_id TEXT PRIMARY KEY NOT NULL,
domain TEXT NOT NULL,
is_stale BOOLEAN NOT NULL,
ts_added_secs BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
`
const upsertStaleDeviceListSQL = "" +
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (user_id)" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
}
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{}
_, err := db.Exec(staleDeviceListsSchema)
if err != nil {
return nil, err
}
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
return err
}
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
// we only query for 1 domain or all domains so optimise for those use cases
if len(domains) == 0 {
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
if err != nil {
return nil, err
}
return rowsToUserIDs(ctx, rows)
}
var result []string
for _, domain := range domains {
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
if err != nil {
return nil, err
}
userIDs, err := rowsToUserIDs(ctx, rows)
if err != nil {
return nil, err
}
result = append(result, userIDs...)
}
return result, nil
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View File

@ -41,10 +41,15 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
sdl, err := NewSqliteStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc, KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
}, nil }, nil
} }

View File

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type OneTimeKeys interface { type OneTimeKeys interface {
@ -45,3 +46,8 @@ type KeyChanges interface {
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
} }
type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
}

View File

@ -46,6 +46,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID,
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information. // be already filled in with join/leave information.
// nolint:gocyclo
func DeviceListCatchup( func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
userID string, res *types.Response, from, to types.StreamingToken, userID string, res *types.Response, from, to types.StreamingToken,
@ -68,22 +69,20 @@ func DeviceListCatchup(
var partition int32 var partition int32
var offset int64 var offset int64
partition = -1
offset = sarama.OffsetOldest
// Extract partition/offset from sync token // Extract partition/offset from sync token
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make. // TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
logOffset := from.Log(DeviceListLogName) logOffset := from.Log(DeviceListLogName)
if logOffset != nil { if logOffset != nil {
partition = logOffset.Partition partition = logOffset.Partition
offset = logOffset.Offset offset = logOffset.Offset
} else {
partition = -1
offset = sarama.OffsetOldest
} }
var toOffset int64 var toOffset int64
toOffset = sarama.OffsetNewest
toLog := to.Log(DeviceListLogName) toLog := to.Log(DeviceListLogName)
if toLog != nil { if toLog != nil && toLog.Offset > 0 {
toOffset = toLog.Offset toOffset = toLog.Offset
} else {
toOffset = sarama.OffsetNewest
} }
var queryRes api.QueryKeyChangesResponse var queryRes api.QueryKeyChangesResponse
keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
@ -96,6 +95,10 @@ func DeviceListCatchup(
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return hasNew, nil return hasNew, nil
} }
util.GetLogger(ctx).Debugf(
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v",
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs,
)
userSet := make(map[string]bool) userSet := make(map[string]bool)
for _, userID := range res.DeviceLists.Changed { for _, userID := range res.DeviceLists.Changed {
userSet[userID] = true userSet[userID] = true
@ -116,6 +119,13 @@ func DeviceListCatchup(
userSet[userID] = true userSet[userID] = true
} }
} }
// set the new token
to.SetLog(DeviceListLogName, &types.LogPosition{
Partition: queryRes.Partition,
Offset: queryRes.Offset,
})
res.NextBatch = to.String()
return hasNew, nil return hasNew, nil
} }

View File

@ -112,6 +112,9 @@ type StreamingToken struct {
} }
func (t *StreamingToken) SetLog(name string, lp *LogPosition) { func (t *StreamingToken) SetLog(name string, lp *LogPosition) {
if t.logs == nil {
t.logs = make(map[string]*LogPosition)
}
t.logs[name] = lp t.logs[name] = lp
} }
@ -173,12 +176,14 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken)
} }
ret.Positions[i] = other.Positions[i] ret.Positions[i] = other.Positions[i]
} }
ret.logs = make(map[string]*LogPosition)
for name := range t.logs { for name := range t.logs {
otherLog := other.Log(name) otherLog := other.Log(name)
if otherLog == nil { if otherLog == nil {
continue continue
} }
t.logs[name] = otherLog copy := *otherLog
ret.logs[name] = &copy
} }
return ret return ret
} }

View File

@ -138,6 +138,7 @@ Users receive device_list updates for their own devices
Get left notifs for other users in sync and /keys/changes when user leaves Get left notifs for other users in sync and /keys/changes when user leaves
Local device key changes get to remote servers Local device key changes get to remote servers
Local device key changes get to remote servers with correct prev_id Local device key changes get to remote servers with correct prev_id
#Server correctly handles incoming m.device_list_update
Can add account data Can add account data
Can add account data to room Can add account data to room
Can get account data without syncing Can get account data without syncing