Add contexts to the public rooms database (#230)

main
Mark Haines 2017-09-14 14:46:56 +01:00 committed by GitHub
parent bfcce5bd21
commit dc5dd4c5d2
5 changed files with 98 additions and 73 deletions

View File

@ -98,5 +98,5 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
return err return err
} }
return s.db.UpdateRoomFromEvents(addQueryRes.Events, remQueryRes.Events) return s.db.UpdateRoomFromEvents(context.TODO(), addQueryRes.Events, remQueryRes.Events)
} }

View File

@ -32,7 +32,7 @@ func GetVisibility(
req *http.Request, publicRoomsDatabase *storage.PublicRoomsServerDatabase, req *http.Request, publicRoomsDatabase *storage.PublicRoomsServerDatabase,
roomID string, roomID string,
) util.JSONResponse { ) util.JSONResponse {
isPublic, err := publicRoomsDatabase.GetRoomVisibility(roomID) isPublic, err := publicRoomsDatabase.GetRoomVisibility(req.Context(), roomID)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -62,7 +62,7 @@ func SetVisibility(
} }
isPublic := v.Visibility == "public" isPublic := v.Visibility == "public"
if err := publicRoomsDatabase.SetRoomVisibility(isPublic, roomID); err != nil { if err := publicRoomsDatabase.SetRoomVisibility(req.Context(), isPublic, roomID); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View File

@ -63,7 +63,7 @@ func GetPublicRooms(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
if response.Estimate, err = publicRoomDatabase.CountPublicRooms(); err != nil { if response.Estimate, err = publicRoomDatabase.CountPublicRooms(req.Context()); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
@ -75,7 +75,9 @@ func GetPublicRooms(
response.NextBatch = strconv.Itoa(nextIndex) response.NextBatch = strconv.Itoa(nextIndex)
} }
if response.Chunk, err = publicRoomDatabase.GetPublicRooms(offset, limit, request.Filter.SearchTerms); err != nil { if response.Chunk, err = publicRoomDatabase.GetPublicRooms(
req.Context(), offset, limit, request.Filter.SearchTerms,
); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View File

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -166,27 +167,35 @@ func (s *publicRoomsStatements) prepare(db *sql.DB) (err error) {
return return
} }
func (s *publicRoomsStatements) countPublicRooms() (nb int64, err error) { func (s *publicRoomsStatements) countPublicRooms(ctx context.Context) (nb int64, err error) {
err = s.countPublicRoomsStmt.QueryRow().Scan(&nb) err = s.countPublicRoomsStmt.QueryRowContext(ctx).Scan(&nb)
return return
} }
func (s *publicRoomsStatements) selectPublicRooms(offset int64, limit int16, filter string) ([]types.PublicRoom, error) { func (s *publicRoomsStatements) selectPublicRooms(
ctx context.Context, offset int64, limit int16, filter string,
) ([]types.PublicRoom, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if len(filter) > 0 { if len(filter) > 0 {
pattern := "%" + filter + "%" pattern := "%" + filter + "%"
if limit == 0 { if limit == 0 {
rows, err = s.selectPublicRoomsWithFilterStmt.Query(pattern, offset) rows, err = s.selectPublicRoomsWithFilterStmt.QueryContext(
ctx, pattern, offset,
)
} else { } else {
rows, err = s.selectPublicRoomsWithLimitAndFilterStmt.Query(pattern, offset, limit) rows, err = s.selectPublicRoomsWithLimitAndFilterStmt.QueryContext(
ctx, pattern, offset, limit,
)
} }
} else { } else {
if limit == 0 { if limit == 0 {
rows, err = s.selectPublicRoomsStmt.Query(offset) rows, err = s.selectPublicRoomsStmt.QueryContext(ctx, offset)
} else { } else {
rows, err = s.selectPublicRoomsWithLimitStmt.Query(offset, limit) rows, err = s.selectPublicRoomsWithLimitStmt.QueryContext(
ctx, offset, limit,
)
} }
} }
@ -207,10 +216,7 @@ func (s *publicRoomsStatements) selectPublicRooms(offset int64, limit int16, fil
return rooms, err return rooms, err
} }
r.Aliases = make([]string, len(aliases)) r.Aliases = aliases
for i := range aliases {
r.Aliases[i] = aliases[i]
}
rooms = append(rooms, r) rooms = append(rooms, r)
} }
@ -218,51 +224,53 @@ func (s *publicRoomsStatements) selectPublicRooms(offset int64, limit int16, fil
return rooms, nil return rooms, nil
} }
func (s *publicRoomsStatements) selectRoomVisibility(roomID string) (v bool, err error) { func (s *publicRoomsStatements) selectRoomVisibility(
err = s.selectRoomVisibilityStmt.QueryRow(roomID).Scan(&v) ctx context.Context, roomID string,
) (v bool, err error) {
err = s.selectRoomVisibilityStmt.QueryRowContext(ctx, roomID).Scan(&v)
return return
} }
func (s *publicRoomsStatements) insertNewRoom(roomID string) error { func (s *publicRoomsStatements) insertNewRoom(
_, err := s.insertNewRoomStmt.Exec(roomID) ctx context.Context, roomID string,
) error {
_, err := s.insertNewRoomStmt.ExecContext(ctx, roomID)
return err return err
} }
func (s *publicRoomsStatements) incrementJoinedMembersInRoom(roomID string) error { func (s *publicRoomsStatements) incrementJoinedMembersInRoom(
_, err := s.incrementJoinedMembersInRoomStmt.Exec(roomID) ctx context.Context, roomID string,
) error {
_, err := s.incrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
return err return err
} }
func (s *publicRoomsStatements) decrementJoinedMembersInRoom(roomID string) error { func (s *publicRoomsStatements) decrementJoinedMembersInRoom(
_, err := s.decrementJoinedMembersInRoomStmt.Exec(roomID) ctx context.Context, roomID string,
) error {
_, err := s.decrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
return err return err
} }
func (s *publicRoomsStatements) updateRoomAttribute(attrName string, attrValue attributeValue, roomID string) error { func (s *publicRoomsStatements) updateRoomAttribute(
isEditable := false ctx context.Context, attrName string, attrValue attributeValue, roomID string,
for _, editable := range editableAttributes { ) error {
if editable == attrName { stmt, isEditable := s.updateRoomAttributeStmts[attrName]
isEditable = true
}
}
if !isEditable { if !isEditable {
return errors.New("Cannot edit " + attrName) return errors.New("Cannot edit " + attrName)
} }
var value interface{} var value interface{}
if attrName == "aliases" { switch v := attrValue.(type) {
// Aliases need a special conversion case []string:
valueAsSlice, isSlice := attrValue.([]string) value = pq.StringArray(v)
if !isSlice { case bool, string:
// attrValue isn't a slice of strings
return errors.New("New list of aliases is of the wrong type")
}
value = pq.StringArray(valueAsSlice)
} else {
value = attrValue value = attrValue
default:
return errors.New("Unsupported attribute type, must be bool, string or []string")
} }
_, err := s.updateRoomAttributeStmts[attrName].Exec(value, roomID) _, err := stmt.ExecContext(ctx, value, roomID)
return err return err
} }

