Refactor account data (#1150)

* Refactor account data

* Tweak database fetching

* Tweaks

* Restore syncProducer notification

* Various tweaks, update tag behaviour

* Fix initial sync
main
Neil Alexander 2020-06-18 18:36:03 +01:00 committed by GitHub
parent 3547a1768c
commit dc0bac85d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 248 additions and 222 deletions

View File

@ -16,21 +16,20 @@ package routing
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
func GetAccountData( func GetAccountData(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string, roomID string, dataType string, userID string, roomID string, dataType string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -40,18 +39,28 @@ func GetAccountData(
} }
} }
localpart, _, err := gomatrixserverlib.SplitID('@', userID) dataReq := api.QueryAccountDataRequest{
if err != nil { UserID: userID,
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") DataType: dataType,
return jsonerror.InternalServerError() RoomID: roomID,
}
dataRes := api.QueryAccountDataResponse{}
if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err))
} }
if data, err := accountDB.GetAccountDataByType( var data json.RawMessage
req.Context(), localpart, roomID, dataType, var ok bool
); err == nil { if roomID != "" {
data, ok = dataRes.RoomAccountData[roomID][dataType]
} else {
data, ok = dataRes.GlobalAccountData[dataType]
}
if ok {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: data.Content, JSON: data,
} }
} }
@ -63,7 +72,7 @@ func GetAccountData(
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
func SaveAccountData( func SaveAccountData(
req *http.Request, accountDB accounts.Database, device *api.Device, req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
@ -73,12 +82,6 @@ func SaveAccountData(
} }
} }
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
defer req.Body.Close() // nolint: errcheck defer req.Body.Close() // nolint: errcheck
if req.Body == http.NoBody { if req.Body == http.NoBody {
@ -101,13 +104,19 @@ func SaveAccountData(
} }
} }
if err := accountDB.SaveAccountData( dataReq := api.InputAccountDataRequest{
req.Context(), localpart, roomID, dataType, string(body), UserID: userID,
); err != nil { DataType: dataType,
util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed") RoomID: roomID,
return jsonerror.InternalServerError() AccountData: json.RawMessage(body),
}
dataRes := api.InputAccountDataResponse{}
if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
return util.ErrorResponse(err)
} }
// TODO: user API should do this since it's account data
if err := syncProducer.SendData(userID, roomID, dataType); err != nil { if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()

View File

@ -24,23 +24,14 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// newTag creates and returns a new gomatrix.TagContent
func newTag() gomatrix.TagContent {
return gomatrix.TagContent{
Tags: make(map[string]gomatrix.TagProperties),
}
}
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags( func GetTags(
req *http.Request, req *http.Request,
accountDB accounts.Database, userAPI api.UserInternalAPI,
device *api.Device, device *api.Device,
userID string, userID string,
roomID string, roomID string,
@ -54,22 +45,15 @@ func GetTags(
} }
} }
_, data, err := obtainSavedTags(req, userID, roomID, accountDB) tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if data == nil {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: data.Content, JSON: tagContent,
} }
} }
@ -78,7 +62,7 @@ func GetTags(
// the tag to the "map" and saving the new "map" to the DB // the tag to the "map" and saving the new "map" to the DB
func PutTag( func PutTag(
req *http.Request, req *http.Request,
accountDB accounts.Database, userAPI api.UserInternalAPI,
device *api.Device, device *api.Device,
userID string, userID string,
roomID string, roomID string,
@ -98,34 +82,25 @@ func PutTag(
return *reqErr return *reqErr
} }
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
var tagContent gomatrix.TagContent if tagContent.Tags == nil {
if data != nil { tagContent.Tags = make(map[string]gomatrix.TagProperties)
if err = json.Unmarshal(data.Content, &tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
}
} else {
tagContent = newTag()
} }
tagContent.Tags[tag] = properties tagContent.Tags[tag] = properties
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// Send data to syncProducer in order to inform clients of changes if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
// Run in a goroutine in order to prevent blocking the tag request response logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
go func() { }
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
}()
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -138,7 +113,7 @@ func PutTag(
// the "map" and then saving the new "map" in the DB // the "map" and then saving the new "map" in the DB
func DeleteTag( func DeleteTag(
req *http.Request, req *http.Request,
accountDB accounts.Database, userAPI api.UserInternalAPI,
device *api.Device, device *api.Device,
userID string, userID string,
roomID string, roomID string,
@ -153,28 +128,12 @@ func DeleteTag(
} }
} }
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// If there are no tags in the database, exit
if data == nil {
// Spec only defines 200 responses for this endpoint so we don't return anything else.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
var tagContent gomatrix.TagContent
err = json.Unmarshal(data.Content, &tagContent)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
}
// Check whether the tag to be deleted exists // Check whether the tag to be deleted exists
if _, ok := tagContent.Tags[tag]; ok { if _, ok := tagContent.Tags[tag]; ok {
delete(tagContent.Tags, tag) delete(tagContent.Tags, tag)
@ -185,18 +144,16 @@ func DeleteTag(
JSON: struct{}{}, JSON: struct{}{},
} }
} }
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// Send data to syncProducer in order to inform clients of changes // TODO: user API should do this since it's account data
// Run in a goroutine in order to prevent blocking the tag request response if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
go func() { logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { }
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
}()
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -210,32 +167,46 @@ func obtainSavedTags(
req *http.Request, req *http.Request,
userID string, userID string,
roomID string, roomID string,
accountDB accounts.Database, userAPI api.UserInternalAPI,
) (string, *gomatrixserverlib.ClientEvent, error) { ) (tags gomatrix.TagContent, err error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) dataReq := api.QueryAccountDataRequest{
if err != nil { UserID: userID,
return "", nil, err RoomID: roomID,
DataType: "m.tag",
} }
dataRes := api.QueryAccountDataResponse{}
data, err := accountDB.GetAccountDataByType( err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes)
req.Context(), localpart, roomID, "m.tag", if err != nil {
) return
}
return localpart, data, err data, ok := dataRes.RoomAccountData[roomID]["m.tag"]
if !ok {
return
}
if err = json.Unmarshal(data, &tags); err != nil {
return
}
return tags, nil
} }
// saveTagData saves the provided tag data into the database // saveTagData saves the provided tag data into the database
func saveTagData( func saveTagData(
req *http.Request, req *http.Request,
localpart string, userID string,
roomID string, roomID string,
accountDB accounts.Database, userAPI api.UserInternalAPI,
Tag gomatrix.TagContent, Tag gomatrix.TagContent,
) error { ) error {
newTagData, err := json.Marshal(Tag) newTagData, err := json.Marshal(Tag)
if err != nil { if err != nil {
return err return err
} }
dataReq := api.InputAccountDataRequest{
return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData)) UserID: userID,
RoomID: roomID,
DataType: "m.tag",
AccountData: json.RawMessage(newTagData),
}
dataRes := api.InputAccountDataResponse{}
return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes)
} }

