Membership viewing API (#174)

* Basic memberships retrieval

* Change the way the memberships are saved in the client API database

* Retrieve single membership

* Get memberships only if the user is or has been in the room

* Check server name on room ID instead of user ID

* Save the join membership event and updates it when necessary

* Membership events retrieval + update on leave

* Implement the API on the roomserver and client API server

* Fix comments

* Remove the functions and attributes used before the new query API

* Explicitely state what we return in query

* Remove tab
main
Brendan Abolivier 2017-08-21 16:34:26 +01:00 committed by GitHub
parent 81179a0595
commit 5950293e79
11 changed files with 383 additions and 48 deletions

View File

@ -44,9 +44,6 @@ const insertMembershipSQL = `
ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id
` `
const selectMembershipSQL = "" +
"SELECT * from account_memberships WHERE localpart = $1 AND room_id = $2"
const selectMembershipsByLocalpartSQL = "" + const selectMembershipsByLocalpartSQL = "" +
"SELECT room_id, event_id FROM account_memberships WHERE localpart = $1" "SELECT room_id, event_id FROM account_memberships WHERE localpart = $1"

View File

@ -121,7 +121,8 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
} }
// SaveMembership saves the user matching a given localpart as a member of a given // SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the `join` membership event. // room. It also stores the ID of the membership event and a flag on whether the user
// is still in the room.
// If a membership already exists between the user and the room, or of the // If a membership already exists between the user and the room, or of the
// insert fails, returns the SQL error // insert fails, returns the SQL error
func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error { func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error {
@ -156,23 +157,19 @@ func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsT
}) })
} }
// GetMembershipsByLocalpart returns an array containing the IDs of all the rooms // GetMembershipsByLocalpart returns an array containing the memberships for all
// a user matching a given localpart is a member of // the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array // If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error // If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) {
return d.memberships.selectMembershipsByLocalpart(localpart) return d.memberships.selectMembershipsByLocalpart(localpart)
} }
// UpdateMembership update the "join" membership event ID of a membership. // newMembership will save a new membership in the database, with a flag on whether
// This is useful in case of membership upgrade (e.g. profile update) // the user is still in the room. This flag is set to true if the given state
// If there was an issue during the update, returns the SQL error // event is a "join" membership event and false if the event is a "leave" or "ban"
func (d *Database) UpdateMembership(oldEventID string, newEventID string) error { // membership. If the event isn't a m.room.member event with one of these three
return d.memberships.updateMembershipByEventID(oldEventID, newEventID) // values, does nothing.
}
// newMembership will save a new membership in the database if the given state
// event is a "join" membership event
// If the event isn't a "join" membership event, does nothing // If the event isn't a "join" membership event, does nothing
// If an error occurred, returns it // If an error occurred, returns it
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error { func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {

View File

@ -0,0 +1,55 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package readers
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util"
)
// GetMemberships implements GET /rooms/{roomId}/members
func GetMemberships(
req *http.Request, device *authtypes.Device, roomID string,
accountDB *accounts.Database, cfg config.Dendrite,
queryAPI api.RoomserverQueryAPI,
) util.JSONResponse {
queryReq := api.QueryMembershipsForRoomRequest{
RoomID: roomID,
Sender: device.UserID,
}
var queryRes api.QueryMembershipsForRoomResponse
if err := queryAPI.QueryMembershipsForRoom(&queryReq, &queryRes); err != nil {
return httputil.LogThenError(req, err)
}
if !queryRes.HasBeenInRoom {
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."),
}
}
return util.JSONResponse{
Code: 200,
JSON: queryRes.JoinEvents,
}
}

View File

@ -313,6 +313,13 @@ func Setup(
}), }),
) )
r0mux.Handle("/rooms/{roomID}/members",
common.MakeAuthAPI("rooms_members", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return readers.GetMemberships(req, device, vars["roomID"], accountDB, cfg, queryAPI)
}),
)
r0mux.Handle("/rooms/{roomID}/read_markers", r0mux.Handle("/rooms/{roomID}/read_markers",
common.MakeAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { common.MakeAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse {
// TODO: return the read_markers. // TODO: return the read_markers.

View File

@ -100,6 +100,23 @@ type QueryEventsByIDResponse struct {
Events []gomatrixserverlib.Event `json:"events"` Events []gomatrixserverlib.Event `json:"events"`
} }
// QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom
type QueryMembershipsForRoomRequest struct {
// ID of the room to fetch memberships from
RoomID string `json:"room_id"`
// ID of the user sending the request
Sender string `json:"sender"`
}
// QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom
type QueryMembershipsForRoomResponse struct {
// The "m.room.member" events (of "join" membership) in the client format
JoinEvents []gomatrixserverlib.ClientEvent `json:"join_events"`
// True if the user has been in room before and has either stayed in it or
// left it.
HasBeenInRoom bool `json:"has_been_in_room"`
}
// RoomserverQueryAPI is used to query information from the room server. // RoomserverQueryAPI is used to query information from the room server.
type RoomserverQueryAPI interface { type RoomserverQueryAPI interface {
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.
@ -119,6 +136,12 @@ type RoomserverQueryAPI interface {
request *QueryEventsByIDRequest, request *QueryEventsByIDRequest,
response *QueryEventsByIDResponse, response *QueryEventsByIDResponse,
) error ) error
// Query a list of membership events for a room
QueryMembershipsForRoom(
request *QueryMembershipsForRoomRequest,
response *QueryMembershipsForRoomResponse,
) error
} }
// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. // RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API.
@ -130,6 +153,9 @@ const RoomserverQueryStateAfterEventsPath = "/api/roomserver/queryStateAfterEven
// RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API. // RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API.
const RoomserverQueryEventsByIDPath = "/api/roomserver/queryEventsByID" const RoomserverQueryEventsByIDPath = "/api/roomserver/queryEventsByID"
// RoomserverQueryMembershipsForRoomPath is the HTTP path for the QueryMembershipsForRoom API
const RoomserverQueryMembershipsForRoomPath = "/api/roomserver/queryMembershipsForRoom"
// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API.
// If httpClient is nil then it uses the http.DefaultClient // If httpClient is nil then it uses the http.DefaultClient
func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI { func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI {
@ -171,6 +197,15 @@ func (h *httpRoomserverQueryAPI) QueryEventsByID(
return postJSON(h.httpClient, apiURL, request, response) return postJSON(h.httpClient, apiURL, request, response)
} }
// QueryMembershipsForRoom implements RoomserverQueryAPI
func (h *httpRoomserverQueryAPI) QueryMembershipsForRoom(
request *QueryMembershipsForRoomRequest,
response *QueryMembershipsForRoomResponse,
) error {
apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath
return postJSON(h.httpClient, apiURL, request, response)
}
func postJSON(httpClient *http.Client, apiURL string, request, response interface{}) error { func postJSON(httpClient *http.Client, apiURL string, request, response interface{}) error {
jsonBytes, err := json.Marshal(request) jsonBytes, err := json.Marshal(request)
if err != nil { if err != nil {

View File

@ -95,10 +95,9 @@ func updateMembership(
return nil, err return nil, err
} }
} }
if old == new { if old == new && new != "join" {
// If the membership is the same then nothing changed and we can return // If the membership is the same then nothing changed and we can return
// immediately. This should help speed up processing for display name // immediately, unless it's a "join" update (e.g. profile update).
// changes where the membership is "join" both before and after.
return updates, nil return updates, nil
} }
@ -152,16 +151,21 @@ func updateToInviteMembership(
func updateToJoinMembership( func updateToJoinMembership(
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// If the user is already marked as being joined then we can return immediately. // If the user is already marked as being joined, we call SetToJoin to update
// TODO: Is this code reachable given the "old != new" guard in updateMembership? // the event ID then we can return immediately. Retired is ignored as there
// is no invite event to retire.
if mu.IsJoin() { if mu.IsJoin() {
_, err := mu.SetToJoin(add.Sender(), add.EventID(), true)
if err != nil {
return nil, err
}
return updates, nil return updates, nil
} }
// When we mark a user as being joined we will invalidate any invites that // When we mark a user as being joined we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have // are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this // been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream. // by studying the state changes in the room event stream.
retired, err := mu.SetToJoin(add.Sender()) retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -194,7 +198,7 @@ func updateToLeaveMembership(
// are active for that user. We notify the consumers that the invites have // are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this // been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream. // by studying the state changes in the room event stream.
retired, err := mu.SetToLeave(add.Sender()) retired, err := mu.SetToLeave(add.Sender(), add.EventID())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -52,6 +52,13 @@ type RoomserverQueryAPIDatabase interface {
// Remove a given room alias. // Remove a given room alias.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
RemoveRoomAlias(alias string) error RemoveRoomAlias(alias string) error
// Lookup the join events for all members in a room as requested by a given
// user. If the user is currently in the room, returns the room's current
// members, if not returns an empty array (TODO: Fix it)
// If the user requesting the list of members has never been in the room,
// returns nil.
// If there was an issue retrieving the events, returns an error.
GetMembershipEvents(roomNID types.RoomNID, requestSenderUserID string) (events []types.Event, err error)
} }
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
@ -182,6 +189,37 @@ func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixs
return result, nil return result, nil
} }
// QueryMembershipsForRoom implements api.RoomserverQueryAPI
func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse,
) error {
roomNID, err := r.DB.RoomNID(request.RoomID)
if err != nil {
return err
}
events, err := r.DB.GetMembershipEvents(roomNID, request.Sender)
if err != nil {
return nil
}
if events == nil {
response.HasBeenInRoom = false
response.JoinEvents = nil
return nil
}
response.HasBeenInRoom = true
response.JoinEvents = []gomatrixserverlib.ClientEvent{}
for _, event := range events {
clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
}
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. // SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
servMux.Handle( servMux.Handle(
@ -226,4 +264,18 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
return util.JSONResponse{Code: 200, JSON: &response} return util.JSONResponse{Code: 200, JSON: &response}
}), }),
) )
servMux.Handle(
api.RoomserverQueryMembershipsForRoomPath,
common.MakeAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipsForRoomRequest
var response api.QueryMembershipsForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipsForRoom(&request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: 200, JSON: &response}
}),
)
} }

View File

@ -58,10 +58,22 @@ const bulkSelectEventStateKeyNIDSQL = "" +
"SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" + "SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" +
" WHERE event_state_key = ANY($1)" " WHERE event_state_key = ANY($1)"
const selectEventStateKeySQL = "" +
"SELECT event_state_key FROM roomserver_event_state_keys" +
" WHERE event_state_key_nid = $1"
// Bulk lookup from numeric ID to string state key for that state key.
// Takes an array of strings as the query parameter.
const bulkSelectEventStateKeySQL = "" +
"SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" +
" WHERE event_state_key_nid = ANY($1)"
type eventStateKeyStatements struct { type eventStateKeyStatements struct {
insertEventStateKeyNIDStmt *sql.Stmt insertEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyStmt *sql.Stmt
} }
func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
@ -72,7 +84,9 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
return statementList{ return statementList{
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
{&s.selectEventStateKeyStmt, selectEventStateKeySQL},
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
}.prepare(db) }.prepare(db)
} }
@ -114,3 +128,36 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st
} }
return result, nil return result, nil
} }
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) {
var eventStateKey string
stmt := s.selectEventStateKeyStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(eventStateKeyNID).Scan(&eventStateKey)
return eventStateKey, err
}
func (s *eventStateKeyStatements) bulkSelectEventStateKey(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) {
var nIDs pq.Int64Array
for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i])
}
rows, err := s.bulkSelectEventStateKeyStmt.Query(nIDs)
if err != nil {
return nil, err
}
defer rows.Close()
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
for rows.Next() {
var stateKey string
var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
return nil, err
}
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
}
return result, nil
}

View File

@ -33,7 +33,7 @@ const membershipSchema = `
-- and the room state tables. -- and the room state tables.
-- This table is updated in one of 3 ways: -- This table is updated in one of 3 ways:
-- 1) The membership of a user changes within the current state of the room. -- 1) The membership of a user changes within the current state of the room.
-- 2) An invite is received outside of a room over federation. -- 2) An invite is received outside of a room over federation.
-- 3) An invite is rejected outside of a room over federation. -- 3) An invite is rejected outside of a room over federation.
CREATE TABLE IF NOT EXISTS roomserver_membership ( CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid BIGINT NOT NULL, room_nid BIGINT NOT NULL,
@ -46,6 +46,16 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
-- The state the user is in within this room. -- The state the user is in within this room.
-- Default value is "membershipStateLeaveOrBan" -- Default value is "membershipStateLeaveOrBan"
membership_nid BIGINT NOT NULL DEFAULT 1, membership_nid BIGINT NOT NULL DEFAULT 1,
-- The numeric ID of the membership event.
-- It refers to the join membership event if the membership_nid is join (3),
-- and to the leave/ban membership event if the membership_nid is leave or
-- ban (1).
-- If the membership_nid is invite (2) and the user has been in the room
-- before, it will refer to the previous leave/ban membership event, and will
-- be equals to 0 (its default) if the user never joined the room before.
-- This NID is updated if the join event gets updated (e.g. profile update),
-- or if the user leaves/joins the room.
event_nid BIGINT NOT NULL DEFAULT 0,
UNIQUE (room_nid, target_nid) UNIQUE (room_nid, target_nid)
); );
` `
@ -57,18 +67,33 @@ const insertMembershipSQL = "" +
" VALUES ($1, $2)" + " VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2"
const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2"
const selectMembershipsFromRoomSQL = "" +
"SELECT membership_nid, event_nid FROM roomserver_membership" +
" WHERE room_nid = $1"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" + "SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE"
const updateMembershipSQL = "" + const updateMembershipSQL = "" +
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4" + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
type membershipStatements struct { type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt
updateMembershipStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -80,6 +105,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
return statementList{ return statementList{
{&s.insertMembershipStmt, insertMembershipSQL}, {&s.insertMembershipStmt, insertMembershipSQL},
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL}, {&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db) }.prepare(db)
} }
@ -100,12 +128,59 @@ func (s *membershipStatements) selectMembershipForUpdate(
return return
} }
func (s *membershipStatements) selectMembershipFromRoomAndTarget(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership membershipState, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRow(
roomNID, targetUserNID,
).Scan(&membership, &eventNID)
return
}
func (s *membershipStatements) selectMembershipsFromRoom(
roomNID types.RoomNID,
) (eventNIDs map[types.EventNID]membershipState, err error) {
rows, err := s.selectMembershipsFromRoomStmt.Query(roomNID)
if err != nil {
return
}
eventNIDs = make(map[types.EventNID]membershipState)
for rows.Next() {
var eNID types.EventNID
var membership membershipState
if err = rows.Scan(&membership, &eNID); err != nil {
return
}
eventNIDs[eNID] = membership
}
return
}
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
roomNID types.RoomNID, membership membershipState,
) (eventNIDs []types.EventNID, err error) {
rows, err := s.selectMembershipsFromRoomAndMembershipStmt.Query(roomNID, membership)
if err != nil {
return
}
for rows.Next() {
var eNID types.EventNID
if err = rows.Scan(&eNID); err != nil {
return
}
eventNIDs = append(eventNIDs, eNID)
}
return
}
func (s *membershipStatements) updateMembership( func (s *membershipStatements) updateMembership(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
senderUserNID types.EventStateKeyNID, membership membershipState, senderUserNID types.EventStateKeyNID, membership membershipState,
eventNID types.EventNID,
) error { ) error {
_, err := txn.Stmt(s.updateMembershipStmt).Exec( _, err := txn.Stmt(s.updateMembershipStmt).Exec(
roomNID, targetUserNID, senderUserNID, membership, roomNID, targetUserNID, senderUserNID, membership, eventNID,
) )
return err return err
} }

