Skip to content

Commit a595f27

Browse files
Add additional type checks
1 parent f5a092c commit a595f27

4 files changed

Lines changed: 54 additions & 15 deletions

File tree

tilebox-workflows/tests/test_task.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def test_merge_future_tasks_to_submissions() -> None:
428428
tasks_3 = context.submit_subtasks([TaskA(6, "six"), TaskB(8.12)], cluster="other")
429429

430430
submissions = merge_future_tasks_to_submissions(tasks_1 + tasks_2 + tasks_3, fallback_cluster="test")
431+
assert submissions is not None
431432
assert len(submissions.task_groups) == 1
432433
group = submissions.task_groups[0]
433434
assert group.dependencies_on_other_groups == []
@@ -465,6 +466,7 @@ def test_merge_future_tasks_to_submissions_dependencies() -> None:
465466
# tasks_1, tasks_2 and tasks_4 should not be merged, because they have different dependants
466467
# tasks_3 and tasks_5 should not be merged, because they have different dependencies
467468

469+
assert submissions is not None
468470
assert len(submissions.task_groups) == 5
469471
assert submissions.task_groups[0].dependencies_on_other_groups == []
470472
assert submissions.task_groups[0].inputs == [serialize_task(TaskA(2, "two")), serialize_task(TaskA(3, "three"))]
@@ -485,6 +487,7 @@ def test_merge_future_tasks_to_submissions_many_tasks() -> None:
485487
tasks_2 = context.submit_subtasks([TaskB(i / 3) for i in range(n)], depends_on=tasks_1)
486488

