Skip to content

Commit ce862b2

Browse files
Stellatsuualdbr
andauthored
feat: automatic job grouping (#95)
* feat: add input_data to TransformationSubmissionModel, updated CLI * chore: add test workflows for transformation grouping * feat: added Job Grouping * chore: rename * chore: added job grouping test * fix: updated transformation hoooks group_size * refactor: rename * refactor: updated tests * refactor: updated group_size definition and values * fix: fixed JobWrapper tests where task.cwl was not deleted * refactor: removed input_name from get_input_query * fix: file location * feat: add input_data to ProductionSubmissionModel * fix: fixed tests and jobwrapper imports * fix: fixed imports * fix: removed pytest warnings * fix: fixed test * fix: fixed mypy error (?) * fix: removed fixture and fixed imports * fix: updated fixture and tests * fix: added input_data to ExecHints instead of SubModels * docs: updated README.md * feat: chunk CLI parameter --------- Co-authored-by: aldbr <aldbr@outlook.com>
1 parent db72f9b commit ce862b2

20 files changed

Lines changed: 408 additions & 167 deletions

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,17 @@ Inside the Pixi environment:
7070
pixi shell
7171

7272
# Submit
73-
dirac-cwl job submit <workflow_path> [--parameter-path <input_path>] [--metadata-path <metadata_path>]
73+
dirac-cwl job submit <workflow_path> [--input-files <input_path>]
7474

75-
dirac-cwl transformation submit <workflow_path> [--metadata-path <metadata_path>]
75+
dirac-cwl transformation submit <workflow_path> [--inputs-file <input_path>]
7676

77-
dirac-cwl production submit <workflow_path> [--steps-metadata-path <steps_metadata_path>]
77+
dirac-cwl production submit <workflow_path> [--inputs-file <input_path>]
7878
```
7979

8080
Or prefix individual commands:
8181

8282
```bash
83-
pixi run dirac-cwl job submit <workflow_path> --parameter-path <input_path>
83+
pixi run dirac-cwl job submit <workflow_path> --input-files <input_path>
8484
```
8585

8686
Common tasks are defined in `pyproject.toml` and can be run with Pixi:

pixi.lock

Lines changed: 2 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

scripts/generate_schemas.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,19 @@
2222
import json
2323
import logging
2424
from pathlib import Path
25-
from typing import Any, Dict, List, cast
25+
from typing import Any, Dict, List, Type, cast
2626

2727
import yaml
28+
from pydantic import BaseModel
2829

2930
# Configure logging
3031
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
3132
logger = logging.getLogger(__name__)
3233

3334

34-
def collect_pydantic_models() -> Dict[str, Any]:
35+
def collect_pydantic_models() -> Dict[str, Type[BaseModel]]:
3536
"""Collect all Pydantic models from the metadata system."""
36-
models = {}
37+
models: Dict[str, Type[BaseModel]] = {}
3738

3839
# Import core models
3940
try:
@@ -44,14 +45,10 @@ def collect_pydantic_models() -> Dict[str, Any]:
4445
TransformationExecutionHooksHint,
4546
)
4647

47-
models.update(
48-
{
49-
"ExecutionHooksBasePlugin": ExecutionHooksBasePlugin,
50-
"ExecutionHooks": ExecutionHooksHint,
51-
"Scheduling": SchedulingHint,
52-
"TransformationExecutionHooks": TransformationExecutionHooksHint,
53-
}
54-
)
48+
models["ExecutionHooksBasePlugin"] = ExecutionHooksBasePlugin
49+
models["ExecutionHooks"] = ExecutionHooksHint
50+
models["Scheduling"] = SchedulingHint
51+
models["TransformationExecutionHooks"] = TransformationExecutionHooksHint
5552
logger.info("Collected core metadata models")
5653
except ImportError as e:
5754
logger.error("Failed to import core models: %s", e)
@@ -65,14 +62,10 @@ def collect_pydantic_models() -> Dict[str, Any]:
6562
TransformationSubmissionModel,
6663
)
6764

68-
models.update(
69-
{
70-
"JobInputModel": JobInputModel,
71-
"JobSubmissionModel": JobSubmissionModel,
72-
"TransformationSubmissionModel": TransformationSubmissionModel,
73-
"ProductionSubmissionModel": ProductionSubmissionModel,
74-
}
75-
)
65+
models["JobInputModel"] = JobInputModel
66+
models["JobSubmissionModel"] = JobSubmissionModel
67+
models["TransformationSubmissionModel"] = TransformationSubmissionModel
68+
models["ProductionSubmissionModel"] = ProductionSubmissionModel
7669
logger.info("Collected submission models")
7770
except ImportError as e:
7871
logger.error("Failed to import submission models: %s", e)

src/dirac_cwl/execution_hooks/core.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
Union,
2424
)
2525

26+
from cwl_utils.parser import File, save
2627
from DIRAC.DataManagementSystem.Client.DataManager import ( # type: ignore[import-untyped]
2728
DataManager,
2829
)
2930
from DIRACCommon.Core.Utilities.ReturnValues import ( # type: ignore[import-untyped]
3031
returnSingleResult,
3132
)
32-
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
33+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator
3334

3435
from dirac_cwl.commands import PostProcessCommand, PreProcessCommand
3536
from dirac_cwl.mocks.data_manager import MockDataManager
@@ -139,12 +140,12 @@ async def store_output(
139140
if res and not res["OK"]:
140141
raise RuntimeError(f"Could not save file {src} with LFN {str(lfn)} : {res['Message']}")
141142

142-
def get_input_query(self, input_name: str, **kwargs: Any) -> Union[Path, List[Path], None]:
143+
def get_input_query(self, **kwargs: Any) -> Union[Path, List[Path], None]:
143144
"""Generate LFN-based input query path.
144145
145146
Accepts and ignores extra kwargs for interface compatibility.
146147
"""
147-
# Build LFN: /query_root/vo/campaign/site/data_type/input_name
148+
# Build LFN: /query_root/vo/campaign/site/data_type
148149
pass
149150

150151
@classmethod
@@ -332,10 +333,40 @@ def from_cwl(cls, cwl_object: Any) -> Self:
332333
descriptor = descriptor.model_copy(update=hint_data)
333334
return descriptor
334335

336+
@classmethod
337+
def update_cwl(cls, cwl_object: Any, descriptor: Self) -> None:
338+
"""Update CWL object with metadata descriptor."""
339+
hints = getattr(cwl_object, "hints", []) or []
340+
for hint in hints:
341+
if hint.get("class") == "dirac:ExecutionHooks":
342+
hint.update(descriptor.model_dump())
343+
return
344+
345+
# Create the ExecutionHooks hint if it doesn't exist
346+
hints.append({"class": "dirac:ExecutionHooks", **descriptor.model_dump()})
347+
cwl_object.hints = hints
348+
335349

336350
class TransformationExecutionHooksHint(ExecutionHooksHint):
337351
"""Extended data manager for transformations."""
338352

339-
group_size: Optional[Dict[str, int]] = Field(
340-
default=None, description="Input grouping configuration for transformation jobs"
353+
group_size: Optional[int] = Field(default=None, description="Number of input files per job")
354+
input_data: Optional[Dict[str, List[str]]] = Field(
355+
default=None, description="Static input file lists, keyed by CWL input parameter name"
341356
)
357+
input_query: Optional[Dict] = Field(default=None, description="Dynamic input query for transformation jobs")
358+
359+
@field_validator("input_data", mode="before")
360+
@classmethod
361+
def convert_input_data(cls, value):
362+
"""Convert an input data dict containing a list of Files to a list of strings.
363+
364+
:param value: Input data dict to convert.
365+
:return: Converted input data dict.
366+
"""
367+
if value:
368+
return {
369+
key: [save(item)["location"] if isinstance(item, File) else item for item in items]
370+
for key, items in value.items()
371+
}
372+
return None

src/dirac_cwl/execution_hooks/plugins/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ class QueryBasedPlugin(ExecutionHooksBasePlugin):
3434
campaign: Optional[str] = Field(default=None, description="Campaign name for LFN path")
3535
data_type: Optional[str] = Field(default=None, description="Data type classification")
3636

37-
def get_input_query(self, input_name: str, **kwargs: Any) -> Union[Path, List[Path], None]:
37+
def get_input_query(self, **kwargs: Any) -> Union[Path, List[Path], None]:
3838
"""Generate LFN-based input query path.
3939
4040
Accepts and ignores extra kwargs for interface compatibility.
4141
"""
42-
# Build LFN: /query_root/vo/campaign/site/data_type/input_name
42+
# Build LFN: /query_root/vo/campaign/site/data_type
4343
path_parts = []
4444

4545
if self.vo:
@@ -53,6 +53,6 @@ def get_input_query(self, input_name: str, **kwargs: Any) -> Union[Path, List[Pa
5353
path_parts.append(self.data_type)
5454

5555
if len(path_parts) > 0: # More than just VO
56-
return Path(self.query_root) / Path(*path_parts) / Path(input_name)
56+
return Path(self.query_root) / Path(*path_parts)
5757

58-
return Path(self.query_root) / Path(input_name)
58+
return Path(self.query_root)

src/dirac_cwl/job/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from cwl_utils.parser.cwl_v1_2 import (
1414
File,
1515
)
16-
from cwl_utils.parser.cwl_v1_2_utils import load_inputfile
16+
from cwl_utils.parser.utils import load_inputfile
1717
from diracx.cli.utils import AsyncTyper
1818
from rich import print_json
1919
from rich.console import Console
@@ -40,7 +40,7 @@
4040
@app.async_command("submit")
4141
async def submit_job_client(
4242
task_path: str = typer.Argument(..., help="Path to the CWL file"),
43-
parameter_path: list[str] | None = typer.Option(None, help="Path to the files containing the metadata"),
43+
input_files: list[str] | None = typer.Option(None, help="Paths to the CWL input files"),
4444
# Specific parameter for the purpose of the prototype
4545
local: bool | None = typer.Option(True, help="Run the job locally instead of submitting it to the router"),
4646
):
@@ -70,35 +70,35 @@ async def submit_job_client(
7070
console.print(f"\t[green]:heavy_check_mark:[/green] Task {task_path}")
7171
console.print("\t[green]:heavy_check_mark:[/green] Hints")
7272

73-
# Extract parameters if any
74-
parameters = []
75-
if parameter_path:
76-
for parameter_p in parameter_path:
73+
# Extract inputs if any
74+
inputs = []
75+
if input_files:
76+
for file in input_files:
7777
try:
78-
parameter = load_inputfile(parameter_p)
78+
input_file = load_inputfile(task.cwlVersion, file)
7979
except Exception as ex:
8080
console.print(
81-
f"[red]:heavy_multiplication_x:[/red] [bold]CLI:[/bold] Failed to validate the parameter:\n{ex}"
81+
f"[red]:heavy_multiplication_x:[/red] [bold]CLI:[/bold] Failed to validate the input file:\n{ex}"
8282
)
8383
return typer.Exit(code=1)
8484

8585
# Prepare files for the ISB
86-
isb_file_paths = prepare_input_sandbox(parameter)
86+
isb_file_paths = prepare_input_sandbox(input_file)
8787

88-
# Upload parameter sandbox
88+
# Upload input file sandbox
8989
sandbox_id = await submission_client.create_sandbox(isb_file_paths)
9090

91-
parameters.append(
91+
inputs.append(
9292
JobInputModel(
9393
sandbox=[sandbox_id] if sandbox_id else None,
94-
cwl=parameter,
94+
cwl=input_file,
9595
)
9696
)
97-
console.print(f"\t[green]:heavy_check_mark:[/green] Parameter {parameter_p}")
97+
console.print(f"\t[green]:heavy_check_mark:[/green] File {file}")
9898

9999
job = JobSubmissionModel(
100100
task=task,
101-
inputs=parameters,
101+
inputs=inputs,
102102
)
103103
console.print("[green]:heavy_check_mark:[/green] [bold]CLI:[/bold] Job(s) validated.")
104104

src/dirac_cwl/job/job_wrapper_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import DIRAC # type: ignore[import-untyped]
1212
from cwl_utils.parser import load_document_by_uri
13-
from cwl_utils.parser.cwl_v1_2_utils import load_inputfile
13+
from cwl_utils.parser.utils import load_inputfile
1414
from ruamel.yaml import YAML
1515

1616
if os.getenv("DIRAC_PROTO_LOCAL") != "1":
@@ -41,7 +41,7 @@ async def main():
4141
task_obj = load_document_by_uri(f.name)
4242

4343
if job_model_dict["input"]:
44-
cwl_inputs_obj = load_inputfile(job_model_dict["input"]["cwl"])
44+
cwl_inputs_obj = load_inputfile(task_obj.cwlVersion, job_model_dict["input"]["cwl"])
4545
job_model_dict["input"]["cwl"] = cwl_inputs_obj
4646
job_model_dict["task"] = task_obj
4747

src/dirac_cwl/production/__init__.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
WorkflowInputParameter,
2121
WorkflowStep,
2222
)
23+
from cwl_utils.parser.utils import load_inputfile
2324
from rich import print_json
2425
from rich.console import Console
2526
from schema_salad.exceptions import ValidationException
@@ -74,6 +75,8 @@
7475
@app.command("submit")
7576
def submit_production_client(
7677
task_path: str = typer.Argument(..., help="Path to the CWL file"),
78+
inputs_file: str | None = typer.Option(None, help="Path to the CWL inputs file"),
79+
chunk: str | None = typer.Option(None, help="Split an array input into jobs: PARAM=SIZE (e.g., input-data=3)"),
7780
# Specific parameter for the purpose of the prototype
7881
local: Optional[bool] = typer.Option(True, help="Run the job locally instead of submitting it to the router"),
7982
):
@@ -86,10 +89,49 @@ def submit_production_client(
8689
"""
8790
os.environ["DIRAC_PROTO_LOCAL"] = "0"
8891

92+
# --chunk and --inputs-file must be used together
93+
if chunk and not inputs_file:
94+
console.print("[red]:heavy_multiplication_x:[/red] [bold]CLI:[/bold] --chunk requires --inputs-file.")
95+
return typer.Exit(code=1)
96+
if inputs_file and not chunk:
97+
console.print("[red]:heavy_multiplication_x:[/red] [bold]CLI:[/bold] --inputs-file requires --chunk.")
98+
return typer.Exit(code=1)
99+
89100
# Validate the workflow
90101
console.print("[blue]:information_source:[/blue] [bold]CLI:[/bold] Validating the production...")
91102
try:
92103
task = load_document(pack(task_path))
104+
105+
# Load Production inputs and inject into the first step's hint
106+
if inputs_file and chunk:
107+
from dirac_cwl.transformation import _parse_chunk
108+
109+
all_inputs = load_inputfile(task.cwlVersion, inputs_file)
110+
chunk_param, chunk_size = _parse_chunk(chunk)
111+
112+
if chunk_param not in all_inputs:
113+
console.print(
114+
f"[red]:heavy_multiplication_x:[/red] [bold]CLI:[/bold] "
115+
f"Parameter '{chunk_param}' not found in inputs file. "
116+
f"Available parameters: {list(all_inputs.keys())}"
117+
)
118+
return typer.Exit(code=1)
119+
if not isinstance(all_inputs[chunk_param], list):
120+
console.print(
121+
f"[red]:heavy_multiplication_x:[/red] [bold]CLI:[/bold] "
122+
f"Parameter '{chunk_param}' must be an array type for --chunk."
123+
)
124+
return typer.Exit(code=1)
125+
126+
input_data = {chunk_param: all_inputs[chunk_param]}
127+
128+
# Inject into the first step's hint
129+
if task.steps:
130+
from dirac_cwl.execution_hooks import TransformationExecutionHooksHint
131+
132+
hint_update = TransformationExecutionHooksHint(group_size=chunk_size, input_data=input_data)
133+
TransformationExecutionHooksHint.update_cwl(task.steps[0].run, hint_update)
134+
93135
except FileNotFoundError as ex:
94136
console.print(f"[red]:heavy_multiplication_x:[/red] [bold]CLI:[/bold] Failed to load the task:\n{ex}")
95137
return typer.Exit(code=1)
@@ -99,11 +141,10 @@ def submit_production_client(
99141
console.print(f"\t[green]:heavy_check_mark:[/green] Task {task_path}")
100142
console.print("\t[green]:heavy_check_mark:[/green] Metadata")
101143

102-
# Create the production
103144
production = ProductionSubmissionModel(task=task)
104145
console.print("[green]:heavy_check_mark:[/green] [bold]CLI:[/bold] Production validated.")
105146

106-
# Submit the tranaformation
147+
# Submit the transformation
107148
console.print("[blue]:information_source:[/blue] [bold]CLI:[/bold] Submitting the production...")
108149
print_json(production.model_dump_json(indent=4))
109150
if not submit_production_router(production):
@@ -161,12 +202,7 @@ def _get_transformations(
161202

162203
for step in production.task.steps:
163204
step_task = _create_subworkflow(step, str(production.task.cwlVersion), production.task.inputs)
164-
165-
transformations.append(
166-
TransformationSubmissionModel(
167-
task=step_task,
168-
)
169-
)
205+
transformations.append(TransformationSubmissionModel(task=step_task))
170206
return transformations
171207

172208

src/dirac_cwl/submission_models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ class JobModel(BaseJobModel):
102102
# -----------------------------------------------------------------------------
103103
# Transformation models
104104
# -----------------------------------------------------------------------------
105-
106-
107105
class TransformationSubmissionModel(BaseModel):
108106
"""Transformation definition sent to the router."""
109107

0 commit comments

Comments
 (0)