View File

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
@ -64,21 +65,25 @@ func (d *PublicRoomsServerDatabase) SetPartitionOffset(topic string, partition i
// GetRoomVisibility returns the room visibility as a boolean: true if the room // GetRoomVisibility returns the room visibility as a boolean: true if the room
// is publicly visible, false if not. // is publicly visible, false if not.
// Returns an error if the retrieval failed. // Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) GetRoomVisibility(roomID string) (bool, error) { func (d *PublicRoomsServerDatabase) GetRoomVisibility(
return d.statements.selectRoomVisibility(roomID) ctx context.Context, roomID string,
) (bool, error) {
return d.statements.selectRoomVisibility(ctx, roomID)
} }
// SetRoomVisibility updates the visibility attribute of a room. This attribute // SetRoomVisibility updates the visibility attribute of a room. This attribute
// must be set to true if the room is publicly visible, false if not. // must be set to true if the room is publicly visible, false if not.
// Returns an error if the update failed. // Returns an error if the update failed.
func (d *PublicRoomsServerDatabase) SetRoomVisibility(visible bool, roomID string) error { func (d *PublicRoomsServerDatabase) SetRoomVisibility(
return d.statements.updateRoomAttribute("visibility", visible, roomID) ctx context.Context, visible bool, roomID string,
) error {
return d.statements.updateRoomAttribute(ctx, "visibility", visible, roomID)
} }
// CountPublicRooms returns the number of room set as publicly visible on the server. // CountPublicRooms returns the number of room set as publicly visible on the server.
// Returns an error if the retrieval failed. // Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) CountPublicRooms() (int64, error) { func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) {
return d.statements.countPublicRooms() return d.statements.countPublicRooms(ctx)
} }
// GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number // GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number
@ -86,8 +91,10 @@ func (d *PublicRoomsServerDatabase) CountPublicRooms() (int64, error) {
// If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all // If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all
// the rooms set as publicly visible on the server. // the rooms set as publicly visible on the server.
// Returns an error if the retrieval failed. // Returns an error if the retrieval failed.
func (d *PublicRoomsServerDatabase) GetPublicRooms(offset int64, limit int16, filter string) ([]types.PublicRoom, error) { func (d *PublicRoomsServerDatabase) GetPublicRooms(
return d.statements.selectPublicRooms(offset, limit, filter) ctx context.Context, offset int64, limit int16, filter string,
) ([]types.PublicRoom, error) {
return d.statements.selectPublicRooms(ctx, offset, limit, filter)
} }
// UpdateRoomFromEvents iterate over a slice of state events and call // UpdateRoomFromEvents iterate over a slice of state events and call
@ -98,17 +105,19 @@ func (d *PublicRoomsServerDatabase) GetPublicRooms(offset int64, limit int16, fi
// If the update triggered by one of the events failed, aborts the process and // If the update triggered by one of the events failed, aborts the process and
// returns an error. // returns an error.
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents( func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents(
eventsToAdd []gomatrixserverlib.Event, eventsToRemove []gomatrixserverlib.Event, ctx context.Context,
eventsToAdd []gomatrixserverlib.Event,
eventsToRemove []gomatrixserverlib.Event,
) error { ) error {
for _, event := range eventsToAdd { for _, event := range eventsToAdd {
if err := d.UpdateRoomFromEvent(event); err != nil { if err := d.UpdateRoomFromEvent(ctx, event); err != nil {
return err return err
} }
} }
for _, event := range eventsToRemove { for _, event := range eventsToRemove {
if event.Type() == "m.room.member" { if event.Type() == "m.room.member" {
if err := d.updateNumJoinedUsers(event, true); err != nil { if err := d.updateNumJoinedUsers(ctx, event, true); err != nil {
return err return err
} }
} }
@ -123,47 +132,49 @@ func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents(
// If the event doesn't match with any property used to compute the public room directory, // If the event doesn't match with any property used to compute the public room directory,
// does nothing. // does nothing.
// If something went wrong during the process, returns an error. // If something went wrong during the process, returns an error.
func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(event gomatrixserverlib.Event) error { func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(
ctx context.Context, event gomatrixserverlib.Event,
) error {
// Process the event according to its type // Process the event according to its type
switch event.Type() { switch event.Type() {
case "m.room.create": case "m.room.create":
return d.statements.insertNewRoom(event.RoomID()) return d.statements.insertNewRoom(ctx, event.RoomID())
case "m.room.member": case "m.room.member":
return d.updateNumJoinedUsers(event, false) return d.updateNumJoinedUsers(ctx, event, false)
case "m.room.aliases": case "m.room.aliases":
return d.updateRoomAliases(event) return d.updateRoomAliases(ctx, event)
case "m.room.canonical_alias": case "m.room.canonical_alias":
var content common.CanonicalAliasContent var content common.CanonicalAliasContent
field := &(content.Alias) field := &(content.Alias)
attrName := "canonical_alias" attrName := "canonical_alias"
return d.updateStringAttribute(attrName, event, &content, field) return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.name": case "m.room.name":
var content common.NameContent var content common.NameContent
field := &(content.Name) field := &(content.Name)
attrName := "name" attrName := "name"
return d.updateStringAttribute(attrName, event, &content, field) return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.topic": case "m.room.topic":
var content common.TopicContent var content common.TopicContent
field := &(content.Topic) field := &(content.Topic)
attrName := "topic" attrName := "topic"
return d.updateStringAttribute(attrName, event, &content, field) return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.avatar": case "m.room.avatar":
var content common.AvatarContent var content common.AvatarContent
field := &(content.URL) field := &(content.URL)
attrName := "avatar_url" attrName := "avatar_url"
return d.updateStringAttribute(attrName, event, &content, field) return d.updateStringAttribute(ctx, attrName, event, &content, field)
case "m.room.history_visibility": case "m.room.history_visibility":
var content common.HistoryVisibilityContent var content common.HistoryVisibilityContent
field := &(content.HistoryVisibility) field := &(content.HistoryVisibility)
attrName := "world_readable" attrName := "world_readable"
strForTrue := "world_readable" strForTrue := "world_readable"
return d.updateBooleanAttribute(attrName, event, &content, field, strForTrue) return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
case "m.room.guest_access": case "m.room.guest_access":
var content common.GuestAccessContent var content common.GuestAccessContent
field := &(content.GuestAccess) field := &(content.GuestAccess)
attrName := "guest_can_join" attrName := "guest_can_join"
strForTrue := "can_join" strForTrue := "can_join"
return d.updateBooleanAttribute(attrName, event, &content, field, strForTrue) return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
} }
// If the event type didn't match, return with no error // If the event type didn't match, return with no error
@ -177,7 +188,7 @@ func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(event gomatrixserverlib.
// database, if set to truem decrements it. // database, if set to truem decrements it.
// Returns an error if the update failed. // Returns an error if the update failed.
func (d *PublicRoomsServerDatabase) updateNumJoinedUsers( func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
membershipEvent gomatrixserverlib.Event, remove bool, ctx context.Context, membershipEvent gomatrixserverlib.Event, remove bool,
) error { ) error {
membership, err := membershipEvent.Membership() membership, err := membershipEvent.Membership()
if err != nil { if err != nil {
@ -189,9 +200,9 @@ func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
} }
if remove { if remove {
return d.statements.decrementJoinedMembersInRoom(membershipEvent.RoomID()) return d.statements.decrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
} }
return d.statements.incrementJoinedMembersInRoom(membershipEvent.RoomID()) return d.statements.incrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
} }
// updateStringAttribute updates a given string attribute in the database // updateStringAttribute updates a given string attribute in the database
@ -200,14 +211,14 @@ func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
// Returns an error if decoding the Matrix event's content or updating the attribute // Returns an error if decoding the Matrix event's content or updating the attribute
// failed. // failed.
func (d *PublicRoomsServerDatabase) updateStringAttribute( func (d *PublicRoomsServerDatabase) updateStringAttribute(
attrName string, event gomatrixserverlib.Event, content interface{}, ctx context.Context, attrName string, event gomatrixserverlib.Event,
field *string, content interface{}, field *string,
) error { ) error {
if err := json.Unmarshal(event.Content(), content); err != nil { if err := json.Unmarshal(event.Content(), content); err != nil {
return err return err
} }
return d.statements.updateRoomAttribute(attrName, *field, event.RoomID()) return d.statements.updateRoomAttribute(ctx, attrName, *field, event.RoomID())
} }
// updateBooleanAttribute updates a given boolean attribute in the database // updateBooleanAttribute updates a given boolean attribute in the database
@ -217,8 +228,8 @@ func (d *PublicRoomsServerDatabase) updateStringAttribute(
// Returns an error if decoding the Matrix event's content or updating the attribute // Returns an error if decoding the Matrix event's content or updating the attribute
// failed. // failed.
func (d *PublicRoomsServerDatabase) updateBooleanAttribute( func (d *PublicRoomsServerDatabase) updateBooleanAttribute(
attrName string, event gomatrixserverlib.Event, content interface{}, ctx context.Context, attrName string, event gomatrixserverlib.Event,
field *string, strForTrue string, content interface{}, field *string, strForTrue string,
) error { ) error {
if err := json.Unmarshal(event.Content(), content); err != nil { if err := json.Unmarshal(event.Content(), content); err != nil {
return err return err
@ -231,17 +242,21 @@ func (d *PublicRoomsServerDatabase) updateBooleanAttribute(
attrValue = false attrValue = false
} }
return d.statements.updateRoomAttribute(attrName, attrValue, event.RoomID()) return d.statements.updateRoomAttribute(ctx, attrName, attrValue, event.RoomID())
} }
// updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of // updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of
// a given room with it. // a given room with it.
// Returns an error if decoding the Matrix event or updating the list failed. // Returns an error if decoding the Matrix event or updating the list failed.
func (d *PublicRoomsServerDatabase) updateRoomAliases(aliasesEvent gomatrixserverlib.Event) error { func (d *PublicRoomsServerDatabase) updateRoomAliases(
ctx context.Context, aliasesEvent gomatrixserverlib.Event,
) error {
var content common.AliasesContent var content common.AliasesContent
if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil {
return err return err
} }
return d.statements.updateRoomAttribute("aliases", content.Aliases, aliasesEvent.RoomID()) return d.statements.updateRoomAttribute(
ctx, "aliases", content.Aliases, aliasesEvent.RoomID(),
)
} }