55
66//go:build ignore
77
8- // gen-accessors generates accessor methods for structs with pointer fields.
8+ // gen-accessors generates accessor methods for all struct fields.
9+ // This is so that interfaces can be easily crafted by users of this repo
10+ // within their own code bases.
11+ // See https://github.com/google/go-github/issues/4059 for details.
912//
1013// It is meant to be used by go-github contributors in conjunction with the
1114// go generate tool before sending a PR to GitHub.
1215// Please see the CONTRIBUTING.md file for more information.
16+ //
17+ // Usage:
18+ //
19+ // go run gen-accessors.go [-v [file1.go file2.go ...]]
1320package main
1421
1522import (
@@ -39,14 +46,15 @@ var (
3946
4047 // skipStructMethods lists "struct.method" combos to skip.
4148 skipStructMethods = map [string ]bool {
42- "RepositoryContent.GetContent" : true ,
49+ "AbuseRateLimitError.GetResponse" : true ,
4350 "Client.GetBaseURL" : true ,
4451 "Client.GetUploadURL" : true ,
4552 "ErrorResponse.GetResponse" : true ,
46- "RateLimitError.GetResponse" : true ,
47- "AbuseRateLimitError.GetResponse" : true ,
53+ "MarketplaceService.GetStubbed" : true ,
4854 "PackageVersion.GetBody" : true ,
4955 "PackageVersion.GetMetadata" : true ,
56+ "RateLimitError.GetResponse" : true ,
57+ "RepositoryContent.GetContent" : true ,
5058 }
5159 // skipStructs lists structs to skip.
5260 skipStructs = map [string ]bool {
@@ -67,6 +75,18 @@ func logf(fmt string, args ...any) {
6775
6876func main () {
6977 flag .Parse ()
78+
79+ // For debugging purposes, processing just a single or a few files is helpful:
80+ var processOnly map [string ]bool
81+ if * verbose { // Only create the map if args are provided.
82+ for _ , arg := range flag .Args () {
83+ if processOnly == nil {
84+ processOnly = map [string ]bool {}
85+ }
86+ processOnly [arg ] = true
87+ }
88+ }
89+
7090 fset := token .NewFileSet ()
7191
7292 pkgs , err := parser .ParseDir (fset , "." , sourceFilter , 0 )
@@ -83,6 +103,10 @@ func main() {
83103 Imports : map [string ]string {},
84104 }
85105 for filename , f := range pkg .Files {
106+ if * verbose && processOnly != nil && ! processOnly [filename ] {
107+ continue
108+ }
109+
86110 logf ("Processing %v..." , filename )
87111 if err := t .processAST (f ); err != nil {
88112 log .Fatal (err )
@@ -116,8 +140,12 @@ func (t *templateData) processAST(f *ast.File) error {
116140 logf ("Struct %v is in skip list; skipping." , ts .Name )
117141 continue
118142 }
143+ if _ , ok := ts .Type .(* ast.Ident ); ok { // e.g. type SomeService service
144+ continue
145+ }
119146 st , ok := ts .Type .(* ast.StructType )
120147 if ! ok {
148+ logf ("Skipping TypeSpec of type %T" , ts .Type )
121149 continue
122150 }
123151 for _ , field := range st .Fields .List {
@@ -141,14 +169,21 @@ func (t *templateData) processAST(f *ast.File) error {
141169 if ! ok {
142170 switch x := field .Type .(type ) {
143171 case * ast.MapType :
172+ logf ("processAST: addMapType(x, %q, %q)" , ts .Name .String (), fieldName .String ())
144173 t .addMapType (x , ts .Name .String (), fieldName .String (), false )
145174 continue
146175 case * ast.ArrayType :
147- if key := fmt .Sprintf ("%v.%v" , ts .Name , fieldName ); whitelistSliceGetters [key ] {
148- logf ("Method %v is whitelist; adding getter method." , key )
149- t .addArrayType (x , ts .Name .String (), fieldName .String (), false )
150- continue
151- }
176+ logf ("processAST: addArrayType(x, %q, %q)" , ts .Name .String (), fieldName .String ())
177+ t .addArrayType (x , ts .Name .String (), fieldName .String (), false )
178+ continue
179+ case * ast.Ident :
180+ logf ("processAST: addSimpleValueIdent(x, %q, %q)" , ts .Name .String (), fieldName .String ())
181+ t .addSimpleValueIdent (x , ts .Name .String (), fieldName .String ())
182+ continue
183+ case * ast.SelectorExpr :
184+ logf ("processAST: addSimpleValueSelectorExpr(x, %q, %q)" , ts .Name .String (), fieldName .String ())
185+ t .addSimpleValueSelectorExpr (x , ts .Name .String (), fieldName .String ())
186+ continue
152187 }
153188
154189 logf ("Skipping field type %T, fieldName=%v" , field .Type , fieldName )
@@ -254,24 +289,62 @@ func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName st
254289 t .Getters = append (t .Getters , ng )
255290}
256291
292+ func (t * templateData ) addSimpleValueIdent (x * ast.Ident , receiverType , fieldName string ) {
293+ getter := genIdentGetter (x , receiverType , fieldName )
294+ getter .IsSimpleValue = true
295+ logf ("addSimpleValueIdent: Processing %q - fieldName=%q, getter.ZeroValue=%q, x.Obj=%#v" , x .String (), fieldName , getter .ZeroValue , x .Obj )
296+ if getter .ZeroValue == "nil" {
297+ if x .Obj == nil {
298+ switch x .String () {
299+ case "any" : // NOOP - leave as `nil`
300+ default :
301+ getter .ZeroValue = x .String () + "{}"
302+ }
303+ } else {
304+ if ts , ok := x .Obj .Decl .(* ast.TypeSpec ); ok {
305+ logf ("addSimpleValueIdent: Processing %q of type %T" , x .String (), ts .Type )
306+ switch xX := ts .Type .(type ) {
307+ case * ast.Ident :
308+ logf ("addSimpleValueIdent: Processing %q of type %T - zero value is %q" , x .String (), ts .Type , getter .ZeroValue )
309+ getter .ZeroValue = zeroValueOfIdent (xX )
310+ case * ast.StructType :
311+ getter .ZeroValue = x .String () + "{}"
312+ logf ("addSimpleValueIdent: Processing %q of type %T - zero value is %q" , x .String (), ts .Type , getter .ZeroValue )
313+ case * ast.InterfaceType , * ast.ArrayType : // NOOP - leave as `nil`
314+ logf ("addSimpleValueIdent: Processing %q of type %T - zero value is %q" , x .String (), ts .Type , getter .ZeroValue )
315+ default :
316+ log .Fatalf ("addSimpleValueIdent: unhandled case %T" , xX )
317+ }
318+ }
319+ }
320+ }
321+ t .Getters = append (t .Getters , getter )
322+ }
323+
257324func (t * templateData ) addIdent (x * ast.Ident , receiverType , fieldName string ) {
258- var zeroValue string
259- var namedStruct bool
325+ getter := genIdentGetter (x , receiverType , fieldName )
326+ t .Getters = append (t .Getters , getter )
327+ }
328+
329+ func zeroValueOfIdent (x * ast.Ident ) string {
260330 switch x .String () {
261- case "int" , "int64" :
262- zeroValue = "0"
331+ case "int" , "int64" , "float64" , "uint8" , "uint16" :
332+ return "0"
263333 case "string" :
264- zeroValue = `""`
334+ return `""`
265335 case "bool" :
266- zeroValue = "false"
336+ return "false"
267337 case "Timestamp" :
268- zeroValue = "Timestamp{}"
338+ return "Timestamp{}"
269339 default :
270- zeroValue = "nil"
271- namedStruct = true
340+ return "nil"
272341 }
342+ }
273343
274- t .Getters = append (t .Getters , newGetter (receiverType , fieldName , x .String (), zeroValue , namedStruct ))
344+ func genIdentGetter (x * ast.Ident , receiverType , fieldName string ) * getter {
345+ zeroValue := zeroValueOfIdent (x )
346+ namedStruct := zeroValue == "nil"
347+ return newGetter (receiverType , fieldName , x .String (), zeroValue , namedStruct )
275348}
276349
277350func (t * templateData ) addMapType (x * ast.MapType , receiverType , fieldName string , isAPointer bool ) {
@@ -300,10 +373,28 @@ func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string
300373 t .Getters = append (t .Getters , ng )
301374}
302375
376+ func (t * templateData ) addSimpleValueSelectorExpr (x * ast.SelectorExpr , receiverType , fieldName string ) {
377+ getter := t .genSelectorExprGetter (x , receiverType , fieldName )
378+ if getter == nil {
379+ return
380+ }
381+ getter .IsSimpleValue = true
382+ logf ("addSimpleValueSelectorExpr: Processing field name %q - %#v - zero value is %q" , fieldName , x , getter .ZeroValue )
383+ t .Getters = append (t .Getters , getter )
384+ }
385+
303386func (t * templateData ) addSelectorExpr (x * ast.SelectorExpr , receiverType , fieldName string ) {
304- if strings .ToLower (fieldName [:1 ]) == fieldName [:1 ] { // Non-exported field.
387+ getter := t .genSelectorExprGetter (x , receiverType , fieldName )
388+ if getter == nil {
305389 return
306390 }
391+ t .Getters = append (t .Getters , getter )
392+ }
393+
394+ func (t * templateData ) genSelectorExprGetter (x * ast.SelectorExpr , receiverType , fieldName string ) * getter {
395+ if strings .ToLower (fieldName [:1 ]) == fieldName [:1 ] { // Non-exported field.
396+ return nil
397+ }
307398
308399 var xX string
309400 if xx , ok := x .X .(* ast.Ident ); ok {
@@ -322,10 +413,12 @@ func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldN
322413 if xX == "time" && x .Sel .Name == "Duration" {
323414 zeroValue = "0"
324415 }
325- t . Getters = append ( t . Getters , newGetter (receiverType , fieldName , fieldType , zeroValue , false ) )
416+ return newGetter (receiverType , fieldName , fieldType , zeroValue , false )
326417 default :
327418 logf ("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping." , xX , receiverType , fieldName , x )
328419 }
420+
421+ return nil
329422}
330423
331424type templateData struct {
@@ -337,15 +430,16 @@ type templateData struct {
337430}
338431
339432type getter struct {
340- sortVal string // Lower-case version of "ReceiverType.FieldName".
341- ReceiverVar string // The one-letter variable name to match the ReceiverType.
342- ReceiverType string
343- FieldName string
344- FieldType string
345- ZeroValue string
346- NamedStruct bool // Getter for named struct.
347- MapType bool
348- ArrayType bool
433+ sortVal string // Lower-case version of "ReceiverType.FieldName".
434+ ReceiverVar string // The one-letter variable name to match the ReceiverType.
435+ ReceiverType string
436+ FieldName string
437+ FieldType string
438+ ZeroValue string
439+ NamedStruct bool // Getter for named struct.
440+ MapType bool
441+ ArrayType bool
442+ IsSimpleValue bool
349443}
350444
351445const source = `// Code generated by gen-accessors; DO NOT EDIT.
@@ -366,7 +460,15 @@ import (
366460)
367461{{end}}
368462{{range .Getters}}
369- {{if .NamedStruct}}
463+ {{if .IsSimpleValue}}
464+ // Get{{.FieldName}} returns the {{.FieldName}} field.
465+ func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
466+ if {{.ReceiverVar}} == nil {
467+ return {{.ZeroValue}}
468+ }
469+ return {{.ReceiverVar}}.{{.FieldName}}
470+ }
471+ {{else if .NamedStruct}}
370472// Get{{.FieldName}} returns the {{.FieldName}} field.
371473func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
372474 if {{.ReceiverVar}} == nil {
@@ -413,7 +515,15 @@ import (
413515)
414516{{end}}
415517{{range .Getters}}
416- {{if .NamedStruct}}
518+ {{if .IsSimpleValue}}
519+ func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
520+ tt.Parallel()
521+ {{.ReceiverVar}} := &{{.ReceiverType}}{}
522+ {{.ReceiverVar}}.Get{{.FieldName}}()
523+ {{.ReceiverVar}} = nil
524+ {{.ReceiverVar}}.Get{{.FieldName}}()
525+ }
526+ {{else if .NamedStruct}}
417527func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
418528 tt.Parallel()
419529 {{.ReceiverVar}} := &{{.ReceiverType}}{}
0 commit comments