diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 0af40758..d0f36a6f 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -29,6 +29,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/matrix-org/dendrite/common/config" @@ -70,12 +71,17 @@ func init() { } // sessionsDict keeps track of completed auth stages for each session. +// It shouldn't be passed by value because it contains a mutex. type sessionsDict struct { + sync.Mutex sessions map[string][]authtypes.LoginType } // GetCompletedStages returns the completed stages for a session. -func (d sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { +func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { + d.Lock() + defer d.Unlock() + if completedStages, ok := d.sessions[sessionID]; ok { return completedStages } @@ -91,12 +97,15 @@ func newSessionsDict() *sessionsDict { // AddCompletedSessionStage records that a session has completed an auth stage. func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) { - for _, completedStage := range sessions.GetCompletedStages(sessionID) { + sessions.Lock() + defer sessions.Unlock() + + for _, completedStage := range sessions.sessions[sessionID] { if completedStage == stage { return } } - sessions.sessions[sessionID] = append(sessions.GetCompletedStages(sessionID), stage) + sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage) } var (