Skip to content

Commit ddd7c57

Browse files
committed
feat: initial commit
0 parents  commit ddd7c57

7 files changed

Lines changed: 636 additions & 0 deletions

File tree

go.mod

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module github.com/coder/paralleltestctx
2+
3+
go 1.24.6
4+
5+
require golang.org/x/tools v0.36.0
6+
7+
require (
8+
golang.org/x/mod v0.27.0 // indirect
9+
golang.org/x/sync v0.16.0 // indirect
10+
)

go.sum

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
2+
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
3+
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
4+
golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
5+
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
6+
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
7+
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
8+
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=

main.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package main
2+
3+
import (
4+
"github.com/coder/paralleltestctx/pkg/paralleltestctx"
5+
"golang.org/x/tools/go/analysis/singlechecker"
6+
)
7+
8+
func main() { singlechecker.Main(paralleltestctx.Analyzer()) }

pkg/paralleltestctx/analyzer.go

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
package paralleltestctx
2+
3+
import (
4+
"flag"
5+
"go/ast"
6+
"go/token"
7+
"go/types"
8+
"strings"
9+
10+
"golang.org/x/tools/go/analysis"
11+
"golang.org/x/tools/go/ast/inspector"
12+
)
13+
14+
const Doc = `warn when a context with a timeout/deadline is used after a t.Parallel call`
15+
16+
func Analyzer() *analysis.Analyzer {
17+
return newCtxAnalyzer().analyzer
18+
}
19+
20+
type timeoutFunc struct {
21+
receiver string // empty string means no receiver specified
22+
name string
23+
}
24+
25+
type ctxAnalyzer struct {
26+
analyzer *analysis.Analyzer
27+
timeoutFuncFlag string
28+
timeoutFuncs []timeoutFunc // cache for flag parsed w/ defaults
29+
}
30+
31+
func newCtxAnalyzer() *ctxAnalyzer {
32+
a := &ctxAnalyzer{}
33+
var flags flag.FlagSet
34+
flags.StringVar(&a.timeoutFuncFlag, "timeout-funcs", "", "comma-separated list of additional function names that create timeout/deadline contexts")
35+
a.analyzer = &analysis.Analyzer{
36+
Name: "paralleltestctx",
37+
Doc: Doc,
38+
Run: a.run,
39+
Flags: flags,
40+
}
41+
return a
42+
}
43+
44+
// doesCreateTimeoutContext checks if the given call expression creates a
45+
// context with a deadline or timeout
46+
func (a *ctxAnalyzer) doesCreateTimeoutContext(call *ast.CallExpr) bool {
47+
timeoutFuncs := a.getTimeoutFuncs()
48+
49+
// Check for receiver calls (e.g. context.WithTimeout, testutil.Context)
50+
if se, ok := call.Fun.(*ast.SelectorExpr); ok {
51+
funcName := se.Sel.Name
52+
receiverName := extractReceiverName(se.X)
53+
54+
for _, fn := range timeoutFuncs {
55+
if fn.name == funcName {
56+
if fn.receiver == "" || fn.receiver == receiverName {
57+
return true
58+
}
59+
}
60+
}
61+
}
62+
63+
// Check for direct function calls (e.g., Context when no receiver specified)
64+
if id, ok := call.Fun.(*ast.Ident); ok {
65+
funcName := id.Name
66+
for _, fn := range timeoutFuncs {
67+
if fn.receiver == "" && fn.name == funcName {
68+
return true
69+
}
70+
}
71+
}
72+
73+
return false
74+
}
75+
76+
func (a *ctxAnalyzer) getTimeoutFuncs() []timeoutFunc {
77+
if a.timeoutFuncs != nil {
78+
return a.timeoutFuncs
79+
}
80+
81+
// Always include standard timeout functions (context.WithTimeout, context.WithDeadline)
82+
result := []timeoutFunc{
83+
{receiver: "context", name: "WithTimeout"},
84+
{receiver: "context", name: "WithDeadline"},
85+
}
86+
87+
if a.timeoutFuncFlag != "" {
88+
for f := range strings.SplitSeq(a.timeoutFuncFlag, ",") {
89+
if trimmed := strings.TrimSpace(f); trimmed != "" {
90+
if parts := strings.Split(trimmed, "."); len(parts) == 2 {
91+
// receiver.function format (e.g., "testutil.Context")
92+
result = append(result, timeoutFunc{receiver: parts[0], name: parts[1]})
93+
} else {
94+
// bare function name (e.g., "Context")
95+
result = append(result, timeoutFunc{receiver: "", name: trimmed})
96+
}
97+
}
98+
}
99+
}
100+
101+
a.timeoutFuncs = result // cache the result
102+
return result
103+
}
104+
105+
// filterTestFiles returns only the test files from the pass
106+
func filterTestFiles(pass *analysis.Pass) []*ast.File {
107+
var testFiles []*ast.File
108+
for _, file := range pass.Files {
109+
if strings.HasSuffix(pass.Fset.Position(file.Pos()).Filename, "_test.go") {
110+
testFiles = append(testFiles, file)
111+
}
112+
}
113+
return testFiles
114+
}
115+
116+
func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) {
117+
testFiles := filterTestFiles(pass)
118+
if len(testFiles) == 0 {
119+
return nil, nil
120+
}
121+
122+
insp := inspector.New(testFiles)
123+
nodeFilter := []ast.Node{(*ast.FuncDecl)(nil)}
124+
insp.Preorder(nodeFilter, func(n ast.Node) {
125+
fd := n.(*ast.FuncDecl)
126+
if !isTestFunction(fd) {
127+
return
128+
}
129+
130+
a.analyzeTestFunction(pass, fd)
131+
})
132+
return nil, nil
133+
}
134+
135+
// analyzeTestFunction analyzes a test function to find timeout context usage after t.Parallel calls
136+
func (a *ctxAnalyzer) analyzeTestFunction(pass *analysis.Pass, fd *ast.FuncDecl) {
137+
testVarName := testParamName(fd)
138+
if testVarName == "" {
139+
return
140+
}
141+
142+
// Collect all timeout contexts and their positions
143+
timeoutCtxs := a.collectTimeoutContexts(pass, fd)
144+
if len(timeoutCtxs) == 0 {
145+
return
146+
}
147+
148+
// Find all t.Parallel() calls and check for context usage after them
149+
a.checkContextUsageAfterParallel(pass, fd, testVarName, timeoutCtxs)
150+
}
151+
152+
func isTestFunction(fd *ast.FuncDecl) bool {
153+
if !strings.HasPrefix(fd.Name.Name, "Test") {
154+
return false
155+
}
156+
if fd.Type.Params == nil || len(fd.Type.Params.List) != 1 {
157+
return false
158+
}
159+
p := fd.Type.Params.List[0]
160+
se, ok := p.Type.(*ast.StarExpr)
161+
if !ok {
162+
return false
163+
}
164+
sel, ok := se.X.(*ast.SelectorExpr)
165+
if !ok || sel.Sel.Name != "T" {
166+
return false
167+
}
168+
if id, ok := sel.X.(*ast.Ident); !ok || id.Name != "testing" {
169+
return false
170+
}
171+
return true
172+
}
173+
174+
func testParamName(fd *ast.FuncDecl) string {
175+
p := fd.Type.Params.List[0]
176+
if len(p.Names) == 0 {
177+
return ""
178+
}
179+
return p.Names[0].Name
180+
}
181+
182+
// timeoutContext represents a context variable created with a timeout/deadline
183+
type timeoutContext struct {
184+
obj types.Object
185+
pos token.Pos // position where it was created
186+
}
187+
188+
// collectTimeoutContexts finds all context variables created with a timeout/deadline functions
189+
func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDecl) []timeoutContext {
190+
var contexts []timeoutContext
191+
192+
ast.Inspect(fd.Body, func(n ast.Node) bool {
193+
as, ok := n.(*ast.AssignStmt)
194+
if !ok {
195+
return true
196+
}
197+
if len(as.Lhs) == 0 || len(as.Rhs) == 0 {
198+
return true
199+
}
200+
call, ok := as.Rhs[0].(*ast.CallExpr)
201+
if !ok {
202+
return true
203+
}
204+
205+
if !a.doesCreateTimeoutContext(call) {
206+
return true
207+
}
208+
209+
id, ok := as.Lhs[0].(*ast.Ident)
210+
if !ok {
211+
return true
212+
}
213+
obj := pass.TypesInfo.ObjectOf(id)
214+
if obj == nil {
215+
return true
216+
}
217+
218+
contexts = append(contexts, timeoutContext{
219+
obj: obj,
220+
pos: as.Pos(),
221+
})
222+
return true
223+
})
224+
return contexts
225+
}
226+
227+
// checkContextUsageAfterParallel walks through the AST and reports timeout context usage after t.Parallel() calls
228+
func (a *ctxAnalyzer) checkContextUsageAfterParallel(pass *analysis.Pass, fd *ast.FuncDecl, testVarName string, timeoutCtxs []timeoutContext) {
229+
// Analyze each function scope (main test and subtests) separately
230+
a.analyzeScope(pass, fd.Body, testVarName, timeoutCtxs)
231+
}
232+
233+
// analyzeScope analyzes a specific scope (test function or subtest) for the pattern
234+
func (a *ctxAnalyzer) analyzeScope(pass *analysis.Pass, scope ast.Node, testVarName string, timeoutCtxs []timeoutContext) {
235+
var parallelCalls []ast.Node
236+
reportedNodes := make(map[ast.Node]bool) // Track nodes we've already reported
237+
238+
ast.Inspect(scope, func(n ast.Node) bool {
239+
// Don't descend into nested function literals (they have their own scopes)
240+
if fl, ok := n.(*ast.FuncLit); ok && fl != scope {
241+
// Analyze the subtest scope separately with its own test variable
242+
subtestVarName := a.getSubtestParamName(fl)
243+
if subtestVarName != "" {
244+
a.analyzeScope(pass, fl, subtestVarName, timeoutCtxs)
245+
}
246+
return false // Don't continue into this scope
247+
}
248+
249+
// Find t.Parallel() calls in this scope
250+
if isTestMethodCall(n, testVarName, "Parallel") {
251+
parallelCalls = append(parallelCalls, n)
252+
}
253+
254+
// Check for context reassignment (overwriting)
255+
if as, ok := n.(*ast.AssignStmt); ok && as.Tok == token.ASSIGN {
256+
if len(as.Lhs) > 0 {
257+
if id, ok := as.Lhs[0].(*ast.Ident); ok {
258+
if a.checkContextViolation(pass, id, n, parallelCalls, timeoutCtxs, true) {
259+
reportedNodes[id] = true
260+
return true
261+
}
262+
}
263+
}
264+
}
265+
266+
// Check if any timeout context is used in this scope (regular usage, not assignment)
267+
if id, ok := n.(*ast.Ident); ok {
268+
// Skip if we've already reported this identifier
269+
if reportedNodes[id] {
270+
return true
271+
}
272+
273+
if a.checkContextViolation(pass, id, n, parallelCalls, timeoutCtxs, false) {
274+
return true
275+
}
276+
}
277+
278+
return true
279+
})
280+
}
281+
282+
// checkContextViolation checks if a context identifier violates the t.Parallel usage rules
283+
func (a *ctxAnalyzer) checkContextViolation(pass *analysis.Pass, id *ast.Ident, node ast.Node, parallelCalls []ast.Node, timeoutCtxs []timeoutContext, isAssignment bool) bool {
284+
obj := pass.TypesInfo.ObjectOf(id)
285+
if obj == nil {
286+
return false
287+
}
288+
289+
// Check if this identifier references a timeout context
290+
for _, ctx := range timeoutCtxs {
291+
if ctx.obj == obj {
292+
// Check if this usage/assignment is after any parallel call in the same scope
293+
for _, parallelCall := range parallelCalls {
294+
parallelPos := pass.Fset.Position(parallelCall.Pos()).Offset
295+
nodePos := pass.Fset.Position(node.Pos()).Offset
296+
contextPos := pass.Fset.Position(ctx.pos).Offset
297+
298+
// Context created before parallel call, but used/assigned after parallel call
299+
if contextPos < parallelPos && nodePos > parallelPos {
300+
if isAssignment {
301+
pass.Reportf(node.Pos(), "timeout context %s overwritten after a t.Parallel call; did you mean to shadow the variable?", id.Name)
302+
} else {
303+
pass.Reportf(node.Pos(), "timeout context %s used after a t.Parallel call", id.Name)
304+
}
305+
return true
306+
}
307+
}
308+
}
309+
}
310+
return false
311+
}
312+
313+
// getSubtestParamName extracts the test parameter name from a function literal
314+
func (a *ctxAnalyzer) getSubtestParamName(fl *ast.FuncLit) string {
315+
if fl.Type.Params == nil || len(fl.Type.Params.List) == 0 {
316+
return ""
317+
}
318+
p := fl.Type.Params.List[0]
319+
if len(p.Names) == 0 {
320+
return ""
321+
}
322+
return p.Names[0].Name
323+
}
324+
325+
// extractReceiverName extracts the receiver name from different AST expressions
326+
func extractReceiverName(expr ast.Expr) string {
327+
switch x := expr.(type) {
328+
case *ast.Ident:
329+
return x.Name
330+
case *ast.CompositeLit:
331+
// Handle foo{}.Method() pattern - get the type name
332+
if id, ok := x.Type.(*ast.Ident); ok {
333+
return id.Name
334+
}
335+
case *ast.ParenExpr:
336+
// Handle (foo{}).Method() pattern - unwrap parentheses
337+
if cl, ok := x.X.(*ast.CompositeLit); ok {
338+
if id, ok := cl.Type.(*ast.Ident); ok {
339+
return id.Name
340+
}
341+
}
342+
}
343+
return ""
344+
}
345+
346+
// isTestMethodCall checks if node is a method call on testVar with the given methodName
347+
func isTestMethodCall(node ast.Node, testVar, methodName string) bool {
348+
ce, ok := node.(*ast.CallExpr)
349+
if !ok {
350+
return false
351+
}
352+
fun, ok := ce.Fun.(*ast.SelectorExpr)
353+
if !ok {
354+
return false
355+
}
356+
recv, ok := fun.X.(*ast.Ident)
357+
if !ok {
358+
return false
359+
}
360+
return recv.Name == testVar && fun.Sel.Name == methodName
361+
}

0 commit comments

Comments
 (0)