Skip to content

Commit 21a3c87

Browse files
authored
Public type name (#1535)
1 parent afe0f95 commit 21a3c87

1 file changed

Lines changed: 31 additions & 18 deletions

File tree

pkg/capabilities/v2/protoc/pkg/template_generator.go

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type TemplateGenerator struct {
2525
StringLblValue func(name string, label *generator.Label) (string, error)
2626
PbLabelTLangLabels func(labels map[string]*generator.Label) ([]Label, error)
2727
ExtraFns template.FuncMap
28+
importToPkg map[protogen.GoImportPath]protogen.GoPackageName
2829
}
2930

3031
func (t *TemplateGenerator) GenerateFile(
@@ -35,7 +36,9 @@ func (t *TemplateGenerator) GenerateFile(
3536
localPrefix string) error {
3637

3738
seen := map[string]int{}
38-
importToPkg := map[protogen.GoImportPath]protogen.GoPackageName{}
39+
if t.importToPkg == nil {
40+
t.importToPkg = map[protogen.GoImportPath]protogen.GoPackageName{}
41+
}
3942
for _, f := range plugin.Files {
4043
base := string(f.GoPackageName)
4144
alias := base
@@ -46,10 +49,10 @@ func (t *TemplateGenerator) GenerateFile(
4649
} else {
4750
seen[base] = 0
4851
}
49-
importToPkg[f.GoImportPath] = protogen.GoPackageName(alias)
52+
t.importToPkg[f.GoImportPath] = protogen.GoPackageName(alias)
5053
}
5154

52-
fileName, content, err := t.Generate(path.Base(file.GeneratedFilenamePrefix), args, importToPkg, toolName, localPrefix)
55+
fileName, content, err := t.Generate(path.Base(file.GeneratedFilenamePrefix), args, t.importToPkg, toolName, localPrefix)
5356
if err != nil {
5457
return err
5558
}
@@ -163,21 +166,7 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial
163166
copy(allImports, orderedImports)
164167
return allImports
165168
},
166-
"name": func(ident protogen.GoIdent, ignore string) string {
167-
importPath := ident.GoImportPath.String()
168-
if ignore == importPath {
169-
return ident.GoName
170-
}
171-
172-
packageName := path.Base(strings.Trim(importPath, `"`))
173-
174-
// use package name alias if package is mismatched with the package name
175-
if !isDirNamePackageName(ident.GoImportPath, importToPkg) {
176-
packageName = string(importToPkg[ident.GoImportPath])
177-
}
178-
179-
return fmt.Sprintf("%s.%s", packageName, ident.GoName)
180-
},
169+
"name": t.TypeName,
181170
"CapabilityId": func(s *protogen.Service) (string, error) {
182171
md, err := getCapabilityMetadata(s)
183172
if err != nil {
@@ -326,3 +315,27 @@ type namedLabel struct {
326315
name string
327316
label *generator.Label
328317
}
318+
319+
func (t *TemplateGenerator) TypeName(ident protogen.GoIdent, ignore string) string {
320+
importPath := ident.GoImportPath.String()
321+
if ignore == importPath {
322+
return ident.GoName
323+
}
324+
325+
packageName := path.Base(strings.Trim(importPath, `"`))
326+
327+
// use package name alias if package is mismatched with the package name
328+
if !isDirNamePackageName(ident.GoImportPath, t.importToPkg) {
329+
packageName = string(t.importToPkg[ident.GoImportPath])
330+
}
331+
332+
return fmt.Sprintf("%s.%s", packageName, ident.GoName)
333+
}
334+
335+
func (t *TemplateGenerator) AddImport(name protogen.GoImportPath, importPath protogen.GoPackageName) {
336+
if t.importToPkg == nil {
337+
t.importToPkg = map[protogen.GoImportPath]protogen.GoPackageName{}
338+
}
339+
340+
t.importToPkg[name] = importPath
341+
}

0 commit comments

Comments
 (0)