Skip to content

Commit 00ab0c4

Browse files
committed
feat: Update GCP Batch support
1 parent b433b6a commit 00ab0c4

2 files changed

Lines changed: 156 additions & 23 deletions

File tree

compute/gcp_batch/backend.go

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ func (b *Backend) Submit(task *tes.Task) error {
198198
}
199199
}
200200

201-
// Mount all buckets to `/mnt/share/<BUCKET>` as volumes in the GCP Job Request
201+
// Mount all buckets to `/mnt/disks/<BUCKET>` as volumes in the GCP Job Request
202202
var volumes []*batchpb.Volume
203203
for bucketName := range buckets {
204204
volumes = append(volumes, &batchpb.Volume{
@@ -211,23 +211,83 @@ func (b *Backend) Submit(task *tes.Task) error {
211211
})
212212
}
213213

214+
// Build a path map: user-specified path → /mnt/disks/<bucket>/<object> so
215+
// that executor commands referencing those paths are rewritten before
216+
// submission. This avoids symlinks, which are unreliable across containers
217+
// on COS (Container-Optimized OS) VMs where each container has an isolated
218+
// filesystem.
219+
if err := detectPathCollisions(task.Inputs, task.Outputs); err != nil {
220+
return fmt.Errorf("GCP Batch path collision: %w", err)
221+
}
222+
223+
pathMap := make(map[string]string) // userPath → mountedPath
224+
for _, input := range task.Inputs {
225+
if input.Path == "" || input.Url == "" {
226+
continue
227+
}
228+
if err := validatePath(input.Path); err != nil {
229+
return fmt.Errorf("invalid input path: %w", err)
230+
}
231+
bucket, objectPath := extractGCSPath(input.Url)
232+
if bucket == "" {
233+
continue
234+
}
235+
pathMap[input.Path] = fmt.Sprintf("/mnt/disks/%s/%s", bucket, objectPath)
236+
}
237+
for _, output := range task.Outputs {
238+
if output.Path == "" || output.Url == "" {
239+
continue
240+
}
241+
if err := validatePath(output.Path); err != nil {
242+
return fmt.Errorf("invalid output path: %w", err)
243+
}
244+
bucket, objectPath := extractGCSPath(output.Url)
245+
if bucket == "" {
246+
continue
247+
}
248+
pathMap[output.Path] = fmt.Sprintf("/mnt/disks/%s/%s", bucket, objectPath)
249+
}
250+
251+
// rewriteArg replaces all occurrences of known user paths within a string
252+
// with their /mnt/disks/... equivalents. This handles both standalone path
253+
// arguments and paths embedded inside shell script strings.
254+
rewriteArg := func(s string) string {
255+
for userPath, mountedPath := range pathMap {
256+
s = strings.ReplaceAll(s, userPath, mountedPath)
257+
}
258+
return s
259+
}
260+
214261
// Runnables
215262
var runnables []*batchpb.Runnable
216263

217264
for _, executor := range task.Executors {
218-
cmd := strings.Join(executor.Command, " ")
265+
var commands []string
266+
for _, arg := range executor.Command {
267+
commands = append(commands, rewriteArg(arg))
268+
}
219269

220-
if executor.Stdout != "" {
221-
// Redirect command output to the specified file path
222-
cmd = fmt.Sprintf("%s | tee %s", cmd, executor.Stdout)
270+
// Wrap in a shell only when stdout/stdin/stderr redirection is needed.
271+
if executor.Stdout != "" || executor.Stdin != "" || executor.Stderr != "" {
272+
cmd := strings.Join(commands, " ")
273+
if executor.Stdout != "" {
274+
cmd = fmt.Sprintf("%s | tee %s", cmd, rewriteArg(executor.Stdout))
275+
}
276+
commands = []string{"sh", "-c", cmd}
277+
}
278+
279+
container := &batchpb.Runnable_Container{
280+
ImageUri: executor.Image,
281+
Commands: commands,
282+
}
283+
284+
if executor.Workdir != "" {
285+
container.Options = fmt.Sprintf("--workdir %s", executor.Workdir)
223286
}
224287

225288
runnable := &batchpb.Runnable{
226289
Executable: &batchpb.Runnable_Container_{
227-
Container: &batchpb.Runnable_Container{
228-
ImageUri: executor.Image,
229-
Commands: []string{"sh", "-c", cmd},
230-
},
290+
Container: container,
231291
},
232292
}
233293

compute/gcp_batch/backend_test.go

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,15 @@ func TestSubmit_MultipleInputsOutputs(t *testing.T) {
256256
t.Errorf("Expected 3 volumes, got %d", len(volumes))
257257
}
258258

259-
// Verify symlink commands are present in the generated command
259+
// Paths are rewritten directly in executor commands — no separate setup runnable.
260260
runnables := capturedReq.Job.TaskGroups[0].TaskSpec.Runnables
261261
if len(runnables) != 1 {
262-
t.Fatalf("Expected 1 runnable, got %d", len(runnables))
262+
t.Fatalf("Expected 1 runnable (executor only), got %d", len(runnables))
263+
}
264+
// The executor command ["echo", "test"] doesn't reference any I/O paths, so
265+
// verify correctness via the volumes instead — all 3 buckets must be mounted.
266+
if len(volumes) != 3 {
267+
t.Errorf("Expected 3 volumes for 3 unique buckets, got %d", len(volumes))
263268
}
264269
}
265270

@@ -311,10 +316,24 @@ func TestSubmit_MultipleExecutors(t *testing.T) {
311316
t.Fatalf("Submit() error = %v", err)
312317
}
313318

314-
// Should create 2 runnables, one per executor
319+
// Should create 1 runnable per executor; paths are rewritten inline, no setup runnable.
315320
runnables := capturedReq.Job.TaskGroups[0].TaskSpec.Runnables
316321
if len(runnables) != 2 {
317-
t.Fatalf("Expected 2 runnables, got %d", len(runnables))
322+
t.Fatalf("Expected 2 runnables (1 per executor), got %d", len(runnables))
323+
}
324+
// Both executors reference /data/input.txt which maps to gs://bucket/input.txt.
325+
for i, r := range runnables {
326+
cmds := r.GetContainer().Commands
327+
found := false
328+
for _, c := range cmds {
329+
if strings.Contains(c, "/mnt/disks/bucket/input.txt") {
330+
found = true
331+
break
332+
}
333+
}
334+
if !found {
335+
t.Errorf("Runnable %d: expected rewritten path /mnt/disks/bucket/input.txt in commands %v", i, cmds)
336+
}
318337
}
319338
}
320339

@@ -407,10 +426,63 @@ func TestSubmit_NoInputsOutputs(t *testing.T) {
407426
t.Errorf("Expected 0 volumes, got %d", len(volumes))
408427
}
409428

410-
// Command should still work, just no symlinks
411-
cmd := capturedReq.Job.TaskGroups[0].TaskSpec.Runnables[0].GetContainer().Commands[2]
412-
if !strings.Contains(cmd, "echo hello") {
413-
t.Error("Executor command not present")
429+
// Command should be passed directly (no sh -c wrapping when no redirection needed)
430+
cmds := capturedReq.Job.TaskGroups[0].TaskSpec.Runnables[0].GetContainer().Commands
431+
if len(cmds) != 2 || cmds[0] != "echo" || cmds[1] != "hello" {
432+
t.Errorf("Expected direct command [echo hello], got %v", cmds)
433+
}
434+
}
435+
436+
// Test Submit with executor Workdir sets --workdir docker option
437+
func TestSubmit_ExecutorWorkdir(t *testing.T) {
438+
log := logger.NewLogger("test", logger.DefaultConfig())
439+
conf := &config.GCPBatch{
440+
Project: "test-project",
441+
Location: "us-west1",
442+
}
443+
444+
var capturedReq *batchpb.CreateJobRequest
445+
mockClient := &mockClient{
446+
CreateJobFunc: func(req *batchpb.CreateJobRequest) (*batchpb.Job, error) {
447+
capturedReq = req
448+
return &batchpb.Job{Name: "test-job", Uid: "test-uid"}, nil
449+
},
450+
}
451+
452+
backend := &Backend{
453+
client: mockClient,
454+
conf: conf,
455+
log: log,
456+
event: &noopEventWriter{},
457+
}
458+
459+
task := &tes.Task{
460+
Id: "task1",
461+
Executors: []*tes.Executor{
462+
{Image: "alpine", Command: []string{"echo", "test"}, Workdir: "/work"},
463+
{Image: "alpine", Command: []string{"echo", "no-workdir"}},
464+
},
465+
}
466+
467+
err := backend.Submit(task)
468+
if err != nil {
469+
t.Fatalf("Submit() error = %v", err)
470+
}
471+
472+
runnables := capturedReq.Job.TaskGroups[0].TaskSpec.Runnables
473+
// No inputs/outputs so no setup runnable; just the 2 executor runnables.
474+
if len(runnables) != 2 {
475+
t.Fatalf("Expected 2 runnables, got %d", len(runnables))
476+
}
477+
478+
opts0 := runnables[0].GetContainer().Options
479+
if !strings.Contains(opts0, "--workdir") || !strings.Contains(opts0, "/work") {
480+
t.Errorf("Expected --workdir /work in Options, got %q", opts0)
481+
}
482+
483+
opts1 := runnables[1].GetContainer().Options
484+
if opts1 != "" {
485+
t.Errorf("Expected empty Options for executor without Workdir, got %q", opts1)
414486
}
415487
}
416488

@@ -482,11 +554,12 @@ func TestSubmit_CommandConstruction(t *testing.T) {
482554
t.Fatalf("Submit failed: %v", err)
483555
}
484556

485-
// Verify the command was properly quoted
486-
cmd := capturedReq.Job.TaskGroups[0].TaskSpec.Runnables[0].GetContainer().Commands[2]
487-
488-
// Should contain properly escaped quotes, not broken by spaces
489-
if !strings.Contains(cmd, "python -c") {
490-
t.Errorf("Command should contain 'python -c', got: %s", cmd)
557+
// Commands are passed directly without shell wrapping, so the original args are preserved.
558+
cmds := capturedReq.Job.TaskGroups[0].TaskSpec.Runnables[0].GetContainer().Commands
559+
if len(cmds) != 3 || cmds[0] != "python" || cmds[1] != "-c" {
560+
t.Errorf("Expected direct command [python -c <script>], got %v", cmds)
561+
}
562+
if !strings.Contains(cmds[2], "Hello World") {
563+
t.Errorf("Expected script body in Commands[2], got: %s", cmds[2])
491564
}
492565
}

0 commit comments

Comments
 (0)