diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go index fbd983be..248dbe38 100644 --- a/internal/sqlutil/trace.go +++ b/internal/sqlutil/trace.go @@ -22,7 +22,10 @@ import ( "io" "os" "regexp" + "runtime" + "strconv" "strings" + "sync" "time" "github.com/matrix-org/dendrite/internal/config" @@ -31,6 +34,7 @@ import ( ) var tracingEnabled = os.Getenv("DENDRITE_TRACE_SQL") == "1" +var goidToWriter sync.Map type traceInterceptor struct { sqlmw.NullInterceptor @@ -40,6 +44,8 @@ func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.St startedAt := time.Now() rows, err := stmt.QueryContext(ctx, args) + trackGoID(query) + logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) return rows, err @@ -49,6 +55,8 @@ func (in *traceInterceptor) StmtExecContext(ctx context.Context, stmt driver.Stm startedAt := time.Now() result, err := stmt.ExecContext(ctx, args) + trackGoID(query) + logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) return result, err @@ -75,6 +83,19 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [ return err } +func trackGoID(query string) { + thisGoID := goid() + if _, ok := goidToWriter.Load(thisGoID); ok { + return // we're on a writer goroutine + } + + q := strings.TrimSpace(query) + if strings.HasPrefix(q, "SELECT") { + return // SELECTs can go on other goroutines + } + logrus.Warnf("unsafe goid: SQL executed not on an ExclusiveWriter: %s", q) +} + // Open opens a database specified by its database driver name and a driver-specific data source name, // usually consisting of at least a database name and connection information. Includes tracing driver // if DENDRITE_TRACE_SQL=1 @@ -119,3 +140,14 @@ func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) { func init() { registerDrivers() } + +func goid() int { + var buf [64]byte + n := runtime.Stack(buf[:], false) + idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] + id, err := strconv.Atoi(idField) + if err != nil { + panic(fmt.Sprintf("cannot get goroutine id: %v", err)) + } + return id +} diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go index 002bc32c..91dd77e4 100644 --- a/internal/sqlutil/writer_exclusive.go +++ b/internal/sqlutil/writer_exclusive.go @@ -60,6 +60,12 @@ func (w *ExclusiveWriter) run() { if !w.running.CAS(false, true) { return } + if tracingEnabled { + gid := goid() + goidToWriter.Store(gid, w) + defer goidToWriter.Delete(gid) + } + defer w.running.Store(false) for task := range w.todo { if task.db != nil && task.txn != nil {