Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 173 additions & 6 deletions pkg/paralleltestctx/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,168 @@ func (a *ctxAnalyzer) doesCreateTimeoutContext(call *ast.CallExpr) bool {
return false
}

type timeoutHelperSet map[types.Object]struct{}

func (a *ctxAnalyzer) collectTimeoutHelpers(pass *analysis.Pass, testFiles []*ast.File) timeoutHelperSet {
candidates := map[*ast.FuncDecl]types.Object{}
for _, file := range testFiles {
for _, decl := range file.Decls {
fd, ok := decl.(*ast.FuncDecl)
if !ok || fd.Body == nil || !funcReturnsContext(pass, fd) {
continue
}

obj := pass.TypesInfo.Defs[fd.Name]
if obj == nil {
continue
}
candidates[fd] = obj
}
}

helpers := timeoutHelperSet{}
for changed := true; changed; {
changed = false
for fd, obj := range candidates {
if _, ok := helpers[obj]; ok {
continue
}
if a.functionReturnsTimeoutContext(pass, fd, helpers) {
helpers[obj] = struct{}{}
changed = true
}
}
}
return helpers
}

func funcReturnsContext(pass *analysis.Pass, fd *ast.FuncDecl) bool {
if fd.Type.Results == nil || len(fd.Type.Results.List) == 0 {
return false
}

return isContextType(pass, pass.TypesInfo.TypeOf(fd.Type.Results.List[0].Type))
}

func (a *ctxAnalyzer) functionReturnsTimeoutContext(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) bool {
found := false
ast.Inspect(fd.Body, func(n ast.Node) bool {
if found {
return false
}

switch n := n.(type) {
case *ast.FuncLit:
return false
case *ast.ReturnStmt:
if len(n.Results) == 0 {
return true
}
found = a.exprContainsTimeoutContext(pass, n.Results[0], helpers)
return !found
default:
return true
}
})
return found
}

func (a *ctxAnalyzer) exprContainsTimeoutContext(pass *analysis.Pass, expr ast.Expr, helpers timeoutHelperSet) bool {
found := false
ast.Inspect(expr, func(n ast.Node) bool {
if found {
return false
}

call, ok := n.(*ast.CallExpr)
if !ok {
return true
}

found = a.doesCreateTimeoutContext(call) || callsTimeoutHelper(pass, call, helpers)
return !found
})
return found
}

func callsTimeoutHelper(pass *analysis.Pass, call *ast.CallExpr, helpers timeoutHelperSet) bool {
if len(helpers) == 0 {
return false
}

obj := callObject(pass, call)
if obj == nil {
return false
}
_, ok := helpers[obj]
return ok
}

func callObject(pass *analysis.Pass, call *ast.CallExpr) types.Object {
switch fun := call.Fun.(type) {
case *ast.Ident:
return pass.TypesInfo.Uses[fun]
case *ast.SelectorExpr:
return pass.TypesInfo.Uses[fun.Sel]
default:
return nil
}
}

func isContextType(pass *analysis.Pass, typ types.Type) bool {
if typ == nil {
return false
}

contextIface := contextInterface(pass)
if contextIface == nil {
return isStdlibContextType(typ)
}

return types.Implements(types.Unalias(typ), contextIface)
}

func contextInterface(pass *analysis.Pass) *types.Interface {
if pass.Pkg == nil {
return nil
}

for _, pkg := range pass.Pkg.Imports() {
if pkg.Path() != "context" {
continue
}

obj := pkg.Scope().Lookup("Context")
if obj == nil {
return nil
}

named, ok := obj.Type().(*types.Named)
if !ok {
return nil
}

iface, ok := named.Underlying().(*types.Interface)
if !ok {
return nil
}

return iface
}

return nil
}

func isStdlibContextType(typ types.Type) bool {
named, ok := types.Unalias(typ).(*types.Named)
if !ok {
return false
}

obj := named.Obj()
return obj.Name() == "Context" && obj.Pkg() != nil && obj.Pkg().Path() == "context"
}

func (a *ctxAnalyzer) getTimeoutFuncs() []timeoutFunc {
if a.timeoutFuncs != nil {
return a.timeoutFuncs
Expand Down Expand Up @@ -119,6 +281,8 @@ func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) {
return nil, nil
}

timeoutHelpers := a.collectTimeoutHelpers(pass, testFiles)

insp := inspector.New(testFiles)
nodeFilter := []ast.Node{(*ast.FuncDecl)(nil)}
insp.Preorder(nodeFilter, func(n ast.Node) {
Expand All @@ -127,20 +291,20 @@ func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) {
return
}

a.analyzeTestFunction(pass, fd)
a.analyzeTestFunction(pass, fd, timeoutHelpers)
})
return nil, nil
}

