@@ -73,6 +73,168 @@ func (a *ctxAnalyzer) doesCreateTimeoutContext(call *ast.CallExpr) bool {
7373 return false
7474}
7575
76+ type timeoutHelperSet map [types.Object ]struct {}
77+
78+ func (a * ctxAnalyzer ) collectTimeoutHelpers (pass * analysis.Pass , testFiles []* ast.File ) timeoutHelperSet {
79+ candidates := map [* ast.FuncDecl ]types.Object {}
80+ for _ , file := range testFiles {
81+ for _ , decl := range file .Decls {
82+ fd , ok := decl .(* ast.FuncDecl )
83+ if ! ok || fd .Body == nil || ! funcReturnsContext (pass , fd ) {
84+ continue
85+ }
86+
87+ obj := pass .TypesInfo .Defs [fd .Name ]
88+ if obj == nil {
89+ continue
90+ }
91+ candidates [fd ] = obj
92+ }
93+ }
94+
95+ helpers := timeoutHelperSet {}
96+ for changed := true ; changed ; {
97+ changed = false
98+ for fd , obj := range candidates {
99+ if _ , ok := helpers [obj ]; ok {
100+ continue
101+ }
102+ if a .functionReturnsTimeoutContext (pass , fd , helpers ) {
103+ helpers [obj ] = struct {}{}
104+ changed = true
105+ }
106+ }
107+ }
108+ return helpers
109+ }
110+
111+ func funcReturnsContext (pass * analysis.Pass , fd * ast.FuncDecl ) bool {
112+ if fd .Type .Results == nil || len (fd .Type .Results .List ) == 0 {
113+ return false
114+ }
115+
116+ return isContextType (pass , pass .TypesInfo .TypeOf (fd .Type .Results .List [0 ].Type ))
117+ }
118+
119+ func (a * ctxAnalyzer ) functionReturnsTimeoutContext (pass * analysis.Pass , fd * ast.FuncDecl , helpers timeoutHelperSet ) bool {
120+ found := false
121+ ast .Inspect (fd .Body , func (n ast.Node ) bool {
122+ if found {
123+ return false
124+ }
125+
126+ switch n := n .(type ) {
127+ case * ast.FuncLit :
128+ return false
129+ case * ast.ReturnStmt :
130+ if len (n .Results ) == 0 {
131+ return true
132+ }
133+ found = a .exprContainsTimeoutContext (pass , n .Results [0 ], helpers )
134+ return ! found
135+ default :
136+ return true
137+ }
138+ })
139+ return found
140+ }
141+
142+ func (a * ctxAnalyzer ) exprContainsTimeoutContext (pass * analysis.Pass , expr ast.Expr , helpers timeoutHelperSet ) bool {
143+ found := false
144+ ast .Inspect (expr , func (n ast.Node ) bool {
145+ if found {
146+ return false
147+ }
148+
149+ call , ok := n .(* ast.CallExpr )
150+ if ! ok {
151+ return true
152+ }
153+
154+ found = a .doesCreateTimeoutContext (call ) || callsTimeoutHelper (pass , call , helpers )
155+ return ! found
156+ })
157+ return found
158+ }
159+
160+ func callsTimeoutHelper (pass * analysis.Pass , call * ast.CallExpr , helpers timeoutHelperSet ) bool {
161+ if len (helpers ) == 0 {
162+ return false
163+ }
164+
165+ obj := callObject (pass , call )
166+ if obj == nil {
167+ return false
168+ }
169+ _ , ok := helpers [obj ]
170+ return ok
171+ }
172+
173+ func callObject (pass * analysis.Pass , call * ast.CallExpr ) types.Object {
174+ switch fun := call .Fun .(type ) {
175+ case * ast.Ident :
176+ return pass .TypesInfo .Uses [fun ]
177+ case * ast.SelectorExpr :
178+ return pass .TypesInfo .Uses [fun .Sel ]
179+ default :
180+ return nil
181+ }
182+ }
183+
184+ func isContextType (pass * analysis.Pass , typ types.Type ) bool {
185+ if typ == nil {
186+ return false
187+ }
188+
189+ contextIface := contextInterface (pass )
190+ if contextIface == nil {
191+ return isStdlibContextType (typ )
192+ }
193+
194+ return types .Implements (types .Unalias (typ ), contextIface )
195+ }
196+
197+ func contextInterface (pass * analysis.Pass ) * types.Interface {
198+ if pass .Pkg == nil {
199+ return nil
200+ }
201+
202+ for _ , pkg := range pass .Pkg .Imports () {
203+ if pkg .Path () != "context" {
204+ continue
205+ }
206+
207+ obj := pkg .Scope ().Lookup ("Context" )
208+ if obj == nil {
209+ return nil
210+ }
211+
212+ named , ok := obj .Type ().(* types.Named )
213+ if ! ok {
214+ return nil
215+ }
216+
217+ iface , ok := named .Underlying ().(* types.Interface )
218+ if ! ok {
219+ return nil
220+ }
221+
222+ return iface
223+ }
224+
225+ return nil
226+ }
227+
228+ func isStdlibContextType (typ types.Type ) bool {
229+ named , ok := types .Unalias (typ ).(* types.Named )
230+ if ! ok {
231+ return false
232+ }
233+
234+ obj := named .Obj ()
235+ return obj .Name () == "Context" && obj .Pkg () != nil && obj .Pkg ().Path () == "context"
236+ }
237+
76238func (a * ctxAnalyzer ) getTimeoutFuncs () []timeoutFunc {
77239 if a .timeoutFuncs != nil {
78240 return a .timeoutFuncs
@@ -119,6 +281,8 @@ func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) {
119281 return nil , nil
120282 }
121283
284+ timeoutHelpers := a .collectTimeoutHelpers (pass , testFiles )
285+
122286 insp := inspector .New (testFiles )
123287 nodeFilter := []ast.Node {(* ast .FuncDecl )(nil )}
124288 insp .Preorder (nodeFilter , func (n ast.Node ) {
@@ -127,20 +291,20 @@ func (a *ctxAnalyzer) run(pass *analysis.Pass) (any, error) {
127291 return
128292 }
129293
130- a .analyzeTestFunction (pass , fd )
294+ a .analyzeTestFunction (pass , fd , timeoutHelpers )
131295 })
132296 return nil , nil
133297}
134298
135299// analyzeTestFunction analyzes a test function to find timeout context usage after t.Parallel calls
136- func (a * ctxAnalyzer ) analyzeTestFunction (pass * analysis.Pass , fd * ast.FuncDecl ) {
300+ func (a * ctxAnalyzer ) analyzeTestFunction (pass * analysis.Pass , fd * ast.FuncDecl , helpers timeoutHelperSet ) {
137301 testVarName := testParamName (fd )
138302 if testVarName == "" {
139303 return
140304 }
141305
142306 // Collect all timeout contexts and their positions
143- timeoutCtxs := a .collectTimeoutContexts (pass , fd )
307+ timeoutCtxs := a .collectTimeoutContexts (pass , fd , helpers )
144308 if len (timeoutCtxs ) == 0 {
145309 return
146310 }
@@ -187,7 +351,7 @@ type timeoutContext struct {
187351}
188352
189353// collectTimeoutContexts finds all context variables created with a timeout/deadline functions
190- func (a * ctxAnalyzer ) collectTimeoutContexts (pass * analysis.Pass , fd * ast.FuncDecl ) []timeoutContext {
354+ func (a * ctxAnalyzer ) collectTimeoutContexts (pass * analysis.Pass , fd * ast.FuncDecl , helpers timeoutHelperSet ) []timeoutContext {
191355 var contexts []timeoutContext
192356 contextMap := make (map [types.Object ]* timeoutContext )
193357
@@ -209,8 +373,11 @@ func (a *ctxAnalyzer) collectTimeoutContexts(pass *analysis.Pass, fd *ast.FuncDe
209373 return true
210374 }
211375
212- call , ok := as .Rhs [0 ].(* ast.CallExpr )
213- if ok && a .doesCreateTimeoutContext (call ) {
376+ if ! isContextType (pass , obj .Type ()) {
377+ return true
378+ }
379+
380+ if a .exprContainsTimeoutContext (pass , as .Rhs [0 ], helpers ) {
214381 // This is a timeout context creation
215382 ctx := timeoutContext {
216383 obj : obj ,
0 commit comments