Skip to content

Commit bee6408

Browse files
committed
implement isSysPkg by package
1 parent 1fd40e0 commit bee6408

2 files changed

Lines changed: 9 additions & 77 deletions

File tree

lang/golang/parser/utils.go

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@ import (
1919
"bytes"
2020
"container/list"
2121
"fmt"
22-
"github.com/cloudwego/abcoder/lang/log"
2322
"go/ast"
23+
"go/build"
2424
"go/types"
2525
"io"
2626
"os"
27-
"os/exec"
2827
"path"
29-
"path/filepath"
3028
"regexp"
3129
"strings"
3230
"sync"
@@ -109,55 +107,20 @@ func (pc *PackageCache) Set(key string, value bool) {
109107
pc.cache[key] = elem
110108
}
111109

112-
var (
113-
gorootOnce sync.Once
114-
detectedGoRoot string
115-
gorootErr error
116-
)
117-
118-
// getGoRoot 获取 go root 环境变量。
119-
func getGoRoot() (string, error) {
120-
gorootOnce.Do(func() {
121-
cmd := exec.Command("go", "env", "GOROOT")
122-
var out bytes.Buffer
123-
var stderr bytes.Buffer
124-
cmd.Stdout = &out
125-
cmd.Stderr = &stderr
126-
err := cmd.Run()
127-
if err != nil {
128-
log.Info("'go env GOROOT' failed: %v, stderr: %s; \n `isSysPkg` will downgrade.", err, stderr.String())
129-
gorootErr = fmt.Errorf("'go env GOROOT' failed: %w, stderr: %s", err, stderr.String())
130-
return
131-
}
132-
133-
gorootPath := strings.TrimSpace(out.String())
134-
if gorootPath == "" {
135-
log.Info("'go env GOROOT' returns a empty string \n `isSysPkg` will downgrade.")
136-
gorootErr = fmt.Errorf("'go env GOROOT' returns a empty string")
137-
return
138-
}
139-
detectedGoRoot = gorootPath
140-
})
141-
return detectedGoRoot, gorootErr
142-
}
143-
144110
// IsStandardPackage 检查一个包是否为标准库,并使用内部缓存。
145111
func (pc *PackageCache) IsStandardPackage(path string) bool {
146112
if isStd, found := pc.Get(path); found {
147113
return isStd
148114
}
149115

150-
goRoot, err := getGoRoot()
151-
// 当前环境找不到 go root,退化到最简单判断
152-
var isStd bool
153-
if err != nil || goRoot == "" {
154-
isStd = !strings.Contains(strings.Split(path, "/")[0], ".")
155-
} else {
156-
pkgPath := filepath.Join(goRoot, "src", path)
157-
_, err = os.Stat(pkgPath)
158-
isStd = !os.IsNotExist(err)
116+
pkg, err := build.Import(path, "", build.FindOnly)
117+
if err != nil {
118+
// Cannot find the package, assume it's not a standard package
119+
pc.Set(path, false)
120+
return false
159121
}
160122

123+
isStd := pkg.Goroot
161124
pc.Set(path, isStd)
162125
return isStd
163126
}

lang/golang/parser/utils_test.go

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package parser
1616

1717
import (
18-
"github.com/stretchr/testify/assert"
1918
"go/ast"
2019
"go/importer"
2120
"go/parser"
@@ -25,6 +24,8 @@ import (
2524
"sync"
2625
"testing"
2726

27+
"github.com/stretchr/testify/assert"
28+
2829
"github.com/stretchr/testify/require"
2930
)
3031

@@ -199,11 +200,6 @@ var f func() (*http.Request, error)`,
199200
}
200201

201202
func resetGlobals() {
202-
// 重置 GOROOT 的 sync.Once 和其缓存的变量
203-
gorootOnce = sync.Once{}
204-
detectedGoRoot = ""
205-
gorootErr = nil
206-
207203
// 重置包缓存
208204
stdlibCache = NewPackageCache(10000)
209205
}
@@ -236,33 +232,6 @@ func Test_isSysPkg(t *testing.T) {
236232
}
237233
})
238234

239-
// 测试在 `go env GOROOT` 执行失败时的行为
240-
t.Run("Group: Fallback Path - GOROOT is not found", func(t *testing.T) {
241-
resetGlobals()
242-
243-
// 使用 t.Setenv 临时清空 PATH,使得 "go" 命令无法被找到
244-
t.Setenv("PATH", "")
245-
246-
testCases := []struct {
247-
name string
248-
importPath string
249-
want bool
250-
}{
251-
{"standard library package (fallback)", "fmt", true},
252-
{"nested standard library package (fallback)", "os/exec", true},
253-
{"third-party package (fallback)", "github.com/google/uuid", false},
254-
{"local-like package name (fallback)", "myproject/utils", true}, // 在降级模式下,被错误地判断为 true
255-
}
256-
257-
for _, tc := range testCases {
258-
t.Run(tc.name, func(t *testing.T) {
259-
if got := isSysPkg(tc.importPath); got != tc.want {
260-
t.Errorf("isSysPkg(%q) in fallback mode = %v, want %v", tc.importPath, got, tc.want)
261-
}
262-
})
263-
}
264-
})
265-
266235
// 测试并发调用时的行为
267236
t.Run("Group: Concurrency Test", func(t *testing.T) {
268237
resetGlobals()

0 commit comments

Comments
 (0)