View File

@ -476,7 +476,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer) return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -486,7 +486,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -496,7 +496,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"]) return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"])
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
@ -506,7 +506,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"]) return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"])
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
@ -604,7 +604,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer) return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
@ -614,7 +614,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
@ -624,7 +624,7 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}), }),
).Methods(http.MethodDelete, http.MethodOptions) ).Methods(http.MethodDelete, http.MethodOptions)

View File

@ -205,22 +205,34 @@ func (rp *RequestPool) appendAccountData(
if req.since == nil { if req.since == nil {
// If this is the initial sync, we don't need to check if a data has // If this is the initial sync, we don't need to check if a data has
// already been sent. Instead, we send the whole batch. // already been sent. Instead, we send the whole batch.
var res userapi.QueryAccountDataResponse dataReq := &userapi.QueryAccountDataRequest{
err := rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{
UserID: userID, UserID: userID,
}, &res) }
if err != nil { dataRes := &userapi.QueryAccountDataResponse{}
if err := rp.userAPI.QueryAccountData(req.ctx, dataReq, dataRes); err != nil {
return nil, err return nil, err
} }
data.AccountData.Events = res.GlobalAccountData for datatype, databody := range dataRes.GlobalAccountData {
data.AccountData.Events = append(
data.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: datatype,
Content: gomatrixserverlib.RawJSON(databody),
},
)
}
for r, j := range data.Rooms.Join { for r, j := range data.Rooms.Join {
if len(res.RoomAccountData[r]) > 0 { for datatype, databody := range dataRes.RoomAccountData[r] {
j.AccountData.Events = res.RoomAccountData[r] j.AccountData.Events = append(
j.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: datatype,
Content: gomatrixserverlib.RawJSON(databody),
},
)
data.Rooms.Join[r] = j data.Rooms.Join[r] = j
} }
} }
return data, nil return data, nil
} }
@ -249,33 +261,42 @@ func (rp *RequestPool) appendAccountData(
// Iterate over the rooms // Iterate over the rooms
for roomID, dataTypes := range dataTypes { for roomID, dataTypes := range dataTypes {
events := []gomatrixserverlib.ClientEvent{}
// Request the missing data from the database // Request the missing data from the database
for _, dataType := range dataTypes { for _, dataType := range dataTypes {
var res userapi.QueryAccountDataResponse dataReq := userapi.QueryAccountDataRequest{
err = rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{
UserID: userID, UserID: userID,
RoomID: roomID, RoomID: roomID,
DataType: dataType, DataType: dataType,
}, &res) }
dataRes := userapi.QueryAccountDataResponse{}
err = rp.userAPI.QueryAccountData(req.ctx, &dataReq, &dataRes)
if err != nil { if err != nil {
return nil, err continue
} }
if len(res.RoomAccountData[roomID]) > 0 { if roomID == "" {
events = append(events, res.RoomAccountData[roomID]...) if globalData, ok := dataRes.GlobalAccountData[dataType]; ok {
} else if len(res.GlobalAccountData) > 0 { data.AccountData.Events = append(
events = append(events, res.GlobalAccountData...) data.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: dataType,
Content: gomatrixserverlib.RawJSON(globalData),
},
)
}
} else {
if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok {
joinData := data.Rooms.Join[roomID]
joinData.AccountData.Events = append(
joinData.AccountData.Events,
gomatrixserverlib.ClientEvent{
Type: dataType,
Content: gomatrixserverlib.RawJSON(roomData),
},
)
data.Rooms.Join[roomID] = joinData
}
} }
} }
// Append the data to the response
if len(roomID) > 0 {
jr := data.Rooms.Join[roomID]
jr.AccountData.Events = events
data.Rooms.Join[roomID] = jr
} else {
data.AccountData.Events = events
}
} }
return data, nil return data, nil