487489
submissions = merge_future_tasks_to_submissions(tasks_1 + tasks_2, fallback_cluster="test")
490+
assert submissions is not None
488491
assert len(submissions.task_groups) == 2
489492
assert submissions.task_groups[0].dependencies_on_other_groups == []
490493
assert submissions.task_groups[0].identifier_pointers == [0] * n
@@ -500,4 +503,26 @@ def test_merge_future_tasks_to_submissions_many_non_mergeable_dependency_groups(
500503
context.submit_subtasks([TaskB(i / 3)], depends_on=task_1)
501504

502505
submissions = merge_future_tasks_to_submissions(context._sub_tasks, fallback_cluster="test")
506+
assert submissions is not None
503507
assert len(submissions.task_groups) == 2 * n
508+
509+
510+
def test_merge_future_tasks_two_separate_branches() -> None:
511+
context = RunnerExecutionContext(None, None, job_cache=InMemoryCache()) # type: ignore[arg-type]
512+
task_a = context.submit_subtasks([TaskA(0, "Task 0")])
513+
# left branch
514+
task_b_left = context.submit_subtasks([TaskB(0.0)], depends_on=task_a)
515+
context.submit_subtasks([TaskB(1.0)], depends_on=task_b_left)
516+
517+
# right branch
518+
task_b_right = context.submit_subtasks([TaskB(2.0)], depends_on=task_a)
519+
context.submit_subtasks([TaskB(3.0)], depends_on=task_b_right)
520+
521+
submissions = merge_future_tasks_to_submissions(context._sub_tasks, fallback_cluster="test")
522+
assert submissions is not None
523+
assert len(submissions.task_groups) == 5
524+
assert submissions.task_groups[0].dependencies_on_other_groups == []
525+
assert submissions.task_groups[1].dependencies_on_other_groups == [0]
526+
assert submissions.task_groups[2].dependencies_on_other_groups == [1]
527+
assert submissions.task_groups[3].dependencies_on_other_groups == [0]
528+
assert submissions.task_groups[4].dependencies_on_other_groups == [3]

tilebox-workflows/tilebox/workflows/jobs/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,13 @@ def submit(
7878
)
7979

8080
task_submissions = [FutureTask(i, task, [], slugs[i], max_retries) for i, task in enumerate(tasks)]
81+
submissions_merged = merge_future_tasks_to_submissions(task_submissions, default_cluster)
82+
if submissions_merged is None:
83+
raise ValueError("At least one task must be submitted.")
8184

8285
with self._tracer.start_as_current_span(f"job/{job_name}"):
8386
trace_parent = get_trace_parent_of_current_span()
84-
return self._service.submit(
85-
job_name, trace_parent, merge_future_tasks_to_submissions(task_submissions, default_cluster)
86-
)
87+
return self._service.submit(job_name, trace_parent, submissions_merged)
8788

8889
def retry(self, job_or_id: JobIDLike) -> int:
8990
"""Retry a job.

tilebox-workflows/tilebox/workflows/runner/task_runner.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,20 @@ def __init__(self, runner: TaskRunner, task: Task, job_cache: JobCache) -> None:
526526
def submit_subtask(
527527
self,
528528
task: TaskInstance,
529-
depends_on: list[FutureTask] | None = None,
529+
depends_on: FutureTask | list[FutureTask] | None = None,
530530
cluster: str | None = None,
531531
max_retries: int = 0,
532532
) -> FutureTask:
533533
dependencies: list[int] = []
534-
for dep in depends_on or []:
534+
535+
if depends_on is None:
536+
depends_on = []
537+
elif isinstance(depends_on, FutureTask):
538+
depends_on = [depends_on]
539+
elif not isinstance(depends_on, list):
540+
raise TypeError(f"Invalid dependency. Expected FutureTask or list[FutureTask], got {type(depends_on)}")
541+
542+
for dep in depends_on:
535543
if not isinstance(dep, FutureTask):
536544
raise TypeError(f"Invalid dependency. Expected FutureTask, got {type(dep)}")
537545
if dep.index >= len(self._sub_tasks):
@@ -553,7 +561,7 @@ def submit_subtasks(
553561
tasks: Sequence[TaskInstance],
554562
cluster: str | None = None,
555563
max_retries: int = 0,
556-
depends_on: list[FutureTask] | None = None,
564+
depends_on: FutureTask | list[FutureTask] | None = None,
557565
) -> list[FutureTask]:
558566
return [
559567
self.submit_subtask(task, cluster=cluster, max_retries=max_retries, depends_on=depends_on) for task in tasks

tilebox-workflows/tilebox/workflows/task.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,10 @@ def append_if_unique(self, value: _T) -> int:
265265
return index
266266

267267

268-
def merge_future_tasks_to_submissions(future_tasks: list[FutureTask], fallback_cluster: str) -> TaskSubmissions:
268+
def merge_future_tasks_to_submissions(future_tasks: list[FutureTask], fallback_cluster: str) -> TaskSubmissions | None:
269+
if len(future_tasks) == 0:
270+
return None
271+
269272
dependants = defaultdict(set)
270273
for task in future_tasks:
271274
for dep in task.depends_on:
@@ -292,23 +295,25 @@ def merge_future_tasks_to_submissions(future_tasks: list[FutureTask], fallback_c
292295
group_index = group_keys.append_if_unique(group_key)
293296
if group_index == len(groups): # it was a new unique group
294297
groups.append(TaskSubmissionGroup(dependencies_on_other_groups=task.depends_on))
295-
296-
group = groups[group_index]
297298
task_index_to_group[task.index] = group_index
298299

299-
group.inputs.append(task.input())
300-
group.identifier_pointers.append(identifiers.append_if_unique(task.identifier()))
301-
group.cluster_slug_pointers.append(cluster_slugs.append_if_unique(task.cluster or fallback_cluster))
302-
group.display_pointers.append(displays.append_if_unique(task.display()))
303-
group.max_retries_values.append(task.max_retries)
304-
305300
for i in range(len(groups)):
306301
group = groups[i]
307302
group.dependencies_on_other_groups = list(
308303
# convert the task dependencies to group dependencies, deduplicate and sort them
309304
{task_index_to_group[dep] for dep in group.dependencies_on_other_groups}
310305
)
311306

307+
for task in future_tasks:
308+
group_index = task_index_to_group[task.index]
309+
group = groups[group_index]
310+
311+
group.inputs.append(task.input())
312+
group.identifier_pointers.append(identifiers.append_if_unique(task.identifier()))
313+
group.cluster_slug_pointers.append(cluster_slugs.append_if_unique(task.cluster or fallback_cluster))
314+
group.display_pointers.append(displays.append_if_unique(task.display()))
315+
group.max_retries_values.append(task.max_retries)
316+
312317
return TaskSubmissions(
313318
task_groups=groups,
314319
cluster_slug_lookup=cluster_slugs.values,

0 commit comments

Comments
 (0)