// analyzeTestFunction analyzes a test function to find timeout context usage after t.Parallel calls
func (a *ctxAnalyzer) analyzeTestFunction(pass *analysis.Pass, fd *ast.FuncDecl) {
func (a *ctxAnalyzer) analyzeTestFunction(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) {
testVarName := testParamName(fd)
if testVarName == "" {
return
}

// Collect all timeout contexts and their positions
timeoutCtxs := a.collectTimeoutContexts(pass, fd)
timeoutCtxs := a.collectTimeoutContexts(pass, fd, helpers)
if len(timeoutCtxs) == 0 {
return
}
Expand Down Expand Up @@ -187,7 +351,7 @@ type timeoutContext struct {
}

// collectTimeoutContexts finds all context variables created with a timeout/deadline functions
func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDecl) []timeoutContext {
func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDecl, helpers timeoutHelperSet) []timeoutContext {
var contexts []timeoutContext
contextMap := make(map[types.Object]*timeoutContext)

Expand All @@ -209,8 +373,11 @@ func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDe
return true
}

call, ok := as.Rhs[0].(*ast.CallExpr)
if ok && a.doesCreateTimeoutContext(call) {
if !isContextType(pass, obj.Type()) {
return true
}

if a.exprContainsTimeoutContext(pass, as.Rhs[0], helpers) {
// This is a timeout context creation
ctx := timeoutContext{
obj: obj,
Expand Down
55 changes: 55 additions & 0 deletions pkg/paralleltestctx/testdata/src/basic/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,39 @@ func TestDeadlineWarn(t *testing.T) {
})
}

func TestNestedTimeoutWarn(t *testing.T) {
ctx, cancel := wrapContext(context.WithTimeout(context.Background(), time.Second))
defer cancel()
t.Run("sub", func(t *testing.T) {
t.Parallel()
_ = ctx // want "timeout context ctx used after a t.Parallel call"
})
}

func TestNestedTimeoutNonContextLHSOK(t *testing.T) {
ok := contextWasCreated(context.WithTimeout(context.Background(), time.Second))
t.Run("sub", func(t *testing.T) {
t.Parallel()
_ = ok
})
}

func TestSamePackageTimeoutHelperWarn(t *testing.T) {
ctx := helperTimeoutContext(t)
t.Run("sub", func(t *testing.T) {
t.Parallel()
_ = ctx // want "timeout context ctx used after a t.Parallel call"
})
}

func TestSamePackageTimeoutHelperChainWarn(t *testing.T) {
ctx := chainedHelperTimeoutContext(t)
t.Run("sub", func(t *testing.T) {
t.Parallel()
doThing(ctx) // want "timeout context ctx used after a t.Parallel call"
})
}

func TestNoParallelOK(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
Expand Down Expand Up @@ -118,6 +151,28 @@ func TestOverwriteEarly(t *testing.T) {
})
}

func wrapContext(ctx context.Context, cancel context.CancelFunc) (context.Context, context.CancelFunc) {
return ctx, cancel
}

func contextWasCreated(ctx context.Context, cancel context.CancelFunc) bool {
cancel()
return ctx != nil
}

func helperTimeoutContext(t *testing.T) context.Context {
return wrapOnlyContext(context.WithTimeout(context.Background(), time.Second))
}

func chainedHelperTimeoutContext(t *testing.T) context.Context {
return helperTimeoutContext(t)
}

func wrapOnlyContext(ctx context.Context, cancel context.CancelFunc) context.Context {
cancel()
return ctx
}

func doThing(ctx context.Context) {
_ = ctx
}
13 changes: 7 additions & 6 deletions pkg/paralleltestctx/testdata/src/custom/custom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,18 @@ func TestStandardTimeoutWarn(t *testing.T) {
})
}

// Test function that should NOT be detected by default
func NotATimeoutFunc(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
// Same-package helpers that return timeout contexts are detected even when
// they are not listed as custom funcs.
func SamePackageTimeoutFunc(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, timeout)
}

func TestNotDetectedByDefault(t *testing.T) {
ctx, cancel := NotATimeoutFunc(context.Background(), time.Second)
func TestSamePackageHelperWarn(t *testing.T) {
ctx, cancel := SamePackageTimeoutFunc(context.Background(), time.Second)
defer cancel()
t.Run("sub", func(t *testing.T) {
t.Parallel() // should not warn when using default config
_ = ctx
t.Parallel()
_ = ctx // want "timeout context ctx used after a t.Parallel call"
})
}

Expand Down
Loading