View File

@ -435,7 +435,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
} }
if u.membership != membershipStateInvite { if u.membership != membershipStateInvite {
if err = u.d.statements.updateMembership( if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
); err != nil { ); err != nil {
return false, err return false, err
} }
@ -444,7 +444,43 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
} }
// SetToJoin implements types.MembershipUpdater // SetToJoin implements types.MembershipUpdater
func (u *membershipUpdater) SetToJoin(senderUserID string) ([]string, error) { func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
var inviteEventIDs []string
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID)
if err != nil {
return nil, err
}
// If this is a join event update, there is no invite to update
if !isUpdate {
inviteEventIDs, err = u.d.statements.updateInviteRetired(
u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return nil, err
}
}
// Lookup the NID of the new join event
nIDs, err := u.d.EventNIDs([]string{eventID})
if err != nil {
return nil, err
}
if u.membership != membershipStateJoin || isUpdate {
if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, nIDs[eventID],
); err != nil {
return nil, err
}
}
return inviteEventIDs, nil
}
// SetToLeave implements types.MembershipUpdater
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -455,9 +491,16 @@ func (u *membershipUpdater) SetToJoin(senderUserID string) ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if u.membership != membershipStateJoin {
// Lookup the NID of the new leave event
nIDs, err := u.d.EventNIDs([]string{eventID})
if err != nil {
return nil, err
}
if u.membership != membershipStateLeaveOrBan {
if err = u.d.statements.updateMembership( if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, nIDs[eventID],
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -465,26 +508,49 @@ func (u *membershipUpdater) SetToJoin(senderUserID string) ([]string, error) {
return inviteEventIDs, nil return inviteEventIDs, nil
} }
// SetToLeave implements types.MembershipUpdater // GetMembershipEvents implements query.RoomserverQueryAPIDB
func (u *membershipUpdater) SetToLeave(senderUserID string) ([]string, error) { func (d *Database) GetMembershipEvents(roomNID types.RoomNID, requestSenderUserID string) (events []types.Event, err error) {
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) txn, err := d.db.Begin()
if err != nil { if err != nil {
return nil, err return
} }
inviteEventIDs, err := u.d.statements.updateInviteRetired( defer txn.Commit()
u.txn, u.roomNID, u.targetUserNID,
) requestSenderUserNID, err := d.assignStateKeyNID(txn, requestSenderUserID)
if err != nil { if err != nil {
return nil, err return
} }
if u.membership != membershipStateLeaveOrBan {
if err = u.d.statements.updateMembership( _, senderMembership, err := d.statements.selectMembershipFromRoomAndTarget(roomNID, requestSenderUserNID)
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, if err == sql.ErrNoRows {
); err != nil { // The user has never been a member of that room
return nil, nil
} else if err != nil {
return
}
if senderMembership == membershipStateJoin {
// The user is still in the room: Send the current list of joined members
var joinEventNIDs []types.EventNID
joinEventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(roomNID, membershipStateJoin)
if err != nil {
return nil, err return nil, err
} }
events, err = d.Events(joinEventNIDs)
} else {
// The user isn't in the room anymore
// TODO: Send the list of joined member as it was when the user left
// We cannot do this using only the memberships database, as it
// only stores the latest join event NID for a given target user.
// The solution would be to build the state of a room after before
// the leave event and extract a members list from it.
// For now, we return an empty slice so we know the user has been
// in the room before.
events = []types.Event{}
} }
return inviteEventIDs, nil
return
} }
type transaction struct { type transaction struct {

View File

@ -193,12 +193,12 @@ type MembershipUpdater interface {
// Set the state to invite. // Set the state to invite.
// Returns whether this invite needs to be sent // Returns whether this invite needs to be sent
SetToInvite(event gomatrixserverlib.Event) (needsSending bool, err error) SetToInvite(event gomatrixserverlib.Event) (needsSending bool, err error)
// Set the state to join. // Set the state to join or updates the event ID in the database.
// Returns a list of invite event IDs that this state change retired. // Returns a list of invite event IDs that this state change retired.
SetToJoin(senderUserID string) (inviteEventIDs []string, err error) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error)
// Set the state to leave. // Set the state to leave.
// Returns a list of invite event IDs that this state change retired. // Returns a list of invite event IDs that this state change retired.
SetToLeave(senderUserID string) (inviteEventIDs []string, err error) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error)
// Implements Transaction so it can be committed or rolledback. // Implements Transaction so it can be committed or rolledback.
Transaction Transaction
} }