Make txn *sql.Tx arguments optional everywhere using a utility function (#191)
* Make txn *sql.Tx arguments optional everywhere using a utility function * Clarify that if the txn is nil the stmt will run outside a transactionmain
parent
57b7097368
commit
808c2e09f6
|
@ -55,3 +55,14 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
|
||||||
succeeded = true
|
succeeded = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TxStmt wraps an SQL stmt inside an optional transaction.
|
||||||
|
// If the transaction is nil then it returns the original statement that will
|
||||||
|
// run outside of a transaction.
|
||||||
|
// Otherwise returns a copy of the statement that will run inside the transaction.
|
||||||
|
func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
|
||||||
|
if transaction != nil {
|
||||||
|
statement = transaction.Stmt(statement)
|
||||||
|
}
|
||||||
|
return statement
|
||||||
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/federationsender/types"
|
"github.com/matrix-org/dendrite/federationsender/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -79,18 +80,18 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *joinedHostsStatements) insertJoinedHosts(
|
func (s *joinedHostsStatements) insertJoinedHosts(
|
||||||
txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName,
|
txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
_, err := txn.Stmt(s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
|
_, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error {
|
func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error {
|
||||||
_, err := txn.Stmt(s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
|
_, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string,
|
func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string,
|
||||||
) ([]types.JoinedHost, error) {
|
) ([]types.JoinedHost, error) {
|
||||||
rows, err := txn.Stmt(s.selectJoinedHostsStmt).Query(roomID)
|
rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,8 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
const roomSchema = `
|
const roomSchema = `
|
||||||
|
@ -65,7 +67,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
// insertRoom inserts the room if it didn't already exist.
|
// insertRoom inserts the room if it didn't already exist.
|
||||||
// If the room didn't exist then last_event_id is set to the empty string.
|
// If the room didn't exist then last_event_id is set to the empty string.
|
||||||
func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
|
func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
|
||||||
_, err := txn.Stmt(s.insertRoomStmt).Exec(roomID)
|
_, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +76,7 @@ func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
|
||||||
// exists by calling insertRoom first.
|
// exists by calling insertRoom first.
|
||||||
func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) {
|
func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) {
|
||||||
var lastEventID string
|
var lastEventID string
|
||||||
err := txn.Stmt(s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
|
err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -84,6 +86,6 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string
|
||||||
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
|
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
|
||||||
// have already been called earlier within the transaction.
|
// have already been called earlier within the transaction.
|
||||||
func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error {
|
func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error {
|
||||||
_, err := txn.Stmt(s.updateRoomStmt).Exec(roomID, lastEventID)
|
_, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -92,21 +93,13 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
||||||
var eventStateKeyNID int64
|
var eventStateKeyNID int64
|
||||||
stmt := s.insertEventStateKeyNIDStmt
|
err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||||
if txn != nil {
|
|
||||||
stmt = txn.Stmt(stmt)
|
|
||||||
}
|
|
||||||
err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
|
||||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
||||||
var eventStateKeyNID int64
|
var eventStateKeyNID int64
|
||||||
stmt := s.selectEventStateKeyNIDStmt
|
err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||||
if txn != nil {
|
|
||||||
stmt = txn.Stmt(stmt)
|
|
||||||
}
|
|
||||||
err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
|
||||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,11 +124,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) {
|
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) {
|
||||||
var eventStateKey string
|
var eventStateKey string
|
||||||
stmt := s.selectEventStateKeyStmt
|
err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey)
|
||||||
if txn != nil {
|
|
||||||
stmt = txn.Stmt(stmt)
|
|
||||||
}
|
|
||||||
err := stmt.QueryRow(eventStateKeyNID).Scan(&eventStateKey)
|
|
||||||
return eventStateKey, err
|
return eventStateKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -253,22 +254,22 @@ func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID typ
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) {
|
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) {
|
||||||
err = txn.Stmt(s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
|
err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error {
|
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error {
|
||||||
_, err := txn.Stmt(s.updateEventSentToOutputStmt).Exec(int64(eventNID))
|
_, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) {
|
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) {
|
||||||
err = txn.Stmt(s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
|
err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
|
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
|
||||||
rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
|
rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ package storage
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -94,7 +95,7 @@ func (s *inviteStatements) insertInviteEvent(
|
||||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||||
inviteEventJSON []byte,
|
inviteEventJSON []byte,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
result, err := txn.Stmt(s.insertInviteEventStmt).Exec(
|
result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec(
|
||||||
inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -110,7 +111,7 @@ func (s *inviteStatements) insertInviteEvent(
|
||||||
func (s *inviteStatements) updateInviteRetired(
|
func (s *inviteStatements) updateInviteRetired(
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
rows, err := txn.Stmt(s.updateInviteRetiredStmt).Query(roomNID, targetUserNID)
|
rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ package storage
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -115,14 +116,14 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *membershipStatements) insertMembership(
|
func (s *membershipStatements) insertMembership(
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) error {
|
) error {
|
||||||
_, err := txn.Stmt(s.insertMembershipStmt).Exec(roomNID, targetUserNID)
|
_, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipForUpdate(
|
func (s *membershipStatements) selectMembershipForUpdate(
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (membership membershipState, err error) {
|
) (membership membershipState, err error) {
|
||||||
err = txn.Stmt(s.selectMembershipForUpdateStmt).QueryRow(
|
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow(
|
||||||
roomNID, targetUserNID,
|
roomNID, targetUserNID,
|
||||||
).Scan(&membership)
|
).Scan(&membership)
|
||||||
return
|
return
|
||||||
|
@ -179,7 +180,7 @@ func (s *membershipStatements) updateMembership(
|
||||||
senderUserNID types.EventStateKeyNID, membership membershipState,
|
senderUserNID types.EventStateKeyNID, membership membershipState,
|
||||||
eventNID types.EventNID,
|
eventNID types.EventNID,
|
||||||
) error {
|
) error {
|
||||||
_, err := txn.Stmt(s.updateMembershipStmt).Exec(
|
_, err := common.TxStmt(txn, s.updateMembershipStmt).Exec(
|
||||||
roomNID, targetUserNID, senderUserNID, membership, eventNID,
|
roomNID, targetUserNID, senderUserNID, membership, eventNID,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -17,6 +17,7 @@ package storage
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,7 +74,7 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
|
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
|
||||||
_, err := txn.Stmt(s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID))
|
_, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,5 +82,5 @@ func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEvent
|
||||||
// Returns sql.ErrNoRows if the event reference doesn't exist.
|
// Returns sql.ErrNoRows if the event reference doesn't exist.
|
||||||
func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error {
|
func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error {
|
||||||
var ok int64
|
var ok int64
|
||||||
return txn.Stmt(s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok)
|
return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -82,21 +83,13 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
||||||
func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
||||||
var roomNID int64
|
var roomNID int64
|
||||||
stmt := s.insertRoomNIDStmt
|
err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
|
||||||
if txn != nil {
|
|
||||||
stmt = txn.Stmt(stmt)
|
|
||||||
}
|
|
||||||
err := stmt.QueryRow(roomID).Scan(&roomNID)
|
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
||||||
var roomNID int64
|
var roomNID int64
|
||||||
stmt := s.selectRoomNIDStmt
|
err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
|
||||||
if txn != nil {
|
|
||||||
stmt = txn.Stmt(stmt)
|
|
||||||
}
|
|
||||||
err := stmt.QueryRow(roomID).Scan(&roomNID)
|
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +113,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
|
||||||
var nids pq.Int64Array
|
var nids pq.Int64Array
|
||||||
var lastEventSentNID int64
|
var lastEventSentNID int64
|
||||||
var stateSnapshotNID int64
|
var stateSnapshotNID int64
|
||||||
err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
|
err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
|
@ -135,7 +128,7 @@ func (s *roomStatements) updateLatestEventNIDs(
|
||||||
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID,
|
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID,
|
||||||
stateSnapshotNID types.StateSnapshotNID,
|
stateSnapshotNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
_, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(
|
_, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec(
|
||||||
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID),
|
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -136,7 +137,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, e
|
||||||
|
|
||||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||||
func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) {
|
func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) {
|
||||||
rows, err := txn.Stmt(s.selectRoomIDsWithMembershipStmt).Query(userID, membership)
|
rows, err := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt).Query(userID, membership)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -155,7 +156,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, us
|
||||||
|
|
||||||
// CurrentState returns all the current state events for the given room.
|
// CurrentState returns all the current state events for the given room.
|
||||||
func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) {
|
func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) {
|
||||||
rows, err := txn.Stmt(s.selectCurrentStateStmt).Query(roomID)
|
rows, err := common.TxStmt(txn, s.selectCurrentStateStmt).Query(roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -165,21 +166,21 @@ func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error {
|
func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error {
|
||||||
_, err := txn.Stmt(s.deleteRoomStateByEventIDStmt).Exec(eventID)
|
_, err := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt).Exec(eventID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) upsertRoomState(
|
func (s *currentRoomStateStatements) upsertRoomState(
|
||||||
txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64,
|
txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64,
|
||||||
) error {
|
) error {
|
||||||
_, err := txn.Stmt(s.upsertRoomStateStmt).Exec(
|
_, err := common.TxStmt(txn, s.upsertRoomStateStmt).Exec(
|
||||||
event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt,
|
event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
|
func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
|
||||||
rows, err := txn.Stmt(s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs))
|
rows, err := common.TxStmt(txn, s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
@ -105,7 +106,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *outputRoomEventsStatements) selectStateInRange(
|
func (s *outputRoomEventsStatements) selectStateInRange(
|
||||||
txn *sql.Tx, oldPos, newPos types.StreamPosition,
|
txn *sql.Tx, oldPos, newPos types.StreamPosition,
|
||||||
) (map[string]map[string]bool, map[string]streamEvent, error) {
|
) (map[string]map[string]bool, map[string]streamEvent, error) {
|
||||||
rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos)
|
rows, err := common.TxStmt(txn, s.selectStateInRangeStmt).Query(oldPos, newPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -167,12 +168,8 @@ func (s *outputRoomEventsStatements) selectStateInRange(
|
||||||
// then this function should only ever be used at startup, as it will race with inserting events if it is
|
// then this function should only ever be used at startup, as it will race with inserting events if it is
|
||||||
// done afterwards. If there are no inserted events, 0 is returned.
|
// done afterwards. If there are no inserted events, 0 is returned.
|
||||||
func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) {
|
func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) {
|
||||||
stmt := s.selectMaxIDStmt
|
|
||||||
if txn != nil {
|
|
||||||
stmt = txn.Stmt(stmt)
|
|
||||||
}
|
|
||||||
var nullableID sql.NullInt64
|
var nullableID sql.NullInt64
|
||||||
err = stmt.QueryRow().Scan(&nullableID)
|
err = common.TxStmt(txn, s.selectMaxIDStmt).QueryRow().Scan(&nullableID)
|
||||||
if nullableID.Valid {
|
if nullableID.Valid {
|
||||||
id = nullableID.Int64
|
id = nullableID.Int64
|
||||||
}
|
}
|
||||||
|
@ -182,7 +179,7 @@ func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err err
|
||||||
// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
|
// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
|
||||||
// of the inserted event.
|
// of the inserted event.
|
||||||
func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) {
|
func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) {
|
||||||
err = txn.Stmt(s.insertEventStmt).QueryRow(
|
err = common.TxStmt(txn, s.insertEventStmt).QueryRow(
|
||||||
event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState),
|
event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState),
|
||||||
).Scan(&streamPos)
|
).Scan(&streamPos)
|
||||||
return
|
return
|
||||||
|
@ -209,11 +206,7 @@ func (s *outputRoomEventsStatements) selectRecentEvents(
|
||||||
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
|
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
|
||||||
// from the database.
|
// from the database.
|
||||||
func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
|
func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
|
||||||
stmt := s.selectEventsStmt
|
rows, err := common.TxStmt(txn, s.selectEventsStmt).Query(pq.StringArray(eventIDs))
|
||||||
if txn != nil {
|
|
||||||
stmt = txn.Stmt(stmt)
|
|
||||||
}
|
|
||||||
rows, err := stmt.Query(pq.StringArray(eventIDs))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue