Skip to content

Commit 0b37b64

Browse files
committed
fix inference logic
1 parent 70c2e1c commit 0b37b64

3 files changed

Lines changed: 162 additions & 1 deletion

File tree

lang/golang/parser/utils.go

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ import (
1818
"bufio"
1919
"bytes"
2020
"fmt"
21+
"github.com/cloudwego/abcoder/lang/log"
2122
"go/ast"
2223
"go/types"
2324
"io"
2425
"os"
26+
"os/exec"
2527
"path"
28+
"path/filepath"
2629
"regexp"
2730
"strings"
31+
"sync"
2832

2933
"github.com/Knetic/govaluate"
3034
. "github.com/cloudwego/abcoder/lang/uniast"
@@ -49,8 +53,97 @@ func (c cache) Visited(val interface{}) bool {
4953
return ok
5054
}
5155

56+
// PackageCache 缓存 importPath 是否是 system package
57+
type PackageCache struct {
58+
lock sync.RWMutex
59+
cache map[string]bool
60+
}
61+
62+
func NewPackageCache() *PackageCache {
63+
return &PackageCache{
64+
cache: make(map[string]bool),
65+
}
66+
}
67+
68+
var (
69+
gorootOnce sync.Once
70+
detectedGoRoot string
71+
gorootErr error
72+
)
73+
74+
// getGoRoot 获取 go root 环境变量。
75+
func getGoRoot() (string, error) {
76+
gorootOnce.Do(func() {
77+
cmd := exec.Command("go", "env", "GOROOT")
78+
var out bytes.Buffer
79+
var stderr bytes.Buffer
80+
cmd.Stdout = &out
81+
cmd.Stderr = &stderr
82+
err := cmd.Run()
83+
if err != nil {
84+
log.Info("'go env GOROOT' failed: %w, stderr: %s; \n `isSysPkg` will downgrade.", err, stderr.String())
85+
gorootErr = fmt.Errorf("'go env GOROOT' failed: %w, stderr: %s", err, stderr.String())
86+
return
87+
}
88+
89+
gorootPath := strings.TrimSpace(out.String())
90+
if gorootPath == "" {
91+
log.Info("'go env GOROOT' returns a empty string \n `isSysPkg` will downgrade.")
92+
gorootErr = fmt.Errorf("'go env GOROOT' returns a empty string")
93+
return
94+
}
95+
detectedGoRoot = gorootPath
96+
})
97+
return detectedGoRoot, gorootErr
98+
}
99+
100+
// IsStandardPackage 检查一个包是否为标准库,并使用内部缓存。
101+
func (pc *PackageCache) IsStandardPackage(path string) bool {
102+
103+
goRoot, err := getGoRoot()
104+
// 当前环境找不到 go root,退化到最简单判断
105+
if err != nil || goRoot == "" {
106+
return !strings.Contains(strings.Split(path, "/")[0], ".")
107+
}
108+
109+
pc.lock.RLock()
110+
isStd, found := pc.cache[path]
111+
pc.lock.RUnlock()
112+
113+
if found {
114+
return isStd
115+
}
116+
117+
pc.lock.Lock()
118+
defer pc.lock.Unlock()
119+
120+
isStd, found = pc.cache[path]
121+
if found {
122+
return isStd
123+
}
124+
125+
pkgPath := filepath.Join(goRoot, "src", path)
126+
stat, err := os.Stat(pkgPath)
127+
if err != nil {
128+
if os.IsNotExist(err) {
129+
isStd = false
130+
} else {
131+
log.Info("IsStandardPackage: failed to get file stat for %s: %v", pkgPath, err)
132+
return false
133+
}
134+
} else {
135+
isStd = stat.IsDir()
136+
}
137+
138+
pc.cache[path] = isStd
139+
140+
return isStd
141+
}
142+
143+
var stdlibCache = NewPackageCache()
144+
52145
func isSysPkg(importPath string) bool {
53-
return !strings.Contains(strings.Split(importPath, "/")[0], ".")
146+
return stdlibCache.IsStandardPackage(importPath)
54147
}
55148

56149
var (

lang/golang/parser/utils_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"go/token"
2222
"go/types"
2323
"slices"
24+
"sync"
2425
"testing"
2526

2627
"github.com/stretchr/testify/require"
@@ -195,3 +196,69 @@ var f func() (*http.Request, error)`,
195196
})
196197
}
197198
}
199+
200+
func resetGlobals() {
201+
// 重置 GOROOT 的 sync.Once 和其缓存的变量
202+
gorootOnce = sync.Once{}
203+
detectedGoRoot = ""
204+
gorootErr = nil
205+
206+
// 重置包缓存
207+
stdlibCache = NewPackageCache()
208+
}
209+
210+
func Test_isSysPkg(t *testing.T) {
211+
// 测试在 `go env GOROOT` 可以成功执行时的行为
212+
t.Run("Group: Happy Path - GOROOT is found", func(t *testing.T) {
213+
resetGlobals()
214+
215+
testCases := []struct {
216+
name string
217+
importPath string
218+
want bool
219+
}{
220+
{"standard library package", "fmt", true},
221+
{"nested standard library package", "net/http", true},
222+
{"third-party package", "github.com/google/uuid", false},
223+
{"extended library package", "golang.org/x/sync/errgroup", false},
224+
{"local-like package name", "myproject/utils", false},
225+
{"non-existent package", "non/existent/package", false},
226+
{"root-level package with dot", "gopkg.in/yaml.v2", false},
227+
}
228+
229+
for _, tc := range testCases {
230+
t.Run(tc.name, func(t *testing.T) {
231+
if got := isSysPkg(tc.importPath); got != tc.want {
232+
t.Errorf("isSysPkg(%q) = %v, want %v", tc.importPath, got, tc.want)
233+
}
234+
})
235+
}
236+
})
237+
238+
// 测试在 `go env GOROOT` 执行失败时的行为
239+
t.Run("Group: Fallback Path - GOROOT is not found", func(t *testing.T) {
240+
resetGlobals()
241+
242+
// 使用 t.Setenv 临时清空 PATH,使得 "go" 命令无法被找到
243+
t.Setenv("PATH", "")
244+
245+
testCases := []struct {
246+
name string
247+
importPath string
248+
want bool
249+
}{
250+
{"standard library package (fallback)", "fmt", true},
251+
{"nested standard library package (fallback)", "os/exec", true},
252+
{"third-party package (fallback)", "github.com/google/uuid", false},
253+
{"local-like package name (fallback)", "myproject/utils", true}, // 在降级模式下,被错误地判断为 true
254+
}
255+
256+
for _, tc := range testCases {
257+
t.Run(tc.name, func(t *testing.T) {
258+
if got := isSysPkg(tc.importPath); got != tc.want {
259+
t.Errorf("isSysPkg(%q) in fallback mode = %v, want %v", tc.importPath, got, tc.want)
260+
}
261+
})
262+
}
263+
})
264+
}

testdata/asts/localsession_g.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)