diff --git a/pkg/paralleltestctx/analyzer.go b/pkg/paralleltestctx/analyzer.go index d83aa51..3bc4909 100644 --- a/pkg/paralleltestctx/analyzer.go +++ b/pkg/paralleltestctx/analyzer.go @@ -73,6 +73,168 @@ func (a *ctxAnalyzer) doesCreateTimeoutContext(call *ast.CallExpr) bool { return false } +type timeoutHelperSet map[types.Object]struct{} + +func (a *ctxAnalyzer) collectTimeoutHelpers(pass *analysis.Pass, testFiles []*ast.File) timeoutHelperSet { + candidates := map[*ast.FuncDecl]types.Object{} + for _, file := range testFiles { + for _, decl := range file.Decls { + fd, ok := decl.(*ast.FuncDecl) + if !ok || fd.Body == nil || !funcReturnsContext(pass, fd) { + continue + } + + obj := pass.TypesInfo.Defs[fd.Name] + if obj == nil { + continue + } + candidates[fd] = obj + } + } + + helpers := timeoutHelperSet{} + for changed := true; changed; { + changed = false + for fd, obj := range candidates { + if _, ok := helpers[obj]; ok { + continue + } + if a.functionReturnsTimeoutContext(pass, fd, helpers) { + helpers[obj] = struct{}{} + changed = true + } + } + } + return helpers +} + +func funcReturnsContext(pass *analysis.Pass, fd *ast.FuncDecl) bool { + if fd.Type.Results == nil || len(fd.Type.Results.List) == 0 { + return false + } + + return isContextType(pass, pass.TypesInfo.TypeOf(fd.Type.Results.List[0].Type)) +} + +func (a *ctxAnalyzer) functionReturnsTimeoutContext(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) bool { + found := false + ast.Inspect(fd.Body, func(n ast.Node) bool { + if found { + return false + } + + switch n := n.(type) { + case *ast.FuncLit: + return false + case *ast.ReturnStmt: + if len(n.Results) == 0 { + return true + } + found = a.exprContainsTimeoutContext(pass, n.Results[0], helpers) + return !found + default: + return true + } + }) + return found +} + +func (a *ctxAnalyzer) exprContainsTimeoutContext(pass *analysis.Pass, expr ast.Expr, helpers timeoutHelperSet) bool { + found := false + ast.Inspect(expr, func(n ast.Node) bool { + if found { + return false + } + + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + found = a.doesCreateTimeoutContext(call) || callsTimeoutHelper(pass, call, helpers) + return !found + }) + return found +} + +func callsTimeoutHelper(pass *analysis.Pass, call *ast.CallExpr, helpers timeoutHelperSet) bool { + if len(helpers) == 0 { + return false + } + + obj := callObject(pass, call) + if obj == nil { + return false + } + _, ok := helpers[obj] + return ok +} + +func callObject(pass *analysis.Pass, call *ast.CallExpr) types.Object { + switch fun := call.Fun.(type) { + case *ast.Ident: + return pass.TypesInfo.Uses[fun] + case *ast.SelectorExpr: + return pass.TypesInfo.Uses[fun.Sel] + default: + return nil + } +} + +func isContextType(pass *analysis.Pass, typ types.Type) bool { + if typ == nil { + return false + } + + contextIface := contextInterface(pass) + if contextIface == nil { + return isStdlibContextType(typ) + } + + return types.Implements(types.Unalias(typ), contextIface) +} + +func contextInterface(pass *analysis.Pass) *types.Interface { + if pass.Pkg == nil { + return nil + } + + for _, pkg := range pass.Pkg.Imports() { + if pkg.Path() != "context" { + continue + } + + obj := pkg.Scope().Lookup("Context") + if obj == nil { + return nil + } + + named, ok := obj.Type().(*types.Named) + if !ok { + return nil + } + + iface, ok := named.Underlying().(*types.Interface) + if !ok { + return nil + } + + return iface + } + + return nil +} + +func isStdlibContextType(typ types.Type) bool { + named, ok := types.Unalias(typ).(*types.Named) + if !ok { + return false + } + + obj := named.Obj() + return obj.Name() == "Context" && obj.Pkg() != nil && obj.Pkg().Path() == "context" +} + func (a *ctxAnalyzer) getTimeoutFuncs() []timeoutFunc { if a.timeoutFuncs != nil { return a.timeoutFuncs @@ -119,6 +281,8 @@ func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) { return nil, nil } + timeoutHelpers := a.collectTimeoutHelpers(pass, testFiles) + insp := inspector.New(testFiles) nodeFilter := []ast.Node{(*ast.FuncDecl)(nil)} insp.Preorder(nodeFilter, func(n ast.Node) { @@ -127,20 +291,20 @@ func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) { return } - a.analyzeTestFunction(pass, fd) + a.analyzeTestFunction(pass, fd, timeoutHelpers) }) return nil, nil } // analyzeTestFunction analyzes a test function to find timeout context usage after t.Parallel calls -func (a *ctxAnalyzer) analyzeTestFunction(pass *analysis.Pass, fd *ast.FuncDecl) { +func (a *ctxAnalyzer) analyzeTestFunction(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) { testVarName := testParamName(fd) if testVarName == "" { return } // Collect all timeout contexts and their positions - timeoutCtxs := a.collectTimeoutContexts(pass, fd) + timeoutCtxs := a.collectTimeoutContexts(pass, fd, helpers) if len(timeoutCtxs) == 0 { return } @@ -187,7 +351,7 @@ type timeoutContext struct { } // collectTimeoutContexts finds all context variables created with a timeout/deadline functions -func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDecl) []timeoutContext { +func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) []timeoutContext { var contexts []timeoutContext contextMap := make(map[types.Object]*timeoutContext) @@ -209,8 +373,11 @@ func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDe return true } - call, ok := as.Rhs[0].(*ast.CallExpr) - if ok && a.doesCreateTimeoutContext(call) { + if !isContextType(pass, obj.Type()) { + return true + } + + if a.exprContainsTimeoutContext(pass, as.Rhs[0], helpers) { // This is a timeout context creation ctx := timeoutContext{ obj: obj, diff --git a/pkg/paralleltestctx/testdata/src/basic/basic_test.go b/pkg/paralleltestctx/testdata/src/basic/basic_test.go index 0618092..2520ef9 100644 --- a/pkg/paralleltestctx/testdata/src/basic/basic_test.go +++ b/pkg/paralleltestctx/testdata/src/basic/basic_test.go @@ -55,6 +55,39 @@ func TestDeadlineWarn(t *testing.T) { }) } +func TestNestedTimeoutWarn(t *testing.T) { + ctx, cancel := wrapContext(context.WithTimeout(context.Background(), time.Second)) + defer cancel() + t.Run("sub", func(t *testing.T) { + t.Parallel() + _ = ctx // want "timeout context ctx used after a t.Parallel call" + }) +} + +func TestNestedTimeoutNonContextLHSOK(t *testing.T) { + ok := contextWasCreated(context.WithTimeout(context.Background(), time.Second)) + t.Run("sub", func(t *testing.T) { + t.Parallel() + _ = ok + }) +} + +func TestSamePackageTimeoutHelperWarn(t *testing.T) { + ctx := helperTimeoutContext(t) + t.Run("sub", func(t *testing.T) { + t.Parallel() + _ = ctx // want "timeout context ctx used after a t.Parallel call" + }) +} + +func TestSamePackageTimeoutHelperChainWarn(t *testing.T) { + ctx := chainedHelperTimeoutContext(t) + t.Run("sub", func(t *testing.T) { + t.Parallel() + doThing(ctx) // want "timeout context ctx used after a t.Parallel call" + }) +} + func TestNoParallelOK(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -118,6 +151,28 @@ func TestOverwriteEarly(t *testing.T) { }) } +func wrapContext(ctx context.Context, cancel context.CancelFunc) (context.Context, context.CancelFunc) { + return ctx, cancel +} + +func contextWasCreated(ctx context.Context, cancel context.CancelFunc) bool { + cancel() + return ctx != nil +} + +func helperTimeoutContext(t *testing.T) context.Context { + return wrapOnlyContext(context.WithTimeout(context.Background(), time.Second)) +} + +func chainedHelperTimeoutContext(t *testing.T) context.Context { + return helperTimeoutContext(t) +} + +func wrapOnlyContext(ctx context.Context, cancel context.CancelFunc) context.Context { + cancel() + return ctx +} + func doThing(ctx context.Context) { _ = ctx } diff --git a/pkg/paralleltestctx/testdata/src/custom/custom_test.go b/pkg/paralleltestctx/testdata/src/custom/custom_test.go index 883cc42..36ff352 100644 --- a/pkg/paralleltestctx/testdata/src/custom/custom_test.go +++ b/pkg/paralleltestctx/testdata/src/custom/custom_test.go @@ -46,17 +46,18 @@ func TestStandardTimeoutWarn(t *testing.T) { }) } -// Test function that should NOT be detected by default -func NotATimeoutFunc(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { +// Same-package helpers that return timeout contexts are detected even when +// they are not listed as custom funcs. +func SamePackageTimeoutFunc(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(parent, timeout) } -func TestNotDetectedByDefault(t *testing.T) { - ctx, cancel := NotATimeoutFunc(context.Background(), time.Second) +func TestSamePackageHelperWarn(t *testing.T) { + ctx, cancel := SamePackageTimeoutFunc(context.Background(), time.Second) defer cancel() t.Run("sub", func(t *testing.T) { - t.Parallel() // should not warn when using default config - _ = ctx + t.Parallel() + _ = ctx // want "timeout context ctx used after a t.Parallel call" }) }