Skip to content

Commit e53d937

Browse files
committed
fix: associate methods (cloudwego#14)
1 parent 676e546 commit e53d937

4 files changed

Lines changed: 136 additions & 45 deletions

File tree

src/compress/golang/plugin/file.go

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,19 @@ func (p *goParser) inspectFile(ctx *fileContext, f *ast.File) (map[string]*Funct
7878
fileFuncs[f.Name] = f
7979
cont = ct
8080
} else if typDecl, ok := node.(*ast.TypeSpec); ok {
81-
// parse structs
8281
struName := typDecl.Name.Name
83-
struDecl, ok := typDecl.Type.(*ast.StructType)
84-
if ok {
85-
st, ct := p.parseStruct(ctx, struName, struDecl)
86-
fileStructs[struName] = st
87-
cont = ct
82+
var st *Struct
83+
var ct bool
84+
// parse structs
85+
if struDecl, ok := typDecl.Type.(*ast.StructType); ok {
86+
st, ct = p.parseStruct(ctx, struName, struDecl)
87+
// parse interface
88+
} else if interDecl, ok := typDecl.Type.(*ast.InterfaceType); ok {
89+
struName := typDecl.Name.Name
90+
st, ct = p.parseInterface(ctx, struName, interDecl)
8891
}
92+
fileStructs[struName] = st
93+
cont = ct
8994
}
9095
return cont
9196
})
@@ -175,6 +180,9 @@ func (p *goParser) parseFunc(ctx *fileContext, funcDecl *ast.FuncDecl) (*Functio
175180
var thirdPartyMethodCalls, thirdPartyFunctionCalls = map[string]*ThirdPartyIdentity{}, map[string]*ThirdPartyIdentity{}
176181
var functionCalls, methodCalls = map[string]*Function{}, map[string]*Function{}
177182

183+
if funcDecl.Body == nil {
184+
goto set_func
185+
}
178186
ast.Inspect(funcDecl.Body, func(node ast.Node) bool {
179187
// scope := ctx.pkgTypeInfo.Scopes[node]
180188
call, ok := node.(*ast.CallExpr)
@@ -230,12 +238,9 @@ func (p *goParser) parseFunc(ctx *fileContext, funcDecl *ast.FuncDecl) (*Functio
230238
mpkg := m.Pkg().Path()
231239
//NOTICE: use {structName.methodName} as method key
232240
mname := obj.Name() + "." + m.Name()
233-
f := p.getOrSetFunc(mpkg, mname)
234-
f.AssociatedStruct = p.getOrSetStruct(mpkg, obj.Name())
235-
f.IsMethod = true
236-
237241
if strings.HasPrefix(mpkg, p.modName) {
238242
// internal pkg
243+
f := p.getOrSetFunc(mpkg, mname)
239244
methodCalls[funcName] = f
240245
} else {
241246
// external pkg
@@ -253,6 +258,8 @@ func (p *goParser) parseFunc(ctx *fileContext, funcDecl *ast.FuncDecl) (*Functio
253258
}
254259
return true
255260
})
261+
262+
set_func:
256263
name := funcDecl.Name.Name
257264
if isMethod {
258265
name = associatedStruct.Name + "." + name
@@ -292,10 +299,11 @@ func (ctx *fileContext) IsSysImport(alias string) bool {
292299

293300
// Struct holds the information about a struct
294301
type Struct struct {
295-
Name string // Name of the struct
296-
PkgPath // Path to the package where the struct is defined
297-
FilePath string // File where the struct is defined
298-
Content string // struct declaration content
302+
IsInterface bool //maybe a interface type decl
303+
Name string // Name of the struct
304+
PkgPath // Path to the package where the struct is defined
305+
FilePath string // File where the struct is defined
306+
Content string // struct declaration content
299307

300308
// related local structs in fields, key is {{pkgName.typName}} or {{typeName}}, val is declaration of the struct
301309
InternalStructs map[string]*Struct
@@ -325,8 +333,7 @@ type fileContext struct {
325333

326334
// parse a ast.StructType node and renturn allocated *Struct
327335
func (p *goParser) parseStruct(ctx *fileContext, struName string, struDecl *ast.StructType) (*Struct, bool) {
328-
pkgPath := p.pkgPathFromABS(filepath.Dir(ctx.filePath))
329-
st := p.getOrSetStruct(pkgPath, struName)
336+
st := p.getOrSetStruct(ctx.pkgPath, struName)
330337
st.FilePath = ctx.filePath
331338

332339
pos := ctx.fset.PositionFor(struDecl.Pos(), false).Offset
@@ -347,7 +354,7 @@ func (p *goParser) parseStruct(ctx *fileContext, struName string, struDecl *ast.
347354
// TODO: combine all names
348355
name = fieldDecl.Names[0].String()
349356
} else {
350-
name = string(ctx.bs[fieldDecl.Type.Pos():fieldDecl.Type.End()])
357+
name = string(ctx.GetRawContent(fieldDecl))
351358
}
352359

353360
types := []ThirdPartyIdentity{}
@@ -374,7 +381,7 @@ func (p *goParser) parseStruct(ctx *fileContext, struName string, struDecl *ast.
374381
}
375382
} else {
376383
// local package
377-
sub := p.getOrSetStruct(pkgPath, ty.Identity)
384+
sub := p.getOrSetStruct(ctx.pkgPath, ty.Identity)
378385
inStructs[name] = sub
379386
}
380387

@@ -390,6 +397,52 @@ func (p *goParser) parseStruct(ctx *fileContext, struName string, struDecl *ast.
390397
return st, true
391398
}
392399

400+
func (ctx *fileContext) GetRawContent(node ast.Node) []byte {
401+
return ctx.bs[ctx.fset.Position(node.Pos()).Offset:ctx.fset.Position(node.End()).Offset]
402+
}
403+
404+
func (p *goParser) parseInterface(ctx *fileContext, name string, decl *ast.InterfaceType) (*Struct, bool) {
405+
if decl == nil || decl.Incomplete {
406+
return nil, true
407+
}
408+
409+
st := p.getOrSetStruct(ctx.pkgPath, name)
410+
st.FilePath = ctx.filePath
411+
st.IsInterface = true
412+
st.Content = string(ctx.GetRawContent(decl))
413+
414+
methods := map[string]*Function{}
415+
ast.Inspect(decl.Methods, func(n ast.Node) bool {
416+
fieldDecl, ok := n.(*ast.Field)
417+
if !ok {
418+
return true
419+
}
420+
fname := ""
421+
if len(fieldDecl.Names) > 0 {
422+
// TODO: combine all names
423+
fname = fieldDecl.Names[0].String()
424+
} else {
425+
fname = string(ctx.GetRawContent(fieldDecl.Type))
426+
}
427+
428+
types := []ThirdPartyIdentity{}
429+
isFunc := getTypeName(ctx.fset, ctx.bs, fieldDecl.Type, &types)
430+
if !isFunc {
431+
return true
432+
}
433+
434+
f := p.getOrSetFunc(ctx.pkgPath, name+"."+fname)
435+
f.IsMethod = true
436+
f.AssociatedStruct = st
437+
f.FilePath = ctx.filePath
438+
methods[fname] = f
439+
return true
440+
})
441+
442+
st.Methods = methods
443+
return st, true
444+
}
445+
393446
// handle typ expr and return not-builtin type identity and return if the type if a func signature.
394447
// ret is used to store results.
395448
func getTypeName(fset *token.FileSet, file []byte, typ ast.Expr, ret *[]ThirdPartyIdentity) bool {

src/compress/golang/plugin/go_ast.go

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,22 @@ func newGoParser(modName, homePageDir string) *goParser {
2828
if err != nil {
2929
panic(fmt.Sprintf("cannot get absolute path form homePageDir:%v", err))
3030
}
31-
return &goParser{
31+
32+
p := &goParser{
3233
modName: modName,
3334
homePageDir: abs,
3435
processedPkgFunctions: map[PkgPath]map[string]*Function{},
3536
processedPkgStruct: map[PkgPath]map[string]*Struct{},
3637
visited: map[string]bool{},
3738
}
39+
if p.modName == "" {
40+
var err error
41+
p.modName, err = getModuleName(p.homePageDir + "/go.mod")
42+
if err != nil {
43+
panic(err.Error())
44+
}
45+
}
46+
return p
3847
}
3948

4049
// ToABS converts a local package path to absolute path
@@ -92,21 +101,26 @@ func (p *goParser) associateStructWithMethods() {
92101
// ParseTilTheEnd parse the all go files from the starDir,
93102
// and their related go files in the project recursively
94103
func (p *goParser) ParseTilTheEnd(startDir string) error {
95-
if p.modName == "" {
96-
var err error
97-
p.modName, err = getModuleName(p.homePageDir + "/go.mod")
98-
if err != nil {
99-
return err
100-
}
101-
}
102104
if err := p.ParseDir(startDir); err != nil {
103105
return err
104106
}
105-
for _, pkg := range p.processedPkgFunctions {
107+
for path, pkg := range p.processedPkgFunctions {
108+
// ignore third-party packages
109+
if !strings.Contains(path, p.modName) {
110+
continue
111+
}
106112
for _, f := range pkg {
107113
// Notice: local funcs has been parsed in ParseDir
108114
for _, fc := range f.InternalFunctionCalls {
109-
if fc.FilePath != "" {
115+
if p.visited[fc.PkgPath] {
116+
continue
117+
}
118+
if err := p.ParseTilTheEnd(p.pkgPathToABS(fc.PkgPath)); err != nil {
119+
return err
120+
}
121+
}
122+
for _, fc := range f.InternalMethodCalls {
123+
if p.visited[fc.PkgPath] {
110124
continue
111125
}
112126
if err := p.ParseTilTheEnd(p.pkgPathToABS(fc.PkgPath)); err != nil {
@@ -164,7 +178,7 @@ type SingleStruct struct {
164178
Content string
165179
}
166180

167-
func (p *goParser) getMain() (*MainStream, *Function) {
181+
func (p *goParser) getMain(depth int) (*MainStream, *Function) {
168182
m := &MainStream{
169183
RelatedFunctions: make([]SingleFunction, 0),
170184
}
@@ -180,36 +194,61 @@ func (p *goParser) getMain() (*MainStream, *Function) {
180194
}
181195
}
182196
}
183-
p.fillRelatedContent(mainFunc, &m.RelatedFunctions, &m.RelatedStruct)
197+
visited := map[string]map[string]bool{}
198+
p.fillRelatedContent(depth, mainFunc, &m.RelatedFunctions, &m.RelatedStruct, visited)
184199
return m, mainFunc
185200
}
186201

187-
func (p *goParser) fillRelatedContent(f *Function, fl *[]SingleFunction, sl *[]SingleStruct) {
202+
func (p *goParser) fillRelatedContent(depth int, f *Function, fl *[]SingleFunction, sl *[]SingleStruct, visited map[string]map[string]bool) {
203+
if depth == 0 {
204+
return
205+
}
206+
if f == nil || (visited[f.PkgPath] != nil && visited[f.PkgPath][f.Name]) {
207+
return
208+
} else {
209+
if visited[f.PkgPath] == nil {
210+
visited[f.PkgPath] = map[string]bool{}
211+
}
212+
visited[f.PkgPath][f.Name] = true
213+
}
188214
for call, ff := range f.InternalFunctionCalls {
189215
s := SingleFunction{
190216
CallName: call,
191217
Content: ff.Content,
192218
}
193219
*fl = append(*fl, s)
194-
p.fillRelatedContent(ff, fl, sl)
220+
p.fillRelatedContent(depth-1, ff, fl, sl, visited)
195221
}
196222

197223
for call, ff := range f.InternalMethodCalls {
224+
content := ff.Content
225+
if ff.AssociatedStruct != nil && ff.AssociatedStruct.IsInterface {
226+
content = ff.AssociatedStruct.Content
227+
}
198228
s := SingleFunction{
199229
CallName: call,
200-
Content: ff.Content,
230+
Content: content,
201231
}
202232
*fl = append(*fl, s)
203-
p.fillRelatedContent(ff, fl, sl)
233+
p.fillRelatedContent(depth-1, ff, fl, sl, visited)
234+
204235
// for method which has been associated with struct, push the struct
205236
if ff.AssociatedStruct != nil && ff.AssociatedStruct.Content != "" {
237+
st := ff.AssociatedStruct
238+
if visited[st.PkgPath] != nil && visited[st.PkgPath][st.Name] {
239+
continue
240+
} else if visited[st.PkgPath] == nil {
241+
visited[st.PkgPath] = map[string]bool{}
242+
}
243+
visited[st.PkgPath][st.Name] = true
206244
ss := SingleStruct{
207-
Name: ff.PkgPath + "." + ff.AssociatedStruct.Name,
208-
Content: ff.AssociatedStruct.Content,
245+
Name: ff.PkgPath + "." + st.Name,
246+
Content: st.Content,
209247
}
210248
*sl = append(*sl, ss)
211249
}
212-
p.fillRelatedContent(ff, fl, sl)
250+
251+
p.fillRelatedContent(depth-1, ff, fl, sl, visited)
213252
}
214253
}
215254

@@ -255,7 +294,7 @@ func main() {
255294
}
256295

257296
// p.generateStruct()
258-
m, _ := p.getMain()
297+
m, _ := p.getMain(100)
259298
m.Dedup()
260299

261300
out := bytes.NewBuffer(nil)

src/compress/golang/plugin/go_ast_test.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package main
33
import (
44
"encoding/json"
55
"testing"
6-
7-
"github.com/davecgh/go-spew/spew"
86
)
97

108
func Test_goParser_ParseTilTheEnd(t *testing.T) {
@@ -23,10 +21,10 @@ func Test_goParser_ParseTilTheEnd(t *testing.T) {
2321
{
2422
name: "test",
2523
fields: fields{
26-
homePageDir: "../../../../testdata/golang",
24+
homePageDir: "/Users/bytedance/GOPATH/work/hertz",
2725
},
2826
args: args{
29-
startDir: "./cmd",
27+
startDir: "cmd/hz",
3028
},
3129
},
3230
}
@@ -37,16 +35,16 @@ func Test_goParser_ParseTilTheEnd(t *testing.T) {
3735
if err != nil {
3836
t.Fatalf("goParser.ParseTilTheEnd() error = %v", err)
3937
}
40-
spew.Dump(p)
41-
out, fun := p.getMain()
38+
// spew.Dump(p)
39+
out, fun := p.getMain(2)
4240
if fun.Name != "main" {
4341
t.Fail()
4442
}
4543
out.Dedup()
4644
if out, err := json.MarshalIndent(out, "", " "); err != nil {
4745
t.Fatalf("json.Marshal() error = %v", err)
4846
} else {
49-
println(string(out))
47+
println("size:", len(out))
5048
}
5149
})
5250
}

src/compress/golang/plugin/pkg.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ func (p *goParser) ParseDir(dir string) (err error) {
3434
if p.visited[pkgPath] {
3535
return nil
3636
}
37+
p.visited[pkgPath] = true
3738

3839
// slow-path: load packages in the dir, including sub pakcages
3940
fset := token.NewFileSet()

0 commit comments

Comments
 (0)