diff --git a/docs/source/model.md b/docs/source/model.md index 6661adf3d..9e6cc9f84 100644 --- a/docs/source/model.md +++ b/docs/source/model.md @@ -44,7 +44,7 @@ Description of model evaluation | `pipeline_name` | `Optional[str]` | Pipeline name (Pipeline names must exist in Processing.pipelines) | | `start_date_time` | `datetime (timezone-aware)` | Start date time | | `end_date_time` | `Optional[datetime (timezone-aware)]` | End date time | -| `output_path` | `Optional[AssetPath]` | Output path (Path to processing outputs, if stored.) | +| `output_path` | `Optional[List[AssetPath]]` | Output path (Path to processing outputs, if stored.) | | `output_parameters` | `Optional[dict]` | Outputs (Output parameters) | | `notes` | `Optional[str]` | Notes | | `resources` | Optional[[ResourceUsage](processing.md#resourceusage)] | Process resource usage | @@ -76,7 +76,7 @@ Description of model training | `pipeline_name` | `Optional[str]` | Pipeline name (Pipeline names must exist in Processing.pipelines) | | `start_date_time` | `datetime (timezone-aware)` | Start date time | | `end_date_time` | `Optional[datetime (timezone-aware)]` | End date time | -| `output_path` | `Optional[AssetPath]` | Output path (Path to processing outputs, if stored.) | +| `output_path` | `Optional[List[AssetPath]]` | Output path (Path to processing outputs, if stored.) | | `output_parameters` | `Optional[dict]` | Outputs (Output parameters) | | `notes` | `Optional[str]` | Notes | | `resources` | Optional[[ResourceUsage](processing.md#resourceusage)] | Process resource usage | diff --git a/docs/source/processing.md b/docs/source/processing.md index 3f990eb26..920ac9a43 100644 --- a/docs/source/processing.md +++ b/docs/source/processing.md @@ -43,7 +43,7 @@ Description of a single processing step | `pipeline_name` | `Optional[str]` | Pipeline name (Pipeline names must exist in Processing.pipelines) | | `start_date_time` | `datetime (timezone-aware)` | Start date time | | `end_date_time` | `Optional[datetime (timezone-aware)]` | End date time | -| `output_path` | `Optional[AssetPath]` | Output path (Path to processing outputs, if stored.) | +| `output_path` | `Optional[List[AssetPath]]` | Output path (Path to processing outputs, if stored.) | | `output_parameters` | `Optional[dict]` | Outputs (Output parameters) | | `notes` | `Optional[str]` | Notes | | `resources` | Optional[[ResourceUsage](processing.md#resourceusage)] | Process resource usage | diff --git a/examples/model.py b/examples/model.py index 438d57f02..63a6b62e6 100644 --- a/examples/model.py +++ b/examples/model.py @@ -46,7 +46,7 @@ "augmentation": True, }, ), - output_path="./trained_model.h5", + output_path=["./trained_model.h5"], start_date_time=now, end_date_time=now, train_performance=[ diff --git a/examples/processing.py b/examples/processing.py index 42b26c733..1592c89a3 100644 --- a/examples/processing.py +++ b/examples/processing.py @@ -61,7 +61,7 @@ stage=ProcessStage.PROCESSING, start_date_time=t, end_date_time=t, - output_path="/path/to/outputs", + output_path=["./path/to/output"], pipeline_name="Imaging processing pipeline", code=example_code.model_copy( update=dict( @@ -90,7 +90,7 @@ stage=ProcessStage.PROCESSING, start_date_time=t, end_date_time=t, - output_path="/path/to/outputs", + output_path=["./path/to/output"], code=example_code.model_copy( update=dict( parameters={"u": 7, "z": True}, @@ -104,7 +104,7 @@ stage=ProcessStage.PROCESSING, start_date_time=t, end_date_time=t, - output_path="/path/to/output", + output_path=["./path/to/output"], code=example_code.model_copy( update=dict( parameters={"a": 2, "b": -2}, @@ -117,7 +117,7 @@ process_type=ProcessName.ANALYSIS, start_date_time=t, end_date_time=t, - output_path="/path/to/outputs", + output_path=["./path/to/output"], code=example_code.model_copy( update=dict( parameters={"size": 7}, @@ -131,7 +131,7 @@ process_type=ProcessName.ANALYSIS, start_date_time=t, end_date_time=t, - output_path="/path/to/outputs", + output_path=["./path/to/output"], code=example_code.model_copy( update=dict( parameters={"u": 7, "z": True}, diff --git a/src/aind_data_schema/core/processing.py b/src/aind_data_schema/core/processing.py index dfeb31490..87af23ed1 100644 --- a/src/aind_data_schema/core/processing.py +++ b/src/aind_data_schema/core/processing.py @@ -68,13 +68,23 @@ class DataProcess(DataModel): end_date_time: Optional[Annotated[AwareDatetimeWithDefault, TimeValidation.AFTER]] = Field( default=None, title="End date time" ) - output_path: Optional[AssetPath] = Field( + output_path: Optional[List[AssetPath]] = Field( default=None, title="Output path", description="Path to processing outputs, if stored." ) output_parameters: Optional[GenericModel] = Field(default=None, description="Output parameters", title="Outputs") notes: Optional[str] = Field(default=None, title="Notes", validate_default=True) resources: Optional[ResourceUsage] = Field(default=None, title="Process resource usage") + @field_validator("output_path", mode="before") + def validate_output_path(cls, value) -> Optional[List[AssetPath]]: + """Validator for output_path to ensure it's a list even if a single path is provided + """ + if value is None: + return value + if not isinstance(value, list): + value = [value] + return [AssetPath(path) for path in value] + @field_validator("notes", mode="after") def validate_other(cls, value: Optional[str], info: ValidationInfo) -> Optional[str]: """Validator for other/notes""" diff --git a/tests/test_composability_merge.py b/tests/test_composability_merge.py index a8c991685..b201ce3b3 100644 --- a/tests/test_composability_merge.py +++ b/tests/test_composability_merge.py @@ -254,7 +254,7 @@ def test_add_processing_objects(self): experimenters=["Dr. Dan"], process_type=ProcessName.ANALYSIS, stage=ProcessStage.PROCESSING, - output_path="/path/to/outputs1", + output_path=["./path/to/output"], start_date_time=t, end_date_time=t, code=Code( @@ -272,7 +272,7 @@ def test_add_processing_objects(self): experimenters=["Dr. Jane"], process_type=ProcessName.COMPRESSION, stage=ProcessStage.PROCESSING, - output_path="/path/to/outputs2", + output_path=["./path/to/output"], start_date_time=t, end_date_time=t, code=Code( diff --git a/tests/test_imaging.py b/tests/test_imaging.py index 26e42bcc1..8d1b3efde 100644 --- a/tests/test_imaging.py +++ b/tests/test_imaging.py @@ -113,7 +113,7 @@ def test_registration(self): experimenters=["Dr. Dan"], start_date_time=datetime.now(tz=timezone.utc), end_date_time=datetime.now(tz=timezone.utc), - output_path="/some/path", + output_path=["./some/path"], code=Code( url="https://github.com/abcd", parameters=parameters, diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 3dc2f8fc1..2eb42470e 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -99,7 +99,7 @@ def setUpClass(cls) -> None: experimenters=["Dr. Dan"], process_type=ProcessName.ANALYSIS, stage=ProcessStage.ANALYSIS, - output_path="/path/to/outputs", + output_path=["./path/to/output"], start_date_time=t, end_date_time=t, code=Code( @@ -828,7 +828,7 @@ def test_validate_time_constraints_processing(self): experimenters=["Dr. Dan"], process_type=ProcessName.ANALYSIS, stage=ProcessStage.ANALYSIS, - output_path="/path/to/outputs", + output_path=["./path/to/output"], start_date_time=datetime(2023, 4, 3, 20, 0, 0, tzinfo=timezone.utc), # After acquisition end_date_time=datetime(2023, 4, 3, 21, 0, 0, tzinfo=timezone.utc), code=Code( @@ -855,7 +855,7 @@ def test_validate_time_constraints_processing(self): experimenters=["Dr. Dan"], process_type=ProcessName.ANALYSIS, stage=ProcessStage.ANALYSIS, - output_path="/path/to/outputs", + output_path=["./path/to/output"], start_date_time=datetime(2023, 4, 3, 17, 0, 0, tzinfo=timezone.utc), # Before acquisition start end_date_time=datetime(2023, 4, 3, 21, 0, 0, tzinfo=timezone.utc), code=Code( diff --git a/tests/test_processing.py b/tests/test_processing.py index 7f59defdc..4fccbb3c6 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -8,6 +8,7 @@ from aind_data_schema_models.units import MemoryUnit from aind_data_schema.components.identifiers import Code, DataAsset +from aind_data_schema.components.wrappers import AssetPath from aind_data_schema.core.processing import ( DataProcess, Processing, @@ -118,6 +119,57 @@ def test_resource_usage_unit_validators(self): # p = Processing(data_processes=[]) # self.assertIsNotNone(p) + def _make_data_process(self, **kwargs) -> DataProcess: + """Helper to create a minimal DataProcess with overridable fields""" + defaults = dict( + experimenters=["Dr. Dan"], + process_type=ProcessName.COMPRESSION, + stage=ProcessStage.PROCESSING, + code=code, + start_date_time=t, + end_date_time=t, + ) + defaults.update(kwargs) + return DataProcess(**defaults) + + def test_output_path_none(self): + """output_path defaults to None when not provided""" + dp = self._make_data_process() + self.assertIsNone(dp.output_path) + + def test_output_path_single_string(self): + """A single string is wrapped into a one-element list of AssetPath""" + dp = self._make_data_process(output_path="./outputs/result.nwb") + self.assertIsInstance(dp.output_path, list) + self.assertEqual(len(dp.output_path), 1) + self.assertIsInstance(dp.output_path[0], AssetPath) + self.assertEqual(str(dp.output_path[0]), "outputs/result.nwb") + + def test_output_path_single_asset_path(self): + """A single AssetPath is wrapped into a one-element list""" + dp = self._make_data_process(output_path=AssetPath("outputs/result.nwb")) + self.assertIsInstance(dp.output_path, list) + self.assertEqual(len(dp.output_path), 1) + self.assertIsInstance(dp.output_path[0], AssetPath) + + def test_output_path_list_of_strings(self): + """A list of strings is converted to a list of AssetPath""" + paths = ["outputs/a.nwb", "outputs/b.nwb"] + dp = self._make_data_process(output_path=paths) + self.assertIsInstance(dp.output_path, list) + self.assertEqual(len(dp.output_path), 2) + for item in dp.output_path: + self.assertIsInstance(item, AssetPath) + self.assertEqual([str(p) for p in dp.output_path], paths) + + def test_output_path_list_of_asset_paths(self): + """A list of AssetPath objects passes through unchanged""" + paths = [AssetPath("outputs/a.nwb"), AssetPath("outputs/b.nwb")] + dp = self._make_data_process(output_path=paths) + self.assertEqual(len(dp.output_path), 2) + for item in dp.output_path: + self.assertIsInstance(item, AssetPath) + def test_unique_process_names(self): """Test that process names are unique within a Processing object"""