@@ -32,6 +32,23 @@ var ErrInvalidVersionQueryChar = errors.New("invalid character in version query"
3232// ErrUnexpectedGoVersionOutput is returned when go version output has unexpected format.
3333var ErrUnexpectedGoVersionOutput = errors .New ("unexpected go version output" )
3434
35+ // commander is the minimal surface of *exec.Cmd that this package uses.
36+ // It exists so tests can swap out commandContext below with a fake.
37+ type commander interface {
38+ CombinedOutput () ([]byte , error )
39+ }
40+
41+ // commandContext builds a commander for the given dir/name/args. Overridable in tests.
42+ // An empty dir leaves cmd.Dir unset (inheriting the parent process working directory),
43+ // matching the behavior of callers that don't pin a working directory.
44+ var commandContext = func (ctx context.Context , dir , name string , args ... string ) commander {
45+ cmd := exec .CommandContext (ctx , name , args ... ) //nolint:gosec
46+ if dir != "" {
47+ cmd .Dir = dir
48+ }
49+ return cmd
50+ }
51+
3552// validateModulePath validates a Go module path to prevent injection attacks.
3653// Uses module.CheckPath() from golang.org/x/mod/module to ensure the path is valid.
3754func validateModulePath (path string ) error {
@@ -72,10 +89,7 @@ func GoModTidy(ctx context.Context, modroot, _ string, compat string) (string, e
7289 if compat != "" {
7390 args = append (args , "-compat" , compat )
7491 }
75-
76- cmd := exec .CommandContext (ctx , "go" , args ... )
77- cmd .Dir = modroot
78- if bytes , err := cmd .CombinedOutput (); err != nil {
92+ if bytes , err := commandContext (ctx , modroot , "go" , args ... ).CombinedOutput (); err != nil {
7993 return strings .TrimSpace (string (bytes )), err
8094 }
8195 return "" , nil
@@ -127,8 +141,7 @@ func UpdateGoWorkVersion(ctx context.Context, modroot string, forceWork bool, go
127141
128142 // Get Go version from environment if not provided
129143 if goVersion == "" {
130- cmd := exec .CommandContext (ctx , "go" , "version" )
131- output , err := cmd .CombinedOutput ()
144+ output , err := commandContext (ctx , "" , "go" , "version" ).CombinedOutput ()
132145 if err != nil {
133146 return fmt .Errorf ("failed to get Go version: %w, output: %s" , err , strings .TrimSpace (string (output )))
134147 }
@@ -174,9 +187,7 @@ func UpdateGoWorkVersion(ctx context.Context, modroot string, forceWork bool, go
174187
175188 dir := filepath .Dir (workPath )
176189 // Safe: goVersion is either auto-detected from runtime.Version() or user-provided version string (e.g., "1.21")
177- cmd := exec .CommandContext (ctx , "go" , "work" , "edit" , "-go" , goVersion ) //nolint:gosec // G204: goVersion is a version string, not user-controlled path
178- cmd .Dir = dir
179- if bytes , err := cmd .CombinedOutput (); err != nil {
190+ if bytes , err := commandContext (ctx , dir , "go" , "work" , "edit" , "-go" , goVersion ).CombinedOutput (); err != nil {
180191 return fmt .Errorf ("failed to update go.work version: %w, output: %s" , err , strings .TrimSpace (string (bytes )))
181192 }
182193
@@ -187,15 +198,11 @@ func UpdateGoWorkVersion(ctx context.Context, modroot string, forceWork bool, go
187198func GoVendor (ctx context.Context , dir string , forceWork bool ) (string , error ) {
188199 workPath := findGoWork (dir )
189200 if forceWork || workPath != "" {
190- cmd := exec .CommandContext (ctx , "go" , "work" , "vendor" )
191- cmd .Dir = dir
192- if bytes , err := cmd .CombinedOutput (); err != nil {
201+ if bytes , err := commandContext (ctx , dir , "go" , "work" , "vendor" ).CombinedOutput (); err != nil {
193202 return strings .TrimSpace (string (bytes )), err
194203 }
195204 } else {
196- cmd := exec .CommandContext (ctx , "go" , "mod" , "vendor" )
197- cmd .Dir = dir
198- if bytes , err := cmd .CombinedOutput (); err != nil {
205+ if bytes , err := commandContext (ctx , dir , "go" , "mod" , "vendor" ).CombinedOutput (); err != nil {
199206 return strings .TrimSpace (string (bytes )), err
200207 }
201208 }
@@ -213,9 +220,7 @@ func GoGetModule(ctx context.Context, name, version, modroot string) (string, er
213220 if err := validateVersionQuery (version ); err != nil {
214221 return "" , err
215222 }
216- cmd := exec .CommandContext (ctx , "go" , "get" , fmt .Sprintf ("%s@%s" , name , version )) //nolint:gosec
217- cmd .Dir = modroot
218- if bytes , err := cmd .CombinedOutput (); err != nil {
223+ if bytes , err := commandContext (ctx , modroot , "go" , "get" , fmt .Sprintf ("%s@%s" , name , version )).CombinedOutput (); err != nil {
219224 return strings .TrimSpace (string (bytes )), err
220225 }
221226 return "" , nil
@@ -235,15 +240,11 @@ func GoModEditReplaceModule(ctx context.Context, nameOld, nameNew, version, modr
235240 return "" , fmt .Errorf ("invalid version: %w" , err )
236241 }
237242
238- cmd := exec .CommandContext (ctx , "go" , "mod" , "edit" , "-dropreplace" , nameOld ) //nolint:gosec
239- cmd .Dir = modroot
240- if bytes , err := cmd .CombinedOutput (); err != nil {
243+ if bytes , err := commandContext (ctx , modroot , "go" , "mod" , "edit" , "-dropreplace" , nameOld ).CombinedOutput (); err != nil {
241244 return strings .TrimSpace (string (bytes )), fmt .Errorf ("error running go command to drop replace modules: %w" , err )
242245 }
243246
244- cmd = exec .CommandContext (ctx , "go" , "mod" , "edit" , "-replace" , fmt .Sprintf ("%s=%s@%s" , nameOld , nameNew , version )) //nolint:gosec
245- cmd .Dir = modroot
246- if bytes , err := cmd .CombinedOutput (); err != nil {
247+ if bytes , err := commandContext (ctx , modroot , "go" , "mod" , "edit" , "-replace" , fmt .Sprintf ("%s=%s@%s" , nameOld , nameNew , version )).CombinedOutput (); err != nil {
247248 return strings .TrimSpace (string (bytes )), fmt .Errorf ("error running go command to replace modules: %w" , err )
248249 }
249250 return "" , nil
@@ -256,9 +257,7 @@ func GoModEditDropRequireModule(ctx context.Context, name, modroot string) (stri
256257 return "" , err
257258 }
258259 // Safe: module path validated above
259- cmd := exec .CommandContext (ctx , "go" , "mod" , "edit" , "-droprequire" , name ) //nolint:gosec
260- cmd .Dir = modroot
261- if bytes , err := cmd .CombinedOutput (); err != nil {
260+ if bytes , err := commandContext (ctx , modroot , "go" , "mod" , "edit" , "-droprequire" , name ).CombinedOutput (); err != nil {
262261 return strings .TrimSpace (string (bytes )), err
263262 }
264263
0 commit comments