diff --git a/go.mod b/go.mod index b04c503b..30e3c97b 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20210722084509-b9eb787c3967 + github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876 github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0 github.com/matrix-org/pinecone v0.0.0-20210623102758-74f885644c1b github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 diff --git a/go.sum b/go.sum index 72986d3d..a090dbc3 100644 --- a/go.sum +++ b/go.sum @@ -1027,8 +1027,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20210722084509-b9eb787c3967 h1:9NvlktHDcqMxbytdLd92rjfFtn59QwxqfV7i2BQSyHA= -github.com/matrix-org/gomatrixserverlib v0.0.0-20210722084509-b9eb787c3967/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876 h1:6ypwCtgRLK0v/hGWvnd847+KTo9BSkP9N0A4qSniP4E= +github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0 h1:HZCzy4oVzz55e+cOMiX/JtSF2UOY1evBl2raaE7ACcU= github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/pinecone v0.0.0-20210623102758-74f885644c1b h1:5X5vdWQ13xrNkJVqaJHPsrt7rKkMJH5iac0EtfOuxSg= diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 44435bfd..2511097d 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -136,6 +136,8 @@ func (r *Inputer) updateMembership( return updateToJoinMembership(mu, add, updates) case gomatrixserverlib.Leave, gomatrixserverlib.Ban: return updateToLeaveMembership(mu, add, newMembership, updates) + case gomatrixserverlib.Knock: + return updateToKnockMembership(mu, add, updates) default: panic(fmt.Errorf( "input: membership %q is not one of the allowed values", newMembership, @@ -220,6 +222,18 @@ func updateToLeaveMembership( return updates, nil } +func updateToKnockMembership( + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + if mu.IsLeave() { + _, err := mu.SetToKnock(add) + if err != nil { + return nil, err + } + } + return updates, nil +} + // membershipChanges pairs up the membership state changes. func membershipChanges(removed, added []types.StateEntry) []stateChange { changes := pairUpChanges(removed, added) diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 57f3a520..f1f589a3 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -86,6 +86,11 @@ func (u *MembershipUpdater) IsLeave() bool { return u.membership == tables.MembershipStateLeaveOrBan } +// IsKnock implements types.MembershipUpdater +func (u *MembershipUpdater) IsKnock() bool { + return u.membership == tables.MembershipStateKnock +} + // SetToInvite implements types.MembershipUpdater func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { var inserted bool @@ -180,3 +185,27 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s }) return inviteEventIDs, err } + +// SetToKnock implements types.MembershipUpdater +func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, error) { + var inserted bool + err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) + } + if u.membership != tables.MembershipStateKnock { + // Look up the NID of the new knock event + nIDs, err := u.d.EventNIDs(u.ctx, []string{event.EventID()}) + if err != nil { + return fmt.Errorf("u.d.EventNIDs: %w", err) + } + + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } + } + return nil + }) + return inserted, err +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index f762cb71..8720d400 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -120,6 +120,7 @@ const ( MembershipStateLeaveOrBan MembershipState = 1 MembershipStateInvite MembershipState = 2 MembershipStateJoin MembershipState = 3 + MembershipStateKnock MembershipState = 4 ) type Membership interface {