From b758757958f64b7215ad62b6515d2bbcb75b84bc Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 11 May 2026 05:05:28 +0000 Subject: [PATCH 1/2] fix(pkg/paralleltestctx): detect nested timeout context calls --- pkg/paralleltestctx/analyzer.go | 79 ++++++++++++++++++- .../testdata/src/basic/basic_test.go | 26 ++++++ 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/pkg/paralleltestctx/analyzer.go b/pkg/paralleltestctx/analyzer.go index d83aa51..0bb286f 100644 --- a/pkg/paralleltestctx/analyzer.go +++ b/pkg/paralleltestctx/analyzer.go @@ -73,6 +73,78 @@ func (a *ctxAnalyzer) doesCreateTimeoutContext(call *ast.CallExpr) bool { return false } +func (a *ctxAnalyzer) exprContainsTimeoutContext(expr ast.Expr) bool { + found := false + ast.Inspect(expr, func(n ast.Node) bool { + if found { + return false + } + + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + found = a.doesCreateTimeoutContext(call) + return !found + }) + return found +} + +func isContextType(pass *analysis.Pass, typ types.Type) bool { + if typ == nil { + return false + } + + contextIface := contextInterface(pass) + if contextIface == nil { + return isStdlibContextType(typ) + } + + return types.Implements(types.Unalias(typ), contextIface) +} + +func contextInterface(pass *analysis.Pass) *types.Interface { + if pass.Pkg == nil { + return nil + } + + for _, pkg := range pass.Pkg.Imports() { + if pkg.Path() != "context" { + continue + } + + obj := pkg.Scope().Lookup("Context") + if obj == nil { + return nil + } + + named, ok := obj.Type().(*types.Named) + if !ok { + return nil + } + + iface, ok := named.Underlying().(*types.Interface) + if !ok { + return nil + } + + return iface + } + + return nil +} + +func isStdlibContextType(typ types.Type) bool { + named, ok := types.Unalias(typ).(*types.Named) + if !ok { + return false + } + + obj := named.Obj() + return obj.Name() == "Context" && obj.Pkg() != nil && obj.Pkg().Path() == "context" +} + func (a *ctxAnalyzer) getTimeoutFuncs() []timeoutFunc { if a.timeoutFuncs != nil { return a.timeoutFuncs @@ -209,8 +281,11 @@ func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDe return true } - call, ok := as.Rhs[0].(*ast.CallExpr) - if ok && a.doesCreateTimeoutContext(call) { + if !isContextType(pass, obj.Type()) { + return true + } + + if a.exprContainsTimeoutContext(as.Rhs[0]) { // This is a timeout context creation ctx := timeoutContext{ obj: obj, diff --git a/pkg/paralleltestctx/testdata/src/basic/basic_test.go b/pkg/paralleltestctx/testdata/src/basic/basic_test.go index 0618092..866d0c1 100644 --- a/pkg/paralleltestctx/testdata/src/basic/basic_test.go +++ b/pkg/paralleltestctx/testdata/src/basic/basic_test.go @@ -55,6 +55,23 @@ func TestDeadlineWarn(t *testing.T) { }) } +func TestNestedTimeoutWarn(t *testing.T) { + ctx, cancel := wrapContext(context.WithTimeout(context.Background(), time.Second)) + defer cancel() + t.Run("sub", func(t *testing.T) { + t.Parallel() + _ = ctx // want "timeout context ctx used after a t.Parallel call" + }) +} + +func TestNestedTimeoutNonContextLHSOK(t *testing.T) { + ok := contextWasCreated(context.WithTimeout(context.Background(), time.Second)) + t.Run("sub", func(t *testing.T) { + t.Parallel() + _ = ok + }) +} + func TestNoParallelOK(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -118,6 +135,15 @@ func TestOverwriteEarly(t *testing.T) { }) } +func wrapContext(ctx context.Context, cancel context.CancelFunc) (context.Context, context.CancelFunc) { + return ctx, cancel +} + +func contextWasCreated(ctx context.Context, cancel context.CancelFunc) bool { + cancel() + return ctx != nil +} + func doThing(ctx context.Context) { _ = ctx } From fa4fd5a179dffcf77082833cacf0d518b513da36 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 11 May 2026 05:44:13 +0000 Subject: [PATCH 2/2] fix(pkg/paralleltestctx): summarize timeout context helpers --- pkg/paralleltestctx/analyzer.go | 106 ++++++++++++++++-- .../testdata/src/basic/basic_test.go | 29 +++++ .../testdata/src/custom/custom_test.go | 13 ++- 3 files changed, 135 insertions(+), 13 deletions(-) diff --git a/pkg/paralleltestctx/analyzer.go b/pkg/paralleltestctx/analyzer.go index 0bb286f..3bc4909 100644 --- a/pkg/paralleltestctx/analyzer.go +++ b/pkg/paralleltestctx/analyzer.go @@ -73,7 +73,73 @@ func (a *ctxAnalyzer) doesCreateTimeoutContext(call *ast.CallExpr) bool { return false } -func (a *ctxAnalyzer) exprContainsTimeoutContext(expr ast.Expr) bool { +type timeoutHelperSet map[types.Object]struct{} + +func (a *ctxAnalyzer) collectTimeoutHelpers(pass *analysis.Pass, testFiles []*ast.File) timeoutHelperSet { + candidates := map[*ast.FuncDecl]types.Object{} + for _, file := range testFiles { + for _, decl := range file.Decls { + fd, ok := decl.(*ast.FuncDecl) + if !ok || fd.Body == nil || !funcReturnsContext(pass, fd) { + continue + } + + obj := pass.TypesInfo.Defs[fd.Name] + if obj == nil { + continue + } + candidates[fd] = obj + } + } + + helpers := timeoutHelperSet{} + for changed := true; changed; { + changed = false + for fd, obj := range candidates { + if _, ok := helpers[obj]; ok { + continue + } + if a.functionReturnsTimeoutContext(pass, fd, helpers) { + helpers[obj] = struct{}{} + changed = true + } + } + } + return helpers +} + +func funcReturnsContext(pass *analysis.Pass, fd *ast.FuncDecl) bool { + if fd.Type.Results == nil || len(fd.Type.Results.List) == 0 { + return false + } + + return isContextType(pass, pass.TypesInfo.TypeOf(fd.Type.Results.List[0].Type)) +} + +func (a *ctxAnalyzer) functionReturnsTimeoutContext(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) bool { + found := false + ast.Inspect(fd.Body, func(n ast.Node) bool { + if found { + return false + } + + switch n := n.(type) { + case *ast.FuncLit: + return false + case *ast.ReturnStmt: + if len(n.Results) == 0 { + return true + } + found = a.exprContainsTimeoutContext(pass, n.Results[0], helpers) + return !found + default: + return true + } + }) + return found +} + +func (a *ctxAnalyzer) exprContainsTimeoutContext(pass *analysis.Pass, expr ast.Expr, helpers timeoutHelperSet) bool { found := false ast.Inspect(expr, func(n ast.Node) bool { if found { @@ -85,12 +151,36 @@ func (a *ctxAnalyzer) exprContainsTimeoutContext(expr ast.Expr) bool { return true } - found = a.doesCreateTimeoutContext(call) + found = a.doesCreateTimeoutContext(call) || callsTimeoutHelper(pass, call, helpers) return !found }) return found } +func callsTimeoutHelper(pass *analysis.Pass, call *ast.CallExpr, helpers timeoutHelperSet) bool { + if len(helpers) == 0 { + return false + } + + obj := callObject(pass, call) + if obj == nil { + return false + } + _, ok := helpers[obj] + return ok +} + +func callObject(pass *analysis.Pass, call *ast.CallExpr) types.Object { + switch fun := call.Fun.(type) { + case *ast.Ident: + return pass.TypesInfo.Uses[fun] + case *ast.SelectorExpr: + return pass.TypesInfo.Uses[fun.Sel] + default: + return nil + } +} + func isContextType(pass *analysis.Pass, typ types.Type) bool { if typ == nil { return false @@ -191,6 +281,8 @@ func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) { return nil, nil } + timeoutHelpers := a.collectTimeoutHelpers(pass, testFiles) + insp := inspector.New(testFiles) nodeFilter := []ast.Node{(*ast.FuncDecl)(nil)} insp.Preorder(nodeFilter, func(n ast.Node) { @@ -199,20 +291,20 @@ func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) { return } - a.analyzeTestFunction(pass, fd) + a.analyzeTestFunction(pass, fd, timeoutHelpers) }) return nil, nil } // analyzeTestFunction analyzes a test function to find timeout context usage after t.Parallel calls -func (a *ctxAnalyzer) analyzeTestFunction(pass *analysis.Pass, fd *ast.FuncDecl) { +func (a *ctxAnalyzer) analyzeTestFunction(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) { testVarName := testParamName(fd) if testVarName == "" { return } // Collect all timeout contexts and their positions - timeoutCtxs := a.collectTimeoutContexts(pass, fd) + timeoutCtxs := a.collectTimeoutContexts(pass, fd, helpers) if len(timeoutCtxs) == 0 { return } @@ -259,7 +351,7 @@ type timeoutContext struct { } // collectTimeoutContexts finds all context variables created with a timeout/deadline functions -func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDecl) []timeoutContext { +func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) []timeoutContext { var contexts []timeoutContext contextMap := make(map[types.Object]*timeoutContext) @@ -285,7 +377,7 @@ func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDe return true } - if a.exprContainsTimeoutContext(as.Rhs[0]) { + if a.exprContainsTimeoutContext(pass, as.Rhs[0], helpers) { // This is a timeout context creation ctx := timeoutContext{ obj: obj, diff --git a/pkg/paralleltestctx/testdata/src/basic/basic_test.go b/pkg/paralleltestctx/testdata/src/basic/basic_test.go index 866d0c1..2520ef9 100644 --- a/pkg/paralleltestctx/testdata/src/basic/basic_test.go +++ b/pkg/paralleltestctx/testdata/src/basic/basic_test.go @@ -72,6 +72,22 @@ func TestNestedTimeoutNonContextLHSOK(t *testing.T) { }) } +func TestSamePackageTimeoutHelperWarn(t *testing.T) { + ctx := helperTimeoutContext(t) + t.Run("sub", func(t *testing.T) { + t.Parallel() + _ = ctx // want "timeout context ctx used after a t.Parallel call" + }) +} + +func TestSamePackageTimeoutHelperChainWarn(t *testing.T) { + ctx := chainedHelperTimeoutContext(t) + t.Run("sub", func(t *testing.T) { + t.Parallel() + doThing(ctx) // want "timeout context ctx used after a t.Parallel call" + }) +} + func TestNoParallelOK(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -144,6 +160,19 @@ func contextWasCreated(ctx context.Context, cancel context.CancelFunc) bool { return ctx != nil } +func helperTimeoutContext(t *testing.T) context.Context { + return wrapOnlyContext(context.WithTimeout(context.Background(), time.Second)) +} + +func chainedHelperTimeoutContext(t *testing.T) context.Context { + return helperTimeoutContext(t) +} + +func wrapOnlyContext(ctx context.Context, cancel context.CancelFunc) context.Context { + cancel() + return ctx +} + func doThing(ctx context.Context) { _ = ctx } diff --git a/pkg/paralleltestctx/testdata/src/custom/custom_test.go b/pkg/paralleltestctx/testdata/src/custom/custom_test.go index 883cc42..36ff352 100644 --- a/pkg/paralleltestctx/testdata/src/custom/custom_test.go +++ b/pkg/paralleltestctx/testdata/src/custom/custom_test.go @@ -46,17 +46,18 @@ func TestStandardTimeoutWarn(t *testing.T) { }) } -// Test function that should NOT be detected by default -func NotATimeoutFunc(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { +// Same-package helpers that return timeout contexts are detected even when +// they are not listed as custom funcs. +func SamePackageTimeoutFunc(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(parent, timeout) } -func TestNotDetectedByDefault(t *testing.T) { - ctx, cancel := NotATimeoutFunc(context.Background(), time.Second) +func TestSamePackageHelperWarn(t *testing.T) { + ctx, cancel := SamePackageTimeoutFunc(context.Background(), time.Second) defer cancel() t.Run("sub", func(t *testing.T) { - t.Parallel() // should not warn when using default config - _ = ctx + t.Parallel() + _ = ctx // want "timeout context ctx used after a t.Parallel call" }) }