diff --git a/clientapi/auth/storage/accounts/account_data_table.go b/clientapi/auth/storage/accounts/account_data_table.go index 0d73cb31..0d6ad093 100644 --- a/clientapi/auth/storage/accounts/account_data_table.go +++ b/clientapi/auth/storage/accounts/account_data_table.go @@ -120,28 +120,17 @@ func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data []gomatrixserverlib.ClientEvent, err error) { - data = []gomatrixserverlib.ClientEvent{} - +) (data *gomatrixserverlib.ClientEvent, err error) { stmt := s.selectAccountDataByTypeStmt - rows, err := stmt.QueryContext(ctx, localpart, roomID, dataType) - if err != nil { + var content []byte + + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { return } - for rows.Next() { - var content []byte - - if err = rows.Scan(&content); err != nil { - return - } - - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - data = append(data, ac) + data = &gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: content, } return diff --git a/clientapi/auth/storage/accounts/storage.go b/clientapi/auth/storage/accounts/storage.go index 41d75daa..020a3837 100644 --- a/clientapi/auth/storage/accounts/storage.go +++ b/clientapi/auth/storage/accounts/storage.go @@ -263,11 +263,11 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // GetAccountDataByType returns account data matching a given // localpart, room ID and type. -// If no account data could be found, returns an empty array +// If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data []gomatrixserverlib.ClientEvent, err error) { +) (data *gomatrixserverlib.ClientEvent, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, ) diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 6e7324cd..487081c5 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -59,7 +59,7 @@ func GetTags( return httputil.LogThenError(req, err) } - if len(data) == 0 { + if data == nil { return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, @@ -68,7 +68,7 @@ func GetTags( return util.JSONResponse{ Code: http.StatusOK, - JSON: data[0].Content, + JSON: data.Content, } } @@ -103,8 +103,8 @@ func PutTag( } var tagContent gomatrix.TagContent - if len(data) > 0 { - if err = json.Unmarshal(data[0].Content, &tagContent); err != nil { + if data != nil { + if err = json.Unmarshal(data.Content, &tagContent); err != nil { return httputil.LogThenError(req, err) } } else { @@ -155,7 +155,7 @@ func DeleteTag( } // If there are no tags in the database, exit - if len(data) == 0 { + if data == nil { // Spec only defines 200 responses for this endpoint so we don't return anything else. return util.JSONResponse{ Code: http.StatusOK, @@ -164,7 +164,7 @@ func DeleteTag( } var tagContent gomatrix.TagContent - err = json.Unmarshal(data[0].Content, &tagContent) + err = json.Unmarshal(data.Content, &tagContent) if err != nil { return httputil.LogThenError(req, err) } @@ -204,7 +204,7 @@ func obtainSavedTags( userID string, roomID string, accountDB *accounts.Database, -) (string, []gomatrixserverlib.ClientEvent, error) { +) (string, *gomatrixserverlib.ClientEvent, error) { localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return "", nil, err diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 6b95f469..94a36900 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -196,13 +196,13 @@ func (rp *RequestPool) appendAccountData( events := []gomatrixserverlib.ClientEvent{} // Request the missing data from the database for _, dataType := range dataTypes { - evs, err := rp.accountDB.GetAccountDataByType( + event, err := rp.accountDB.GetAccountDataByType( req.ctx, localpart, roomID, dataType, ) if err != nil { return nil, err } - events = append(events, evs...) + events = append(events, *event) } // Append the data to the response