Skip to content

Commit 2c72902

Browse files
committed
fix: handle overwritten contexts
1 parent 68df738 commit 2c72902

2 files changed

Lines changed: 39 additions & 15 deletions

File tree

pkg/paralleltestctx/analyzer.go

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,15 @@ func testParamName(fd *ast.FuncDecl) string {
181181

182182
// timeoutContext represents a context variable created with a timeout/deadline
183183
type timeoutContext struct {
184-
obj types.Object
185-
pos token.Pos // position where it was created
184+
obj types.Object
185+
pos token.Pos // position where it was created
186+
invalidPos token.Pos // position where it was overwritten with non-timeout context (0 if still valid)
186187
}
187188

188189
// collectTimeoutContexts finds all context variables created with a timeout/deadline functions
189190
func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDecl) []timeoutContext {
190191
var contexts []timeoutContext
192+
contextMap := make(map[types.Object]*timeoutContext)
191193

192194
ast.Inspect(fd.Body, func(n ast.Node) bool {
193195
as, ok := n.(*ast.AssignStmt)
@@ -197,14 +199,6 @@ func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDe
197199
if len(as.Lhs) == 0 || len(as.Rhs) == 0 {
198200
return true
199201
}
200-
call, ok := as.Rhs[0].(*ast.CallExpr)
201-
if !ok {
202-
return true
203-
}
204-
205-
if !a.doesCreateTimeoutContext(call) {
206-
return true
207-
}
208202

209203
id, ok := as.Lhs[0].(*ast.Ident)
210204
if !ok {
@@ -215,10 +209,20 @@ func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDe
215209
return true
216210
}
217211

218-
contexts = append(contexts, timeoutContext{
219-
obj: obj,
220-
pos: as.Pos(),
221-
})
212+
call, ok := as.Rhs[0].(*ast.CallExpr)
213+
if ok && a.doesCreateTimeoutContext(call) {
214+
// This is a timeout context creation
215+
ctx := timeoutContext{
216+
obj: obj,
217+
pos: as.Pos(),
218+
}
219+
contexts = append(contexts, ctx)
220+
contextMap[obj] = &contexts[len(contexts)-1]
221+
} else if existingCtx, exists := contextMap[obj]; exists {
222+
// This is an overwrite of an existing timeout context with non-timeout
223+
existingCtx.invalidPos = as.Pos()
224+
}
225+
222226
return true
223227
})
224228
return contexts
@@ -297,6 +301,15 @@ func (a *ctxAnalyzer) checkContextViolation(pass *analysis.Pass, id *ast.Ident,
297301

298302
// Context created before parallel call, but used/assigned after parallel call
299303
if contextPos < parallelPos && nodePos > parallelPos {
304+
// If context was invalidated (overwritten with non-timeout), check if that happened before this usage
305+
if ctx.invalidPos != 0 {
306+
invalidOffset := pass.Fset.Position(ctx.invalidPos).Offset
307+
if invalidOffset < nodePos {
308+
// Context was invalidated before this usage, so don't warn
309+
return false
310+
}
311+
}
312+
300313
if isAssignment {
301314
pass.Reportf(node.Pos(), "timeout context %s overwritten after a t.Parallel call; did you mean to shadow the variable?", id.Name)
302315
} else {

pkg/paralleltestctx/testdata/src/basic/basic_test.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func TestContextOverwrittenWarn(t *testing.T) {
7777
t.Run("sub", func(t *testing.T) {
7878
t.Parallel()
7979
ctx = context.Background() // want "timeout context ctx overwritten after a t.Parallel call; did you mean to shadow the variable\\?"
80-
_ = ctx // want "timeout context ctx used after a t.Parallel call"
80+
_ = ctx
8181
})
8282
}
8383

@@ -107,6 +107,17 @@ func TestDifferentScopes(t *testing.T) {
107107
_ = ctx // used in main test - should not warn (different scope)
108108
}
109109

110+
func TestOverwriteEarly(t *testing.T) {
111+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
112+
t.Cleanup(cancel)
113+
_ = ctx
114+
ctx = context.Background()
115+
t.Run("sub", func(t *testing.T) {
116+
t.Parallel()
117+
_ = ctx
118+
})
119+
}
120+
110121
func doThing(ctx context.Context) {
111122
_ = ctx
112123
}

0 commit comments

Comments
 (0)