1+ /*
2+ The MIT License (MIT)
3+
4+ Copyright (c) 2015-2026 Ernesto Jiménez and contributors.
5+
6+ Permission is hereby granted, free of charge, to any person obtaining a copy
7+ of this software and associated documentation files (the "Software"), to deal
8+ in the Software without restriction, including without limitation the rights
9+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+ copies of the Software, and to permit persons to whom the Software is
11+ furnished to do so, subject to the following conditions:
12+
13+ The above copyright notice and this permission notice shall be included in all
14+ copies or substantial portions of the Software.
15+
16+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+ SOFTWARE.
23+ */
24+
125// This program reads all assertion functions from the assert package and
226// automatically generates the corresponding requires and forwarded assertions
327
@@ -22,8 +46,6 @@ import (
2246 "regexp"
2347 "strings"
2448 "text/template"
25-
26- "github.com/stretchr/testify/_codegen/internal/imports"
2749)
2850
2951var (
@@ -42,17 +64,17 @@ func main() {
4264 log .Fatal (err )
4365 }
4466
45- importer , funcs , err := analyzeCode (scope , docs )
67+ imports , funcs , err := analyzeCode (scope , docs )
4668 if err != nil {
4769 log .Fatal (err )
4870 }
4971
50- if err := generateCode (importer , funcs ); err != nil {
72+ if err := generateCode (imports , funcs ); err != nil {
5173 log .Fatal (err )
5274 }
5375}
5476
55- func generateCode (importer imports. Importer , funcs []testFunc ) error {
77+ func generateCode (imports * imports , funcs []testFunc ) error {
5678 buff := bytes .NewBuffer (nil )
5779
5880 tmplHead , tmplFunc , err := parseTemplates ()
@@ -66,7 +88,7 @@ func generateCode(importer imports.Importer, funcs []testFunc) error {
6688 Imports map [string ]string
6789 }{
6890 * outputPkg ,
69- importer . Imports () ,
91+ imports . imports ,
7092 }); err != nil {
7193 return err
7294 }
@@ -126,10 +148,13 @@ func outputFile() (*os.File, error) {
126148
127149// analyzeCode takes the types scope and the docs and returns the import
128150// information and information about all the assertion functions.
129- func analyzeCode (scope * types.Scope , docs * doc.Package ) (imports. Importer , []testFunc , error ) {
151+ func analyzeCode (scope * types.Scope , docs * doc.Package ) (* imports , []testFunc , error ) {
130152 testingT := scope .Lookup ("TestingT" ).Type ().Underlying ().(* types.Interface )
131153
132- importer := imports .New (* outputPkg )
154+ importer := & imports {
155+ currentPkg : * outputPkg ,
156+ imports : map [string ]string {},
157+ }
133158 var funcs []testFunc
134159 // Go through all the top level functions
135160 for _ , fdocs := range docs .Funcs {
@@ -164,11 +189,43 @@ func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []tes
164189 }
165190
166191 funcs = append (funcs , testFunc {* outputPkg , fdocs , fn })
167- importer .AddImportsFrom (sig .Params ())
192+ importer .addImportsFrom (sig .Params ())
168193 }
169194 return importer , funcs , nil
170195}
171196
197+ // imports collects a map of imported packages for a source file.
198+ //
199+ // This code has been copied from package github.com/ernesto-jimenez/gogen/imports
200+ type imports struct {
201+ currentPkg string
202+ imports map [string ]string
203+ }
204+
205+ func (imp * imports ) addImportsFrom (t types.Type ) {
206+ switch el := t .(type ) {
207+ case * types.Basic :
208+ case * types.Slice :
209+ imp .addImportsFrom (el .Elem ())
210+ case * types.Pointer :
211+ imp .addImportsFrom (el .Elem ())
212+ case * types.Named :
213+ pkg := el .Obj ().Pkg ()
214+ if pkg == nil {
215+ return
216+ }
217+ if pkg .Name () == imp .currentPkg {
218+ return
219+ }
220+ imp .imports [pkg .Path ()] = pkg .Name ()
221+ case * types.Tuple :
222+ for i := 0 ; i < el .Len (); i ++ {
223+ imp .addImportsFrom (el .At (i ).Type ())
224+ }
225+ default :
226+ }
227+ }
228+
172229// parsePackageSource returns the types scope and the package documentation from the package
173230func parsePackageSource (pkg string ) (* types.Scope , * doc.Package , error ) {
174231 pd , err := build .Import (pkg , "." , 0 )
0 commit comments