diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 978eafd4..dde07701 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -205,7 +205,7 @@ func Send( util.GetLogger(httpReq.Context()).Infof("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction(context.Background()) + resp, jsonErr := t.processTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -253,11 +253,8 @@ type txnFederationClient interface { func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { results := make(map[string]gomatrixserverlib.PDUResult) - //var resultsMutex sync.Mutex - var wg sync.WaitGroup var tasks []*inputTask - wg.Add(1) // for processEDUs for _, pdu := range t.PDUs { pduCountTotal.WithLabelValues("total").Inc() @@ -313,9 +310,6 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res input: newSendFIFOQueue(), }) worker := v.(*inputWorker) - if !worker.running.Load() { - go worker.run() - } wg.Add(1) task := &inputTask{ ctx: ctx, @@ -325,13 +319,12 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res } tasks = append(tasks, task) worker.input.push(task) + if worker.running.CAS(false, true) { + go worker.run() + } } - go func() { - defer wg.Done() - t.processEDUs(ctx) - }() - + t.processEDUs(ctx) wg.Wait() for _, task := range tasks { @@ -351,9 +344,6 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res } func (t *inputWorker) run() { - if !t.running.CAS(false, true) { - return - } defer t.running.Store(false) for { task, ok := t.input.pop() @@ -371,7 +361,10 @@ func (t *inputWorker) run() { return default: evStart := time.Now() - task.err = task.t.processEvent(task.ctx, task.event) + // TODO: Is 5 minutes too long? + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + task.err = task.t.processEvent(ctx, task.event) + cancel() task.duration = time.Since(evStart) if err := task.err; err != nil { switch err.(type) {