View File

@ -16,12 +16,14 @@ package api
import ( import (
"context" "context"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// UserInternalAPI is the internal API for information about users and devices. // UserInternalAPI is the internal API for information about users and devices.
type UserInternalAPI interface { type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@ -30,6 +32,18 @@ type UserInternalAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
} }
// InputAccountDataRequest is the request for InputAccountData
type InputAccountDataRequest struct {
UserID string // required: the user to set account data for
RoomID string // optional: the room to associate the account data with
DataType string // optional: the data type of the data
AccountData json.RawMessage // required: the message content
}
// InputAccountDataResponse is the response for InputAccountData
type InputAccountDataResponse struct {
}
// QueryAccessTokenRequest is the request for QueryAccessToken // QueryAccessTokenRequest is the request for QueryAccessToken
type QueryAccessTokenRequest struct { type QueryAccessTokenRequest struct {
AccessToken string AccessToken string
@ -46,18 +60,15 @@ type QueryAccessTokenResponse struct {
// QueryAccountDataRequest is the request for QueryAccountData // QueryAccountDataRequest is the request for QueryAccountData
type QueryAccountDataRequest struct { type QueryAccountDataRequest struct {
UserID string // required: the user to get account data for. UserID string // required: the user to get account data for.
// TODO: This is a terribly confusing API shape :/ RoomID string // optional: the room ID, or global account data if not specified.
DataType string // optional: if specified returns only a single event matching this data type. DataType string // optional: the data type, or all types if not specified.
// optional: Only used if DataType is set. If blank returns global account data matching the data type.
// If set, returns only room account data matching this data type.
RoomID string
} }
// QueryAccountDataResponse is the response for QueryAccountData // QueryAccountDataResponse is the response for QueryAccountData
type QueryAccountDataResponse struct { type QueryAccountDataResponse struct {
GlobalAccountData []gomatrixserverlib.ClientEvent GlobalAccountData map[string]json.RawMessage // type -> data
RoomAccountData map[string][]gomatrixserverlib.ClientEvent RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data
} }
// QueryDevicesRequest is the request for QueryDevices // QueryDevicesRequest is the request for QueryDevices

View File

@ -17,6 +17,7 @@ package internal
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -38,6 +39,20 @@ type UserInternalAPI struct {
AppServices []config.ApplicationService AppServices []config.ApplicationService
} }
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return err
}
if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
}
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
if req.AccountType == api.AccountTypeGuest { if req.AccountType == api.AccountTypeGuest {
acc, err := a.AccountDB.CreateGuestAccount(ctx) acc, err := a.AccountDB.CreateGuestAccount(ctx)
@ -130,17 +145,21 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName)
} }
if req.DataType != "" { if req.DataType != "" {
var event *gomatrixserverlib.ClientEvent var data json.RawMessage
event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil { if err != nil {
return err return err
} }
if event != nil { res.RoomAccountData = make(map[string]map[string]json.RawMessage)
res.GlobalAccountData = make(map[string]json.RawMessage)
if data != nil {
if req.RoomID != "" { if req.RoomID != "" {
res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent) if _, ok := res.RoomAccountData[req.RoomID]; !ok {
res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event} res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage)
}
res.RoomAccountData[req.RoomID][req.DataType] = data
} else { } else {
res.GlobalAccountData = append(res.GlobalAccountData, *event) res.GlobalAccountData[req.DataType] = data
} }
} }
return nil return nil

