// Copyright 2017 Vector Creations Ltd // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package sqlite3 import ( "context" "database/sql" "strings" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" ) const membershipSchema = ` -- Stores data about users memberships to rooms. CREATE TABLE IF NOT EXISTS account_memberships ( -- The Matrix user ID localpart for the member localpart TEXT NOT NULL, -- The room this user is a member of room_id TEXT NOT NULL, -- The ID of the join membership event event_id TEXT NOT NULL, -- A user can only be member of a room once PRIMARY KEY (localpart, room_id), UNIQUE (event_id) ); ` const insertMembershipSQL = ` INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3) ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id ` const selectMembershipsByLocalpartSQL = "" + "SELECT room_id, event_id FROM account_memberships WHERE localpart = $1" const selectMembershipInRoomByLocalpartSQL = "" + "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" const selectRoomIDsByLocalPartSQL = "" + "SELECT room_id FROM account_memberships WHERE localpart = $1" const deleteMembershipsByEventIDsSQL = "" + "DELETE FROM account_memberships WHERE event_id IN ($1)" type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt selectRoomIDsByLocalPartStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(membershipSchema) if err != nil { return } if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil { return } if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil { return } if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { return } if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil { return } return } func (s *membershipStatements) insertMembership( ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, ) (err error) { stmt := txn.Stmt(s.insertMembershipStmt) _, err = stmt.ExecContext(ctx, localpart, roomID, eventID) return } func (s *membershipStatements) deleteMembershipsByEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) (err error) { sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) iEventIDs := make([]interface{}, len(eventIDs)) for i, e := range eventIDs { iEventIDs[i] = e } _, err = txn.ExecContext(ctx, sqlStr, iEventIDs...) return } func (s *membershipStatements) selectMembershipInRoomByLocalpart( ctx context.Context, localpart, roomID string, ) (authtypes.Membership, error) { membership := authtypes.Membership{Localpart: localpart, RoomID: roomID} stmt := s.selectMembershipInRoomByLocalpartStmt err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID) return membership, err } func (s *membershipStatements) selectMembershipsByLocalpart( ctx context.Context, localpart string, ) (memberships []authtypes.Membership, err error) { stmt := s.selectMembershipsByLocalpartStmt rows, err := stmt.QueryContext(ctx, localpart) if err != nil { return } memberships = []authtypes.Membership{} defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed") for rows.Next() { var m authtypes.Membership m.Localpart = localpart if err := rows.Scan(&m.RoomID, &m.EventID); err != nil { return nil, err } memberships = append(memberships, m) } return } func (s *membershipStatements) selectRoomIDsByLocalPart( ctx context.Context, localPart string, ) ([]string, error) { stmt := s.selectRoomIDsByLocalPartStmt rows, err := stmt.QueryContext(ctx, localPart) if err != nil { return nil, err } roomIDs := []string{} defer rows.Close() // nolint: errcheck for rows.Next() { var roomID string if err = rows.Scan(&roomID); err != nil { return nil, err } roomIDs = append(roomIDs, roomID) } return roomIDs, rows.Err() }