@@ -25,6 +25,7 @@ type TemplateGenerator struct {
2525 StringLblValue func (name string , label * generator.Label ) (string , error )
2626 PbLabelTLangLabels func (labels map [string ]* generator.Label ) ([]Label , error )
2727 ExtraFns template.FuncMap
28+ importToPkg map [protogen.GoImportPath ]protogen.GoPackageName
2829}
2930
3031func (t * TemplateGenerator ) GenerateFile (
@@ -35,7 +36,9 @@ func (t *TemplateGenerator) GenerateFile(
3536 localPrefix string ) error {
3637
3738 seen := map [string ]int {}
38- importToPkg := map [protogen.GoImportPath ]protogen.GoPackageName {}
39+ if t .importToPkg == nil {
40+ t .importToPkg = map [protogen.GoImportPath ]protogen.GoPackageName {}
41+ }
3942 for _ , f := range plugin .Files {
4043 base := string (f .GoPackageName )
4144 alias := base
@@ -46,10 +49,10 @@ func (t *TemplateGenerator) GenerateFile(
4649 } else {
4750 seen [base ] = 0
4851 }
49- importToPkg [f .GoImportPath ] = protogen .GoPackageName (alias )
52+ t . importToPkg [f .GoImportPath ] = protogen .GoPackageName (alias )
5053 }
5154
52- fileName , content , err := t .Generate (path .Base (file .GeneratedFilenamePrefix ), args , importToPkg , toolName , localPrefix )
55+ fileName , content , err := t .Generate (path .Base (file .GeneratedFilenamePrefix ), args , t . importToPkg , toolName , localPrefix )
5356 if err != nil {
5457 return err
5558 }
@@ -163,21 +166,7 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial
163166 copy (allImports , orderedImports )
164167 return allImports
165168 },
166- "name" : func (ident protogen.GoIdent , ignore string ) string {
167- importPath := ident .GoImportPath .String ()
168- if ignore == importPath {
169- return ident .GoName
170- }
171-
172- packageName := path .Base (strings .Trim (importPath , `"` ))
173-
174- // use package name alias if package is mismatched with the package name
175- if ! isDirNamePackageName (ident .GoImportPath , importToPkg ) {
176- packageName = string (importToPkg [ident .GoImportPath ])
177- }
178-
179- return fmt .Sprintf ("%s.%s" , packageName , ident .GoName )
180- },
169+ "name" : t .TypeName ,
181170 "CapabilityId" : func (s * protogen.Service ) (string , error ) {
182171 md , err := getCapabilityMetadata (s )
183172 if err != nil {
@@ -326,3 +315,27 @@ type namedLabel struct {
326315 name string
327316 label * generator.Label
328317}
318+
319+ func (t * TemplateGenerator ) TypeName (ident protogen.GoIdent , ignore string ) string {
320+ importPath := ident .GoImportPath .String ()
321+ if ignore == importPath {
322+ return ident .GoName
323+ }
324+
325+ packageName := path .Base (strings .Trim (importPath , `"` ))
326+
327+ // use package name alias if package is mismatched with the package name
328+ if ! isDirNamePackageName (ident .GoImportPath , t .importToPkg ) {
329+ packageName = string (t .importToPkg [ident .GoImportPath ])
330+ }
331+
332+ return fmt .Sprintf ("%s.%s" , packageName , ident .GoName )
333+ }
334+
335+ func (t * TemplateGenerator ) AddImport (name protogen.GoImportPath , importPath protogen.GoPackageName ) {
336+ if t .importToPkg == nil {
337+ t .importToPkg = map [protogen.GoImportPath ]protogen.GoPackageName {}
338+ }
339+
340+ t .importToPkg [name ] = importPath
341+ }
0 commit comments