Skip to content

Commit 74ef680

Browse files
authored
feat: Generate accessors for all fields (#4105)
1 parent 247ddc3 commit 74ef680

File tree

3 files changed

+19480
-61
lines changed

3 files changed

+19480
-61
lines changed

github/gen-accessors.go

Lines changed: 142 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,18 @@
55

66
//go:build ignore
77

8-
// gen-accessors generates accessor methods for structs with pointer fields.
8+
// gen-accessors generates accessor methods for all struct fields.
9+
// This is so that interfaces can be easily crafted by users of this repo
10+
// within their own code bases.
11+
// See https://github.com/google/go-github/issues/4059 for details.
912
//
1013
// It is meant to be used by go-github contributors in conjunction with the
1114
// go generate tool before sending a PR to GitHub.
1215
// Please see the CONTRIBUTING.md file for more information.
16+
//
17+
// Usage:
18+
//
19+
// go run gen-accessors.go [-v [file1.go file2.go ...]]
1320
package main
1421

1522
import (
@@ -39,14 +46,15 @@ var (
3946

4047
// skipStructMethods lists "struct.method" combos to skip.
4148
skipStructMethods = map[string]bool{
42-
"RepositoryContent.GetContent": true,
49+
"AbuseRateLimitError.GetResponse": true,
4350
"Client.GetBaseURL": true,
4451
"Client.GetUploadURL": true,
4552
"ErrorResponse.GetResponse": true,
46-
"RateLimitError.GetResponse": true,
47-
"AbuseRateLimitError.GetResponse": true,
53+
"MarketplaceService.GetStubbed": true,
4854
"PackageVersion.GetBody": true,
4955
"PackageVersion.GetMetadata": true,
56+
"RateLimitError.GetResponse": true,
57+
"RepositoryContent.GetContent": true,
5058
}
5159
// skipStructs lists structs to skip.
5260
skipStructs = map[string]bool{
@@ -67,6 +75,18 @@ func logf(fmt string, args ...any) {
6775

6876
func main() {
6977
flag.Parse()
78+
79+
// For debugging purposes, processing just a single or a few files is helpful:
80+
var processOnly map[string]bool
81+
if *verbose { // Only create the map if args are provided.
82+
for _, arg := range flag.Args() {
83+
if processOnly == nil {
84+
processOnly = map[string]bool{}
85+
}
86+
processOnly[arg] = true
87+
}
88+
}
89+
7090
fset := token.NewFileSet()
7191

7292
pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
@@ -83,6 +103,10 @@ func main() {
83103
Imports: map[string]string{},
84104
}
85105
for filename, f := range pkg.Files {
106+
if *verbose && processOnly != nil && !processOnly[filename] {
107+
continue
108+
}
109+
86110
logf("Processing %v...", filename)
87111
if err := t.processAST(f); err != nil {
88112
log.Fatal(err)
@@ -116,8 +140,12 @@ func (t *templateData) processAST(f *ast.File) error {
116140
logf("Struct %v is in skip list; skipping.", ts.Name)
117141
continue
118142
}
143+
if _, ok := ts.Type.(*ast.Ident); ok { // e.g. type SomeService service
144+
continue
145+
}
119146
st, ok := ts.Type.(*ast.StructType)
120147
if !ok {
148+
logf("Skipping TypeSpec of type %T", ts.Type)
121149
continue
122150
}
123151
for _, field := range st.Fields.List {
@@ -141,14 +169,21 @@ func (t *templateData) processAST(f *ast.File) error {
141169
if !ok {
142170
switch x := field.Type.(type) {
143171
case *ast.MapType:
172+
logf("processAST: addMapType(x, %q, %q)", ts.Name.String(), fieldName.String())
144173
t.addMapType(x, ts.Name.String(), fieldName.String(), false)
145174
continue
146175
case *ast.ArrayType:
147-
if key := fmt.Sprintf("%v.%v", ts.Name, fieldName); whitelistSliceGetters[key] {
148-
logf("Method %v is whitelist; adding getter method.", key)
149-
t.addArrayType(x, ts.Name.String(), fieldName.String(), false)
150-
continue
151-
}
176+
logf("processAST: addArrayType(x, %q, %q)", ts.Name.String(), fieldName.String())
177+
t.addArrayType(x, ts.Name.String(), fieldName.String(), false)
178+
continue
179+
case *ast.Ident:
180+
logf("processAST: addSimpleValueIdent(x, %q, %q)", ts.Name.String(), fieldName.String())
181+
t.addSimpleValueIdent(x, ts.Name.String(), fieldName.String())
182+
continue
183+
case *ast.SelectorExpr:
184+
logf("processAST: addSimpleValueSelectorExpr(x, %q, %q)", ts.Name.String(), fieldName.String())
185+
t.addSimpleValueSelectorExpr(x, ts.Name.String(), fieldName.String())
186+
continue
152187
}
153188

154189
logf("Skipping field type %T, fieldName=%v", field.Type, fieldName)
@@ -254,24 +289,62 @@ func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName st
254289
t.Getters = append(t.Getters, ng)
255290
}
256291

292+
func (t *templateData) addSimpleValueIdent(x *ast.Ident, receiverType, fieldName string) {
293+
getter := genIdentGetter(x, receiverType, fieldName)
294+
getter.IsSimpleValue = true
295+
logf("addSimpleValueIdent: Processing %q - fieldName=%q, getter.ZeroValue=%q, x.Obj=%#v", x.String(), fieldName, getter.ZeroValue, x.Obj)
296+
if getter.ZeroValue == "nil" {
297+
if x.Obj == nil {
298+
switch x.String() {
299+
case "any": // NOOP - leave as `nil`
300+
default:
301+
getter.ZeroValue = x.String() + "{}"
302+
}
303+
} else {
304+
if ts, ok := x.Obj.Decl.(*ast.TypeSpec); ok {
305+
logf("addSimpleValueIdent: Processing %q of type %T", x.String(), ts.Type)
306+
switch xX := ts.Type.(type) {
307+
case *ast.Ident:
308+
logf("addSimpleValueIdent: Processing %q of type %T - zero value is %q", x.String(), ts.Type, getter.ZeroValue)
309+
getter.ZeroValue = zeroValueOfIdent(xX)
310+
case *ast.StructType:
311+
getter.ZeroValue = x.String() + "{}"
312+
logf("addSimpleValueIdent: Processing %q of type %T - zero value is %q", x.String(), ts.Type, getter.ZeroValue)
313+
case *ast.InterfaceType, *ast.ArrayType: // NOOP - leave as `nil`
314+
logf("addSimpleValueIdent: Processing %q of type %T - zero value is %q", x.String(), ts.Type, getter.ZeroValue)
315+
default:
316+
log.Fatalf("addSimpleValueIdent: unhandled case %T", xX)
317+
}
318+
}
319+
}
320+
}
321+
t.Getters = append(t.Getters, getter)
322+
}
323+
257324
func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
258-
var zeroValue string
259-
var namedStruct bool
325+
getter := genIdentGetter(x, receiverType, fieldName)
326+
t.Getters = append(t.Getters, getter)
327+
}
328+
329+
func zeroValueOfIdent(x *ast.Ident) string {
260330
switch x.String() {
261-
case "int", "int64":
262-
zeroValue = "0"
331+
case "int", "int64", "float64", "uint8", "uint16":
332+
return "0"
263333
case "string":
264-
zeroValue = `""`
334+
return `""`
265335
case "bool":
266-
zeroValue = "false"
336+
return "false"
267337
case "Timestamp":
268-
zeroValue = "Timestamp{}"
338+
return "Timestamp{}"
269339
default:
270-
zeroValue = "nil"
271-
namedStruct = true
340+
return "nil"
272341
}
342+
}
273343

274-
t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
344+
func genIdentGetter(x *ast.Ident, receiverType, fieldName string) *getter {
345+
zeroValue := zeroValueOfIdent(x)
346+
namedStruct := zeroValue == "nil"
347+
return newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct)
275348
}
276349

277350
func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) {
@@ -300,10 +373,28 @@ func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string
300373
t.Getters = append(t.Getters, ng)
301374
}
302375

376+
func (t *templateData) addSimpleValueSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
377+
getter := t.genSelectorExprGetter(x, receiverType, fieldName)
378+
if getter == nil {
379+
return
380+
}
381+
getter.IsSimpleValue = true
382+
logf("addSimpleValueSelectorExpr: Processing field name %q - %#v - zero value is %q", fieldName, x, getter.ZeroValue)
383+
t.Getters = append(t.Getters, getter)
384+
}
385+
303386
func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
304-
if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
387+
getter := t.genSelectorExprGetter(x, receiverType, fieldName)
388+
if getter == nil {
305389
return
306390
}
391+
t.Getters = append(t.Getters, getter)
392+
}
393+
394+
func (t *templateData) genSelectorExprGetter(x *ast.SelectorExpr, receiverType, fieldName string) *getter {
395+
if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
396+
return nil
397+
}
307398

308399
var xX string
309400
if xx, ok := x.X.(*ast.Ident); ok {
@@ -322,10 +413,12 @@ func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldN
322413
if xX == "time" && x.Sel.Name == "Duration" {
323414
zeroValue = "0"
324415
}
325-
t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
416+
return newGetter(receiverType, fieldName, fieldType, zeroValue, false)
326417
default:
327418
logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
328419
}
420+
421+
return nil
329422
}
330423

331424
type templateData struct {
@@ -337,15 +430,16 @@ type templateData struct {
337430
}
338431

339432
type getter struct {
340-
sortVal string // Lower-case version of "ReceiverType.FieldName".
341-
ReceiverVar string // The one-letter variable name to match the ReceiverType.
342-
ReceiverType string
343-
FieldName string
344-
FieldType string
345-
ZeroValue string
346-
NamedStruct bool // Getter for named struct.
347-
MapType bool
348-
ArrayType bool
433+
sortVal string // Lower-case version of "ReceiverType.FieldName".
434+
ReceiverVar string // The one-letter variable name to match the ReceiverType.
435+
ReceiverType string
436+
FieldName string
437+
FieldType string
438+
ZeroValue string
439+
NamedStruct bool // Getter for named struct.
440+
MapType bool
441+
ArrayType bool
442+
IsSimpleValue bool
349443
}
350444

351445
const source = `// Code generated by gen-accessors; DO NOT EDIT.
@@ -366,7 +460,15 @@ import (
366460
)
367461
{{end}}
368462
{{range .Getters}}
369-
{{if .NamedStruct}}
463+
{{if .IsSimpleValue}}
464+
// Get{{.FieldName}} returns the {{.FieldName}} field.
465+
func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
466+
if {{.ReceiverVar}} == nil {
467+
return {{.ZeroValue}}
468+
}
469+
return {{.ReceiverVar}}.{{.FieldName}}
470+
}
471+
{{else if .NamedStruct}}
370472
// Get{{.FieldName}} returns the {{.FieldName}} field.
371473
func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
372474
if {{.ReceiverVar}} == nil {
@@ -413,7 +515,15 @@ import (
413515
)
414516
{{end}}
415517
{{range .Getters}}
416-
{{if .NamedStruct}}
518+
{{if .IsSimpleValue}}
519+
func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
520+
tt.Parallel()
521+
{{.ReceiverVar}} := &{{.ReceiverType}}{}
522+
{{.ReceiverVar}}.Get{{.FieldName}}()
523+
{{.ReceiverVar}} = nil
524+
{{.ReceiverVar}}.Get{{.FieldName}}()
525+
}
526+
{{else if .NamedStruct}}
417527
func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
418528
tt.Parallel()
419529
{{.ReceiverVar}} := &{{.ReceiverType}}{}

0 commit comments

Comments
 (0)