View File

@ -26,6 +26,8 @@ import (
// HTTP paths for the internal HTTP APIs // HTTP paths for the internal HTTP APIs
const ( const (
InputAccountDataPath = "/userapi/inputAccountData"
PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation" PerformAccountCreationPath = "/userapi/performAccountCreation"
@ -55,6 +57,14 @@ type httpUserInternalAPI struct {
httpClient *http.Client httpClient *http.Client
} }
func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData")
defer span.Finish()
apiURL := h.apiURL + InputAccountDataPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformAccountCreation( func (h *httpUserInternalAPI) PerformAccountCreation(
ctx context.Context, ctx context.Context,
request *api.PerformAccountCreationRequest, request *api.PerformAccountCreationRequest,

View File

@ -16,6 +16,7 @@ package accounts
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -39,13 +40,13 @@ type Database interface {
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
// localpart, room ID and type. // localpart, room ID and type.
// If no account data could be found, returns nil // If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error) GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)

View File

@ -17,9 +17,9 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := txn.Stmt(s.insertAccountDataStmt) stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
global []gomatrixserverlib.ClientEvent, /* global */ map[string]json.RawMessage,
rooms map[string][]gomatrixserverlib.ClientEvent, /* rooms */ map[string]map[string]json.RawMessage,
err error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return return nil, nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
global = []gomatrixserverlib.ClientEvent{} global := map[string]json.RawMessage{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent) rooms := map[string]map[string]json.RawMessage{}
for rows.Next() { for rows.Next() {
var roomID string var roomID string
@ -102,41 +102,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil { if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return return nil, nil, err
} }
ac := gomatrixserverlib.ClientEvent{ if roomID != "" {
Type: dataType, if _, ok := rooms[roomID]; !ok {
Content: content, rooms[roomID] = map[string]json.RawMessage{}
} }
rooms[roomID][dataType] = content
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
} else { } else {
global = append(global, ac) global[dataType] = content
} }
} }
return global, rooms, rows.Err() return global, rooms, rows.Err()
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) { ) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
var content []byte if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return return
} }
data = json.RawMessage(bytes)
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
return return
} }

View File

@ -17,6 +17,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"strconv" "strconv"
@ -169,7 +170,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -177,7 +178,7 @@ func (d *Database) createAccount(
"sender": [], "sender": [],
"underride": [] "underride": []
} }
}`); err != nil { }`)); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@ -295,7 +296,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content // update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update // Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@ -306,8 +307,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays // If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent, global map[string]json.RawMessage,
rooms map[string][]gomatrixserverlib.ClientEvent, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
return d.accountDatas.selectAccountData(ctx, localpart) return d.accountDatas.selectAccountData(ctx, localpart)
@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType( func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) { ) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType( return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType, ctx, localpart, roomID, dataType,
) )

View File

@ -17,8 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return return
@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
global []gomatrixserverlib.ClientEvent, /* global */ map[string]json.RawMessage,
rooms map[string][]gomatrixserverlib.ClientEvent, /* rooms */ map[string]map[string]json.RawMessage,
err error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil { if err != nil {
return return nil, nil, err
} }
global = []gomatrixserverlib.ClientEvent{} global := map[string]json.RawMessage{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent) rooms := map[string]map[string]json.RawMessage{}
for rows.Next() { for rows.Next() {
var roomID string var roomID string
@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil { if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return return nil, nil, err
} }
ac := gomatrixserverlib.ClientEvent{ if roomID != "" {
Type: dataType, if _, ok := rooms[roomID]; !ok {
Content: content, rooms[roomID] = map[string]json.RawMessage{}
} }
rooms[roomID][dataType] = content
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
} else { } else {
global = append(global, ac) global[dataType] = content
} }
} }
return return global, rooms, nil
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) { ) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
var content []byte if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return return
} }
data = json.RawMessage(bytes)
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
return return
} }

View File

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"strconv" "strconv"
"sync" "sync"
@ -180,7 +181,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -188,7 +189,7 @@ func (d *Database) createAccount(
"sender": [], "sender": [],
"underride": [] "underride": []
} }
}`); err != nil { }`)); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@ -306,7 +307,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content // update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update // Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@ -317,8 +318,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays // If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent, global map[string]json.RawMessage,
rooms map[string][]gomatrixserverlib.ClientEvent, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
return d.accountDatas.selectAccountData(ctx, localpart) return d.accountDatas.selectAccountData(ctx, localpart)
@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType( func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) { ) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType( return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType, ctx, localpart, roomID, dataType,
) )