From 0846c5ff6f9aa817230ca26dc94fe8b4f8937d8a Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 4 Jun 2026 20:46:58 +0200 Subject: [PATCH 01/11] :rocket: improve validation for ProblemDefinition and Infos --- docs/source/concepts/check.md | 3 + docs/source/concepts/dataset.md | 8 +- docs/source/concepts/infos.md | 12 +- docs/source/concepts/problem_definition.md | 40 +-- docs/source/tutorials/storage.md | 12 +- docs/zensical.toml | 3 +- examples/infos_example.py | 8 +- examples/problem_definition_example.py | 37 +-- src/plaid/cli/plaidcheck.py | 139 +++++++++-- src/plaid/infos.py | 34 +-- src/plaid/problem_definition.py | 147 +++++------ src/plaid/storage/common/reader.py | 14 +- src/plaid/storage/writer.py | 8 +- tests/cli/test_plaidcheck.py | 98 ++++++++ tests/conftest.py | 2 + tests/containers/test_utils.py | 4 + tests/storage/test_cgns_init.py | 14 +- tests/storage/test_hf_datasets_init.py | 8 +- tests/storage/test_storage.py | 12 +- tests/storage/test_zarr_init.py | 8 +- tests/test_problem_definition.py | 276 ++++++++------------- tests/utils/test_info.py | 75 ++++-- 22 files changed, 585 insertions(+), 377 deletions(-) create mode 100644 docs/source/concepts/check.md diff --git a/docs/source/concepts/check.md b/docs/source/concepts/check.md new file mode 100644 index 00000000..f387297d --- /dev/null +++ b/docs/source/concepts/check.md @@ -0,0 +1,3 @@ +# Dataset check + +`plaid-check` is a \ No newline at end of file diff --git a/docs/source/concepts/dataset.md b/docs/source/concepts/dataset.md index a455ac51..270bb5bf 100644 --- a/docs/source/concepts/dataset.md +++ b/docs/source/concepts/dataset.md @@ -120,7 +120,13 @@ from plaid import ProblemDefinition from plaid.infos import DataDescription, Infos, Legal from plaid.storage import save_to_disk -pb_def = ProblemDefinition(name="regression_1") +pb_def = ProblemDefinition( + name="regression_1", + input_features=["Global/input"], + output_features=["Base/Zone/VertexFields/pressure"], + train_split={"train": [0, 1, 2]}, + test_split={"test": "all"}, +) infos = Infos( legal=Legal(owner="CompanyX", license="proprietary"), data_description=DataDescription(number_of_samples=3), diff --git a/docs/source/concepts/infos.md b/docs/source/concepts/infos.md index 06019e2e..5a0889b3 100644 --- a/docs/source/concepts/infos.md +++ b/docs/source/concepts/infos.md @@ -14,8 +14,8 @@ In the current API, infos stores: hardware, contact, or location - `data_description`, for optional dataset description entries such as the number of samples, DOE, inputs, and outputs -- `num_samples`, as a dictionary keyed by split name -- `storage_backend` +- `num_samples`, as a required dictionary keyed by split name +- `storage_backend`, as a required storage backend identifier ## Basic usage @@ -24,6 +24,8 @@ from plaid.infos import DataProduction, Infos, Legal infos = Infos( legal=Legal(owner="Safran", license="proprietary"), + num_samples={"train": 10, "test": 5}, + storage_backend="zarr", data_production=DataProduction( type="simulation", physics="fluid dynamics", @@ -42,6 +44,8 @@ infos = Infos.from_mapping( "owner": "Safran", "license": "proprietary", }, + "num_samples": {"train": 10, "test": 5}, + "storage_backend": "zarr", } ) ``` @@ -81,7 +85,7 @@ payload = infos.to_dict() ## Notes -- `legal.owner` and `legal.license` are required when validating complete infos. -- `num_samples` and `storage_backend` are automatically filled when `save_to_disk(..., infos=...)` is called before writing `infos.yaml`. +- `legal.owner`, `legal.license`, `num_samples`, and `storage_backend` are required when validating complete infos. +- `num_samples` and `storage_backend` are overwritten with the actual saved dataset values when `save_to_disk(..., infos=...)` is called before writing `infos.yaml`. - Unknown keys are rejected during validation. - `save_to_file(...)` writes YAML using the standard infos key order. diff --git a/docs/source/concepts/problem_definition.md b/docs/source/concepts/problem_definition.md index 239b384b..ee6bb01c 100644 --- a/docs/source/concepts/problem_definition.md +++ b/docs/source/concepts/problem_definition.md @@ -8,39 +8,41 @@ title: Problem definition In the current API, a problem definition stores: -- `name` -- `input_features` (`list[str]`) -- `output_features` (`list[str]`) -- `train_split` and `test_split` +- `name` (`str`, required) +- `input_features` (`list[str]`, required and non-empty) +- `output_features` (`list[str]`, required and non-empty) +- `train_split` and `test_split` (required) ## Basic usage ```python from plaid import ProblemDefinition -pb = ProblemDefinition(name="regression_1") - -pb.add_input_features([ - "Base/Zone/GridCoordinates/CoordinateX", - "Base/Zone/GridCoordinates/CoordinateY", -]) -pb.add_output_features([ - "Base/Zone/VertexFields/pressure", -]) - -pb.train_split = {"train": [0, 1, 2]} -pb.test_split = {"test": [3, 4]} +pb = ProblemDefinition( + name="regression_1", + input_features=[ + "Base/Zone/GridCoordinates/CoordinateX", + "Base/Zone/GridCoordinates/CoordinateY", + ], + output_features=[ + "Base/Zone/VertexFields/pressure", + ], + train_split={"train": [0, 1, 2]}, + test_split={"test": [3, 4]}, +) ``` Feature lists are normalized by the model: entries are converted to strings, -sorted, and checked for duplicates. +sorted, checked for duplicates, and rejected if empty. ## Loading from disk Load a definition from a dataset path: ```python -pb = ProblemDefinition.from_path("/path/to/plaid_dataset", name="regression_1") +pb = ProblemDefinition.from_path( + "/path/to/plaid_dataset/problem_definitions/regression_1.yaml" +) ``` At storage level, problem definitions are loaded as a dictionary keyed by name: @@ -66,4 +68,4 @@ pb.save_to_file("problem_definitions/regression_1.yaml") - Splits are represented by `train_split` and `test_split` dictionaries. - Split values can be explicit index sequences or the string `"all"`. - `add_input_features(...)` and `add_output_features(...)` accept either a - single string or a sequence of strings. \ No newline at end of file + single string or a sequence of strings after initialization. \ No newline at end of file diff --git a/docs/source/tutorials/storage.md b/docs/source/tutorials/storage.md index b7ef5500..e826c956 100644 --- a/docs/source/tutorials/storage.md +++ b/docs/source/tutorials/storage.md @@ -123,11 +123,13 @@ output_features = [ ] -pb_def = ProblemDefinition(name="regression_1") -pb_def.add_input_features(input_features) -pb_def.add_output_features(output_features) -pb_def.train_split = {"train":"all"} -pb_def.test_split = {"test":"all"} +pb_def = ProblemDefinition( + name="regression_1", + input_features=input_features, + output_features=output_features, + train_split={"train": "all"}, + test_split={"test": "all"}, +) #--------------------------------------------------------------- # Define a simple function that takes a single identifier and returns a Sample. diff --git a/docs/zensical.toml b/docs/zensical.toml index 068031ee..87591ece 100644 --- a/docs/zensical.toml +++ b/docs/zensical.toml @@ -55,7 +55,8 @@ nav = [ { "Downloadable samples" = "notebooks/downloadable_example/sample_example.md" }, { "Conversion tutorial" = "tutorials/storage.md" }, ] }, - { "Viewer" = "concepts/viewer.md" }, + { "`plaid-viewer`" = "concepts/viewer.md" }, + { "`plaid-check`" = "concepts/check.md" }, # >>> AUTO-GENERATED API REFERENCE START # The block below is overwritten by `python docs/generate_api_stubs.py`. # Edit that script (or the markers) instead of changing this section by hand. diff --git a/examples/infos_example.py b/examples/infos_example.py index a1dc2214..cd7ed37b 100644 --- a/examples/infos_example.py +++ b/examples/infos_example.py @@ -50,6 +50,8 @@ print("#---# Infos") infos = Infos( legal=Legal(owner="PLAID", license="MIT"), + num_samples={"train": 2, "test": 2}, + storage_backend="cgns", ) print(f"{infos = }") @@ -64,6 +66,8 @@ "license": "MIT", }, "data_description": "Example metadata for a PLAID dataset.", + "num_samples": {"train": 2, "test": 2}, + "storage_backend": "cgns", } ) print(f"{infos_from_mapping = }") @@ -101,8 +105,8 @@ # %% infos.data_description = "Example dataset generated for the Infos example." -infos.num_samples = {"train": 2, "test": 2} -infos.storage_backend = "cgns" +infos.num_samples = {"train": 3, "test": 1} +infos.storage_backend = "zarr" print(f"{infos.data_description = }") print(f"{infos.num_samples = }") diff --git a/examples/problem_definition_example.py b/examples/problem_definition_example.py index 42a5bfef..0e1a4d7b 100644 --- a/examples/problem_definition_example.py +++ b/examples/problem_definition_example.py @@ -18,7 +18,7 @@ # # This Jupyter Notebook demonstrates the usage of the ProblemDefinition class for defining machine learning problems using the PLAID library. It includes examples of: # -# 1. Initializing an empty ProblemDefinition +# 1. Initializing a complete ProblemDefinition # 2. Configuring problem characteristics and retrieve data # 3. Saving and loading problem definitions # @@ -37,38 +37,43 @@ from plaid import ProblemDefinition # %% [markdown] -# ## Section 1: Initializing an Empty ProblemDefinition +# ## Section 1: Initializing a ProblemDefinition # # This section demonstrates how to initialize a ProblemDefinition and add # input/output feature identifiers with the current API. # %% [markdown] -# ### Initialize and print ProblemDefinition +# ### Initialize feature identifiers # %% -print("#---# Empty ProblemDefinition") -problem = ProblemDefinition() -print(f"{problem = }") - -# %% -# ### Initialize some feature identifiers scalar_1_feat_id = "Global/scalar_1" scalar_2_feat_id = "Global/scalar_2" scalar_3_feat_id = "Global/scalar_3" field_1_feat_id = "Base_2_2/Zone/CellCenterFields/field_1" field_2_feat_id = "Base_2_2/Zone/VertexFields/field_2" +# %% +print("#---# ProblemDefinition") +problem = ProblemDefinition( + name="my_problem_definition", + input_features=[scalar_3_feat_id, field_1_feat_id], + output_features=[field_2_feat_id], + train_split={"train": [0, 1]}, + test_split={"test": [2, 3]}, +) +print(f"{problem = }") + # %% [markdown] # ### Add inputs / outputs to a Problem Definition # %% -# Add unique input and output feature identifiers +# Add unique input and output feature identifiers after initialization #problem.add_input_features(scalar_1_feat_id) #problem.add_output_features(scalar_2_feat_id) # or Add list of input and output feature identifiers -problem.add_input_features([scalar_3_feat_id, field_1_feat_id]) -problem.add_output_features([field_2_feat_id]) +problem.add_input_features([scalar_1_feat_id]) +problem.add_output_features([scalar_2_feat_id]) print(f"{problem.input_features = }") print( @@ -84,17 +89,14 @@ # ### Set Problem Definition name # %% -problem.name = "my_problem_definition" print(f"{problem.name = }") # %% [markdown] # ### Set Problem Definition split # %% -# Current API uses `train_split` and `test_split` fields. +# Current API uses required `train_split` and `test_split` fields. # Note: each split field currently expects a dictionary with a single entry. -problem.train_split = {"train": [0, 1]} -problem.test_split = {"test": [2, 3]} print(f"{problem.train_split = }") print(f"{problem.test_split = }") @@ -136,6 +138,5 @@ # ### Load a ProblemDefinition from a YAML file # %% -problem = ProblemDefinition() -problem._load_from_file_(pb_def_save_fname) +problem = ProblemDefinition.from_path(pb_def_save_fname) print(problem) diff --git a/src/plaid/cli/plaidcheck.py b/src/plaid/cli/plaidcheck.py index a2f33f3e..b0a4b922 100644 --- a/src/plaid/cli/plaidcheck.py +++ b/src/plaid/cli/plaidcheck.py @@ -158,6 +158,81 @@ def _check_numeric_content(value: Any) -> Optional[str]: return None +def _format_missing_split_message(split: object) -> str: + """Return an actionable message for missing split declarations. + + Args: + split: Split name/key reported by a low-level ``KeyError``. + + Returns: + Human-readable explanation suitable for a checker error. + """ + return ( + f"Split {split!r} exists in the stored dataset or metadata but is missing " + "from infos.yaml > num_samples. Add this split to num_samples, or remove " + "the corresponding split data/metadata from the dataset." + ) + + +def _check_num_samples_declares_splits( + num_samples: dict[str, Any], + split_names: set[str], + report: CheckReport, +) -> None: + """Validate that ``infos.yaml > num_samples`` declares known disk splits. + + Args: + num_samples: Mapping loaded from ``infos.yaml``. + split_names: Split names discovered from storage files/metadata. + report: Report updated with missing declaration errors. + + Returns: + None. + """ + declared_splits = {str(split) for split in num_samples} + for split in sorted(split_names - declared_splits): + report.add( + "error", + "NUM_SAMPLES_MISSING_SPLIT", + "infos.yaml", + _format_missing_split_message(split), + ) + + +def _discover_split_names_from_disk( + path: Path, + backend: Optional[str], + flat_cst: dict[str, Any], + constant_schema: dict[str, Any], +) -> set[str]: + """Discover split names from files/metadata without building converters. + + Args: + path: Dataset root. + backend: Declared storage backend, if valid. + flat_cst: Flattened constants keyed by split for non-CGNS backends. + constant_schema: Constant schema keyed by split for non-CGNS backends. + + Returns: + Split names discovered from the on-disk dataset structure and metadata. + """ + split_names: set[str] = set() + data_path = path / "data" + if data_path.exists(): + if backend in {"zarr", "cgns"}: + split_names.update(p.name for p in data_path.iterdir() if p.is_dir()) + elif backend == "hf_datasets": + split_names.update(flat_cst.keys()) + split_names.update(constant_schema.keys()) + if backend != "cgns": + constants_path = path / "constants" + if constants_path.exists(): + split_names.update(p.name for p in constants_path.iterdir() if p.is_dir()) + split_names.update(flat_cst.keys()) + split_names.update(constant_schema.keys()) + return {str(split) for split in split_names} + + def _is_branch_without_data(sample: Any, path: str) -> bool: """Return True when `path` points to a branch node with no direct value. @@ -273,6 +348,25 @@ def check_dataset( if report.has_errors(): return report + # Validate top-level dataset declarations from infos.yaml before calling + # init_from_disk(), because storage initialization indexes num_samples by + # split and otherwise reports missing entries as opaque KeyError messages. + declared_backend = infos.get("storage_backend") + if not isinstance(declared_backend, str): + report.add( + "error", + "BACKEND_MISSING", + "infos.yaml", + "Missing or invalid 'storage_backend' in infos.yaml", + ) + + num_samples = infos.get("num_samples", {}) + if not isinstance(num_samples, dict): + report.add( + "error", "NUM_SAMPLES_INVALID", "infos.yaml", "'num_samples' must be a dict" + ) + num_samples = {} + # Load metadata when the backend defines it. The CGNS backend stores # self-contained samples and intentionally writes no derived metadata. if declared_backend_for_layout == "cgns": @@ -288,28 +382,29 @@ def check_dataset( report.add("error", "METADATA_READ_ERROR", str(path), str(exc)) return report - try: - datasetdict, converterdict = init_from_disk(path) - except Exception as exc: - report.add("error", "DATASET_INIT_ERROR", str(path), str(exc)) + discovered_splits = _discover_split_names_from_disk( + path, + declared_backend_for_layout, + flat_cst, + constant_schema, + ) + _check_num_samples_declares_splits(num_samples, discovered_splits, report) + if report.has_errors(): return report - # Validate top-level dataset declarations from infos.yaml. - declared_backend = infos.get("storage_backend") - if not isinstance(declared_backend, str): + try: + datasetdict, converterdict = init_from_disk(path) + except KeyError as exc: report.add( "error", - "BACKEND_MISSING", + "NUM_SAMPLES_MISSING_SPLIT", "infos.yaml", - "Missing or invalid 'storage_backend' in infos.yaml", + _format_missing_split_message(exc.args[0] if exc.args else str(exc)), ) - - num_samples = infos.get("num_samples", {}) - if not isinstance(num_samples, dict): - report.add( - "error", "NUM_SAMPLES_INVALID", "infos.yaml", "'num_samples' must be a dict" - ) - num_samples = {} + return report + except Exception as exc: + report.add("error", "DATASET_INIT_ERROR", str(path), str(exc)) + return report # Resolve the user-requested splits against the splits actually available. dataset_splits = set(datasetdict.keys()) @@ -527,9 +622,6 @@ def check_dataset( f"Out-of-range indices in {split_dict_name} (first 10): {bad[:10]}", ) - # Emit an explicit success message when no errors or warnings were found. - if not report.messages: - report.add("info", "OK", str(path), "No issue detected") return report @@ -577,8 +669,13 @@ def main(argv: Optional[list[str]] = None) -> int: if args.json: print(report.to_json()) else: - for msg in report.messages: - print(f"[{msg.severity.upper()}] {msg.code} {msg.location}: {msg.message}") + if not report.messages: + print(f"[OK] {args.path}: No issue detected") + else: + for msg in report.messages: + print( + f"[{msg.severity.upper()}] {msg.code} {msg.location}: {msg.message}" + ) counts = report.counts() print( f"Summary: {counts['error']} error(s), " diff --git a/src/plaid/infos.py b/src/plaid/infos.py index a4958057..54ca32b2 100644 --- a/src/plaid/infos.py +++ b/src/plaid/infos.py @@ -2,13 +2,12 @@ from __future__ import annotations -import copy import logging from pathlib import Path from typing import Any, Union import yaml -from pydantic import BaseModel, ConfigDict, Field, ValidationError +from pydantic import BaseModel, ConfigDict, ValidationError from pydantic.dataclasses import dataclass logger = logging.getLogger(__name__) @@ -59,26 +58,25 @@ class Infos(BaseModel): legal: Legal data_production: DataProduction | None = None data_description: str | None = None - num_samples: dict[str, int] = Field(default_factory=dict) - storage_backend: str | None = None - - @classmethod - def _normalize_top_level(cls, infos: dict[str, Any]) -> dict[str, Any]: - # Drop legacy/unsupported top-level sections silently before validation. - normalized = copy.deepcopy(infos) - normalized.pop("plaid", None) - return normalized + num_samples: dict[str, int] + storage_backend: str @classmethod def validate_authorized_only(cls, infos: dict[str, Any]) -> "Infos": """Validate schema/authorized keys without enforcing required sections.""" - normalized = cls._normalize_top_level(infos) + normalized = dict(infos) had_legal = "legal" in normalized + had_num_samples = "num_samples" in normalized + had_storage_backend = "storage_backend" in normalized if not had_legal: normalized["legal"] = { "owner": "__placeholder__", "license": "__placeholder__", } + if not had_num_samples: + normalized["num_samples"] = {} + if not had_storage_backend: + normalized["storage_backend"] = "__placeholder__" try: model = cls.model_validate(normalized) except ValidationError as exc: @@ -93,19 +91,21 @@ def validate_authorized_only(cls, infos: dict[str, Any]) -> "Infos": if not had_legal: model.legal = Legal(owner="", license="") + if not had_num_samples: + model.num_samples = {} + if not had_storage_backend: + model.storage_backend = "" return model @classmethod def validate_required_only(cls, infos: dict[str, Any]) -> None: """Validate required entries using pydantic-required fields.""" - normalized = cls._normalize_top_level(infos) - cls.model_validate(normalized) + cls.model_validate(infos) @classmethod def normalize_mapping(cls, infos: dict[str, Any]) -> dict[str, Any]: """Validate and return a normalized deep copy of infos.""" - normalized = cls._normalize_top_level(infos) - model = cls.model_validate(normalized) + model = cls.model_validate(infos) return model.model_dump(exclude_none=True) # ------------------------------------------------------------------ @@ -115,7 +115,7 @@ def normalize_mapping(cls, infos: dict[str, Any]) -> dict[str, Any]: @classmethod def from_mapping(cls, infos: dict[str, Any]) -> "Infos": """Build a validated :class:`Infos` from a plain mapping.""" - return cls.model_validate(cls._normalize_top_level(infos)) + return cls.model_validate(infos) @classmethod def from_path(cls, path: Union[str, Path]) -> "Infos": diff --git a/src/plaid/problem_definition.py b/src/plaid/problem_definition.py index 4585b09d..ff68a766 100644 --- a/src/plaid/problem_definition.py +++ b/src/plaid/problem_definition.py @@ -12,10 +12,10 @@ import logging from pathlib import Path -from typing import Any, Literal, Optional, Sequence, Union, cast +from typing import Any, Literal, Sequence, Union, cast import yaml -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, field_validator from .types import IndexArrayType @@ -35,66 +35,55 @@ class ProblemDefinition(BaseModel): revalidate_instances="always", validate_assignment=True, extra="forbid" ) - name: Optional[str] = Field(default=None) - input_features: list[str] = Field(default_factory=list) - output_features: list[str] = Field(default_factory=list) - train_split: Optional[dict[str, Sequence[int] | Literal["all"]]] = Field( - default=None - ) - test_split: Optional[dict[str, Sequence[int] | Literal["all"]]] = Field( - default=None - ) - - @staticmethod - def from_path( - path: str | Path, name: str | None = None, **overrides: Any - ) -> "ProblemDefinition": - """Load a problem definition from a YAML file located at the specified path. + name: str + input_features: list[str] + output_features: list[str] + train_split: dict[str, Sequence[int] | Literal["all"]] + test_split: dict[str, Sequence[int] | Literal["all"]] - The YAML file should contain one or more problem definitions, and the desired definition can be selected by its name. + @classmethod + def from_mapping(cls, data: dict[str, Any]) -> "ProblemDefinition": + """Build a validated :class:`ProblemDefinition` from a plain mapping. Args: - path (str | Path): The file path to the YAML file containing problem definitions. - name (str | None, optional): The name of the problem definition to load. If None, it will attempt to load the - only problem definition available in the file. Defaults to None. - **overrides: Additional keyword arguments to override specific fields in the loaded problem definition. + data: YAML-like mapping containing one problem definition. - Raises: - ValueError: If the specified name is not found in the YAML file. - RuntimeError: If multiple problem definitions are present without a specified name. + Returns: + Validated problem definition instance. + """ + return cls.model_validate(data) + + @classmethod + def from_path(cls, path: str | Path) -> "ProblemDefinition": + """Load and validate one problem definition from a YAML file. + + Args: + path: Path to the problem-definition YAML file. If no suffix is + provided, ``.yaml`` is appended. Returns: - ProblemDefinition: The loaded problem definition. + Validated problem definition instance. + + Raises: + FileNotFoundError: If the resolved YAML file does not exist. """ - from plaid.storage import load_problem_definitions_from_disk - - all_pb_def = load_problem_definitions_from_disk(path=Path(path)) - available = ", ".join(sorted(all_pb_def)) - if name is not None: - if name not in all_pb_def: - raise ValueError( - f"Problem definition '{name}' not found in {path}. " - f"Available definitions: {available}" - ) - data2 = all_pb_def[name].model_dump() - data2.update(overrides) - data = data2 - else: - if len(all_pb_def) > 1: - raise RuntimeError( - f"Non name specified, but more than one Problem definition. Available definitions: {available}" - ) - else: - data2 = next(iter(all_pb_def.values())).model_dump() - data2.update(overrides) - data = data2 + path = Path(path) + if path.suffix != ".yaml": + path = path.with_suffix(".yaml") + if not path.exists(): + raise FileNotFoundError(f'File "{path}" does not exist. Abort') - return ProblemDefinition(**data) + with path.open("r", encoding="utf-8") as file: + data = yaml.safe_load(file) or {} + + return cls.from_mapping(data) @field_validator("input_features", mode="before") @classmethod def normalize_input_features(cls, v): """Normalize input features identifiers by ensuring they are unique and sorted.""" + if not v: + raise ValueError("input_features must not be empty") if len(set(v)) != len(v): raise ValueError("duplicated values in input_features") return _normalize_list(v) @@ -103,6 +92,8 @@ def normalize_input_features(cls, v): @classmethod def normalize_output_features(cls, v): """Normalize output features identifiers by ensuring they are unique and sorted.""" + if not v: + raise ValueError("output_features must not be empty") if len(set(v)) != len(v): raise ValueError("duplicated values in output_features") return _normalize_list(v) @@ -186,7 +177,13 @@ def add_input_features(self, inputs: Union[str, Sequence[str]]) -> None: .. code-block:: python from plaid.problem_definition import ProblemDefinition - problem = ProblemDefinition() + problem = ProblemDefinition( + name="example", + input_features=["angle"], + output_features=["pressure"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) input_features = ['omega', 'pressure'] problem.add_input_features(input_features) @@ -222,7 +219,13 @@ def add_output_features(self, outputs: Union[str, Sequence[str]]) -> None: .. code-block:: python from plaid.problem_definition import ProblemDefinition - problem = ProblemDefinition() + problem = ProblemDefinition( + name="example", + input_features=["angle"], + output_features=["pressure"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) output_features = ['omega', 'pressure'] problem.add_output_features(output_features) @@ -255,7 +258,13 @@ def save_to_file(self, path: Union[str, Path]) -> None: .. code-block:: python from plaid import ProblemDefinition - problem = ProblemDefinition() + problem = ProblemDefinition( + name="example", + input_features=["angle"], + output_features=["pressure"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) problem.save_to_file("/path/to/save_file") """ path = Path(path) @@ -285,37 +294,3 @@ def save_to_file(self, path: Union[str, Path]) -> None: sort_keys=False, allow_unicode=True, ) - - def _load_from_file_(self, path: Union[str, Path]) -> None: - """Load problem information, inputs, outputs, and split from the specified file in YAML format. - - Args: - path (Union[str,Path]): The filepath from which to load the problem information. - - Raises: - FileNotFoundError: Triggered if the provided file does not exist. - - Example: - .. code-block:: python - - from plaid import ProblemDefinition - problem = ProblemDefinition() - problem._load_from_file_("/path/to/load_file") - """ - path = Path(path) - - if path.suffix != ".yaml": - path = path.with_suffix(".yaml") - - if not path.exists(): - raise FileNotFoundError(f'File "{path}" does not exist. Abort') - - with path.open("r") as file: - data = yaml.safe_load(file) - - model_fields = type(self).model_fields.keys() - for key, value in data.items(): - if key in model_fields: - setattr(self, key, value) - else: - logger.warning(f" Data ignored! : {key}: {value}") diff --git a/src/plaid/storage/common/reader.py b/src/plaid/storage/common/reader.py index 3a99fcc5..1f20b4e5 100644 --- a/src/plaid/storage/common/reader.py +++ b/src/plaid/storage/common/reader.py @@ -61,8 +61,8 @@ def load_problem_definitions_from_disk( ``problem_definitions/`` subdirectory under ``path`` and reconstructs them into ``ProblemDefinition`` objects. - Each file is loaded using ``ProblemDefinition._load_from_file_`` and inserted - into a dictionary keyed by the problem definition name. + Each file is loaded using ``ProblemDefinition.from_path`` and inserted into + a dictionary keyed by the problem definition name. Expected local layout: / @@ -92,8 +92,14 @@ def load_problem_definitions_from_disk( pb_defs = {} for p in pb_def_dir.iterdir(): if p.is_file(): - pb_def = ProblemDefinition() - pb_def._load_from_file_(pb_def_dir / Path(p.name)) + pb_def_path = pb_def_dir / Path(p.name) + try: + pb_def = ProblemDefinition.from_path(pb_def_path) + except Exception as exc: + raise ValueError( + f"Failed to load problem definition file " + f"'{pb_def_path.name}': {exc}" + ) from exc pb_name = pb_def.name if isinstance(pb_def.name, str) else p.stem pb_defs[pb_name] = pb_def return pb_defs diff --git a/src/plaid/storage/writer.py b/src/plaid/storage/writer.py index a0acfdd9..6f972395 100644 --- a/src/plaid/storage/writer.py +++ b/src/plaid/storage/writer.py @@ -156,6 +156,8 @@ def sample_constructor(file_path): }, infos=Infos( legal={"owner": "owner", "license": "license"}, + num_samples={}, + storage_backend="hf_datasets", ), num_proc=6, ) @@ -236,7 +238,11 @@ def sample_constructor(file_path): # written ``infos.yaml`` always reflects how the dataset was saved, # overriding any inherited values from the input ``infos``. if infos is None: - infos = Infos(legal=Legal(owner="unknown", license="unknown")) + infos = Infos( + legal=Legal(owner="unknown", license="unknown"), + num_samples={}, + storage_backend=backend, + ) infos_data = infos.to_dict() infos_data["num_samples"] = num_samples infos_data["storage_backend"] = backend diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index 840a3730..a66f391d 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -60,6 +60,26 @@ def test_check_dataset_missing_infos(tmp_path: Path, dataset_name: str) -> None: assert any(msg.code == "MISSING_PATH" for msg in report.messages) +@pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) +def test_check_dataset_rejects_extra_infos_key( + tmp_path: Path, dataset_name: str +) -> None: + """Extra infos.yaml keys should be reported through infos validation.""" + dataset_path = _copy_reference_dataset(tmp_path, dataset_name) + infos_path = dataset_path / "infos.yaml" + original = infos_path.read_text(encoding="utf-8") + infos_path.write_text( + f"{original}\nplaid:\n version: 0.1.13.dev36+g21db6656e.d20260302\n", + encoding="utf-8", + ) + + report = check_dataset(dataset_path) + + assert report.has_errors() + assert any(msg.code == "INFOS_READ_ERROR" for msg in report.messages) + assert any("plaid" in msg.message for msg in report.messages) + + @pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) def test_check_dataset_num_samples_mismatch(tmp_path: Path, dataset_name: str) -> None: """Tampering with num_samples should raise split mismatch errors.""" @@ -95,6 +115,21 @@ def test_main_json_output_and_exit_code( assert "messages" in payload +@pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) +def test_main_text_success_does_not_count_ok_as_info( + tmp_path: Path, capsys, dataset_name: str +) -> None: + """Successful text output should show OK without adding an info count.""" + dataset_path = _copy_reference_dataset(tmp_path, dataset_name) + + code = main([str(dataset_path)]) + out = capsys.readouterr().out + + assert code == 0 + assert f"[OK] {dataset_path}: No issue detected" in out + assert "Summary: 0 error(s), 0 warning(s), 0 info message(s)" in out + + @pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) def test_main_strict_fails_on_warning(tmp_path: Path, dataset_name: str) -> None: """In strict mode, warnings should make the command fail.""" @@ -496,6 +531,31 @@ def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> N assert any(msg.code == "SAMPLE_CONVERSION_ERROR" for msg in report.messages) +def test_check_dataset_missing_num_samples_split_is_clear( + tmp_path: Path, monkeypatch +) -> None: + """Missing split declarations should not be reported as opaque KeyErrors.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "data" / "OOD").mkdir() + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: {"storage_backend": "zarr", "num_samples": {"train": 1}}, # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Var": {}}, {"train": {}}, None), # noqa: ARG005 + ) + + report = check_dataset(dataset) + + assert any(msg.code == "NUM_SAMPLES_MISSING_SPLIT" for msg in report.messages) + assert not any(msg.code == "DATASET_INIT_ERROR" for msg in report.messages) + assert any("OOD" in msg.message for msg in report.messages) + + def test_check_dataset_problem_definition_validation_paths( tmp_path: Path, monkeypatch ) -> None: @@ -564,6 +624,44 @@ def __init__(self, train_split, test_split): assert any(msg.code == "PB_DEF_OUT_OF_RANGE_INDICES" for msg in report_pb.messages) +def test_check_dataset_problem_definition_read_error_names_yaml_file( + tmp_path: Path, monkeypatch +) -> None: + """Problem-definition read errors should identify the offending YAML file.""" + dataset = _make_minimal_layout(tmp_path) + pb_def_dir = dataset / "problem_definitions" + pb_def_dir.mkdir() + (pb_def_dir / "bad_definition.yaml").write_text( + "name: bad\n" + "input_features: [in]\n" + "output_features: [out]\n" + "unexpected_key: value\n", + encoding="utf-8", + ) + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: {"storage_backend": "zarr", "num_samples": {"train": 0}}, # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Known": {}}, {"train": {}}, None), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ({"train": _FakeDataset(0)}, {"train": _FakeConverter([])}), # noqa: ARG005 + ) + + report = check_dataset(dataset) + + assert any(msg.code == "PB_DEF_READ_ERROR" for msg in report.messages) + assert any("bad_definition.yaml" in msg.message for msg in report.messages) + assert any("extra_forbidden" in msg.message for msg in report.messages) + + def test_main_strict_returns_warning_exit_code( monkeypatch, tmp_path: Path, capsys ) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index fa3ee858..88a3355d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -95,6 +95,8 @@ def infos(): { "legal": {"owner": "PLAID2", "license": "BSD-3"}, "data_production": {"type": "simulation", "simulator": "Z-set"}, + "num_samples": {}, + "storage_backend": "zarr", } ) diff --git a/tests/containers/test_utils.py b/tests/containers/test_utils.py index a7cbd352..b119c9d4 100644 --- a/tests/containers/test_utils.py +++ b/tests/containers/test_utils.py @@ -145,6 +145,8 @@ def test_get_feature_details_from_path(self, url, expected): def test_validate_required_only(self): infos = { "legal": {"owner": "Joh Doe", "license": "cc-by-sa-4.0"}, + "num_samples": {"train": 1}, + "storage_backend": "zarr", } Infos.validate_required_only(infos) @@ -152,6 +154,8 @@ def test_validate_required_only(self): "legal": { "owner": "Joh Doe", }, + "num_samples": {"train": 1}, + "storage_backend": "zarr", } with pytest.raises(ValueError): Infos.validate_required_only(infos_missing_license) diff --git a/tests/storage/test_cgns_init.py b/tests/storage/test_cgns_init.py index 292e9f92..a3fc99e4 100644 --- a/tests/storage/test_cgns_init.py +++ b/tests/storage/test_cgns_init.py @@ -170,7 +170,11 @@ def test_cgns_backend_configure_dataset_card_requires_local_dir(): CgnsBackend.configure_dataset_card( repo_id="dummy/repo", infos=Infos.from_mapping( - {"legal": {"owner": "owner", "license": "cc-by-4.0"}} + { + "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "num_samples": {}, + "storage_backend": "cgns", + } ), ) @@ -184,7 +188,13 @@ def fake_configure_dataset_card(**kwargs): monkeypatch.setattr(cgns, "configure_dataset_card", fake_configure_dataset_card) - infos = Infos.from_mapping({"legal": {"owner": "owner", "license": "cc-by-4.0"}}) + infos = Infos.from_mapping( + { + "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "num_samples": {}, + "storage_backend": "cgns", + } + ) CgnsBackend.configure_dataset_card( repo_id="dummy/repo", infos=infos, diff --git a/tests/storage/test_hf_datasets_init.py b/tests/storage/test_hf_datasets_init.py index 5f92b315..3771e74e 100644 --- a/tests/storage/test_hf_datasets_init.py +++ b/tests/storage/test_hf_datasets_init.py @@ -210,7 +210,13 @@ def fake_configure_dataset_card( hf_datasets, "configure_dataset_card", fake_configure_dataset_card ) - infos = Infos.from_mapping({"legal": {"owner": "owner", "license": "cc-by-4.0"}}) + infos = Infos.from_mapping( + { + "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "num_samples": {}, + "storage_backend": "hf_datasets", + } + ) HFBackend.configure_dataset_card("dummy/repo", infos) assert call == {"repo_id": "dummy/repo", "infos": infos, "local_dir": None} diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index b9f2f449..8495c9e6 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -163,11 +163,13 @@ def main_splits() -> dict: @pytest.fixture() def problem_definition(main_splits) -> ProblemDefinition: - problem_definition = ProblemDefinition() - problem_definition.add_input_features(["feature_name_1", "feature_name_2"]) - problem_definition.train_split = {"train": main_splits["train"]} - problem_definition.test_split = {"test": main_splits["test"]} - return problem_definition + return ProblemDefinition( + name="problem_definition", + input_features=["feature_name_1", "feature_name_2"], + output_features=["feature_name_1"], + train_split={"train": main_splits["train"]}, + test_split={"test": main_splits["test"]}, + ) @pytest.fixture() diff --git a/tests/storage/test_zarr_init.py b/tests/storage/test_zarr_init.py index ea8c320b..220743be 100644 --- a/tests/storage/test_zarr_init.py +++ b/tests/storage/test_zarr_init.py @@ -176,7 +176,13 @@ def fake_configure_dataset_card( monkeypatch.setattr(zarr, "configure_dataset_card", fake_configure_dataset_card) - infos = Infos.from_mapping({"legal": {"owner": "owner", "license": "cc-by-4.0"}}) + infos = Infos.from_mapping( + { + "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "num_samples": {}, + "storage_backend": "zarr", + } + ) result = ZarrBackend.configure_dataset_card("dummy/repo", infos, "/tmp/local") assert result == "configured" diff --git a/tests/test_problem_definition.py b/tests/test_problem_definition.py index 38b23842..b462421a 100644 --- a/tests/test_problem_definition.py +++ b/tests/test_problem_definition.py @@ -14,7 +14,13 @@ @pytest.fixture() def problem_definition() -> ProblemDefinition: - return ProblemDefinition() + return ProblemDefinition.model_construct( + name=None, + input_features=[], + output_features=[], + train_split=None, + test_split=None, + ) @pytest.fixture() @@ -77,153 +83,131 @@ class Test_ProblemDefinition: def test__init__(self, problem_definition): print(problem_definition) - def test__init__both_path_and_directory_path(self, current_directory): - d_path = current_directory / "problem_definition" - with pytest.raises(ValueError): - ProblemDefinition(path=d_path, directory_path=d_path) + def test_required_fields(self): + with pytest.raises(ValidationError, match="Field required"): + ProblemDefinition() + + def test_feature_lists_must_not_be_empty(self): + base = { + "name": "pb", + "train_split": {"train": "all"}, + "test_split": {"test": "all"}, + } + with pytest.raises(ValidationError, match="input_features must not be empty"): + ProblemDefinition( + **base, + input_features=[], + output_features=["out"], + ) + with pytest.raises(ValidationError, match="output_features must not be empty"): + ProblemDefinition( + **base, + input_features=["in"], + output_features=[], + ) # -------------------------------------------------------------------------# - def test_from_path_single_definition(self, monkeypatch, tmp_path): - expected = ProblemDefinition( - name="pb_single", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, - ) - - def fake_loader(path): - assert path == tmp_path - return {"pb_single": expected} - - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", fake_loader + def test_from_mapping_validates_and_normalizes(self): + loaded = ProblemDefinition.from_mapping( + { + "name": "pb_single", + "input_features": ["in_b", "in_a"], + "output_features": ["out_b", "out_a"], + "train_split": {"train_0": [0, 1]}, + "test_split": {"test_0": [2]}, + } ) - loaded = ProblemDefinition.from_path(tmp_path) assert loaded.name == "pb_single" - assert loaded.input_features == ["in_a"] - assert loaded.output_features == ["out_a"] + assert loaded.input_features == ["in_a", "in_b"] + assert loaded.output_features == ["out_a", "out_b"] assert loaded.get_train_split_name() == "train_0" assert loaded.get_test_split_name() == "test_0" assert loaded.get_train_split_indices() == [0, 1] assert loaded.get_test_split_indices() == [2] - def test_from_path_single_definition_with_override(self, monkeypatch, tmp_path): - expected = ProblemDefinition( - name="pb_single", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, - ) - - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", - lambda path: {"pb_single": expected}, # noqa: ARG005 - ) - - loaded = ProblemDefinition.from_path(tmp_path) - assert loaded.name == "pb_single" - - def test_from_path_named_definition_and_override(self, monkeypatch, tmp_path): - pb_1 = ProblemDefinition( - name="pb_1", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, - ) - pb_2 = ProblemDefinition( - name="pb_2", - input_features=["in_b"], - output_features=["out_b"], - train_split={"train_1": [3, 4]}, - test_split={"test_1": [5]}, - ) - - def fake_loader(path): - assert path == tmp_path - return {"pb_1": pb_1, "pb_2": pb_2} - - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", fake_loader - ) - - loaded = ProblemDefinition.from_path( - tmp_path, - name="pb_2", - ) - assert loaded.name == "pb_2" - - def test_from_path_unknown_name_raises(self, monkeypatch, tmp_path): - pb = ProblemDefinition( - name="existing", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, + def test_from_path_loads_single_yaml_file(self, tmp_path: Path): + file_path = tmp_path / "problem.yaml" + file_path.write_text( + "name: pb\n" + "input_features:\n" + " - in_1\n" + "output_features:\n" + " - out_1\n" + "train_split:\n" + " train: [0]\n" + "test_split:\n" + " test: [1]\n", + encoding="utf-8", ) - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", - lambda path: {"existing": pb}, # noqa: ARG005 - ) + loaded = ProblemDefinition.from_path(file_path) - with pytest.raises(ValueError, match="Problem definition 'missing' not found"): - ProblemDefinition.from_path(tmp_path, name="missing") + assert loaded.name == "pb" + assert loaded.get_train_split_name() == "train" + assert loaded.get_test_split_name() == "test" - def test_from_path_requires_name_when_multiple(self, monkeypatch, tmp_path): - pb_1 = ProblemDefinition( - name="pb_1", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, - ) - pb_2 = ProblemDefinition( - name="pb_2", - input_features=["in_b"], - output_features=["out_b"], - train_split={"train_1": [3, 4]}, - test_split={"test_1": [5]}, + def test_from_path_adds_yaml_suffix(self, tmp_path: Path): + file_path = tmp_path / "problem.yaml" + file_path.write_text( + "name: pb\n" + "input_features: [in_1]\n" + "output_features: [out_1]\n" + "train_split:\n" + " train: all\n" + "test_split:\n" + " test: all\n", + encoding="utf-8", ) - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", - lambda path: {"pb_1": pb_1, "pb_2": pb_2}, # noqa: ARG005 - ) + loaded = ProblemDefinition.from_path(tmp_path / "problem") - with pytest.raises(RuntimeError, match="more than one Problem definition"): - ProblemDefinition.from_path(tmp_path) + assert loaded.name == "pb" - def test_from_path_error_lists_sorted_available_names(self, monkeypatch, tmp_path): - pb_a = ProblemDefinition( - name="a", input_features=["in"], output_features=["out"] - ) - pb_b = ProblemDefinition( - name="b", input_features=["in"], output_features=["out"] + def test_from_path_unknown_key_raises(self, tmp_path: Path): + file_path = tmp_path / "problem_with_unknown.yaml" + file_path.write_text( + "name: pb\n" + "input_features: [in_1]\n" + "output_features: [out_1]\n" + "train_split:\n" + " train: all\n" + "test_split:\n" + " test: all\n" + "unknown_key: value\n", + encoding="utf-8", ) - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", - lambda path: {"b": pb_b, "a": pb_a}, # noqa: ARG005 - ) + with pytest.raises(ValidationError, match="extra_forbidden"): + ProblemDefinition.from_path(file_path) - with pytest.raises(ValueError, match="Available definitions: a, b"): - ProblemDefinition.from_path(tmp_path, name="missing") + def test_from_path_non_existing_file(self): + with pytest.raises(FileNotFoundError): + ProblemDefinition.from_path(Path("non_existing_path")) def test_feature_validators_reject_duplicates(self): with pytest.raises( ValidationError, match="duplicated values in input_features" ): - ProblemDefinition(input_features=["a", "a"]) + ProblemDefinition( + name="pb", + input_features=["a", "a"], + output_features=["out"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) with pytest.raises( ValidationError, match="duplicated values in output_features" ): - ProblemDefinition(output_features=["a", "a"]) + ProblemDefinition( + name="pb", + input_features=["in"], + output_features=["a", "a"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) def test_non_overwritable_attributes_raise(self, problem_definition): problem_definition.name = "problem_a" @@ -280,34 +264,6 @@ def test_split(self, problem_definition): assert problem_definition.train_split == {"train_0": [0, 1, 2]} assert problem_definition.test_split == {"test-1": [3, 4]} - def test__load_from_file_( - self, problem_definition_full: ProblemDefinition, tmp_path: Path - ): - - path = tmp_path / "pb_def" - problem_definition_full.save_to_file(path) - problem = ProblemDefinition() - problem._load_from_file_(path) - assert set(problem.input_features) == set( - [ - "Base_2_2/Zone/PointData/sig12", - "Base_2_2/Zone/PointData/U1", - "Base_2_2/Zone/PointData/U2", - "Global/predict_feature", - "Global/test_feature", - "Global/feature", - ] - ) - assert set(problem.output_features) == set( - [ - "Global/predict_feature", - "Base_2_2/Zone/PointData/sig12", - "Global/feature", - "Base_2_2/Zone/PointData/U2", - "Global/test_feature", - ] - ) - def test_save_and_load_keep_yaml_suffix(self, tmp_path: Path): problem = ProblemDefinition( name="pb", @@ -319,36 +275,8 @@ def test_save_and_load_keep_yaml_suffix(self, tmp_path: Path): file_path = tmp_path / "problem.yaml" problem.save_to_file(file_path) - loaded = ProblemDefinition() - loaded._load_from_file_(file_path) + loaded = ProblemDefinition.from_path(file_path) assert loaded.name == "pb" assert loaded.get_train_split_name() == "train" assert loaded.get_test_split_name() == "test" - - def test__load_from_file__unknown_field_warns_and_raises( - self, tmp_path: Path, caplog - ): - file_path = tmp_path / "problem_with_unknown.yaml" - file_path.write_text( - "name: pb\n" - "task: regression\n" - "input_features:\n" - " - in_1\n" - "output_features:\n" - " - out_1\n" - "unknown_key: value\n", - encoding="utf-8", - ) - - problem = ProblemDefinition() - with caplog.at_level("WARNING"): - problem._load_from_file_(file_path) - - assert "Data ignored! : unknown_key: value" in caplog.text - - def test__load_from_file__non_existing_file(self): - problem = ProblemDefinition() - non_existing_path = Path("non_existing_path") - with pytest.raises(FileNotFoundError): - problem._load_from_file_(non_existing_path) diff --git a/tests/utils/test_info.py b/tests/utils/test_info.py index 3dea8f13..e356fc61 100644 --- a/tests/utils/test_info.py +++ b/tests/utils/test_info.py @@ -28,16 +28,16 @@ def test_validate_required_only_missing_required_key(): Infos(**{"legal": {"owner": "someone"}}) -def test_normalize_infos_strips_legacy_plaid_section_and_copies(): +def test_normalize_infos_rejects_legacy_plaid_section_and_copies(): infos = { "legal": {"owner": "owner", "license": "cc-by-4.0"}, "plaid": {"version": "x"}, } - normalized = Infos.normalize_mapping(infos) - # The legacy ``plaid`` section is dropped from the validated payload. - assert "plaid" not in normalized - # And the input mapping is not mutated. + with pytest.raises(ValidationError, match="extra_forbidden"): + Infos.normalize_mapping(infos) + + # The input mapping is not mutated before validation raises. assert "plaid" in infos @@ -95,6 +95,8 @@ def test_infos_save_and_load_roundtrip(tmp_path): def test_infos_from_path_directory(tmp_path): infos = { "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "num_samples": {"train": 10}, + "storage_backend": "zarr", } Infos.from_mapping(infos).save_to_file(tmp_path / "infos.yaml") reloaded = Infos.from_path(tmp_path) @@ -134,7 +136,13 @@ def test_validate_authorized_only_reraises_other_validation_errors(): def test_validate_required_only_accepts_valid_mapping(): - Infos.validate_required_only({"legal": {"owner": "o", "license": "l"}}) + Infos.validate_required_only( + { + "legal": {"owner": "o", "license": "l"}, + "num_samples": {"train": 1}, + "storage_backend": "zarr", + } + ) def test_validate_required_only_missing_legal(): @@ -144,7 +152,11 @@ def test_validate_required_only_missing_legal(): def test_to_dict_returns_plain_mapping(): model = Infos.from_mapping( - {"legal": {"owner": "o", "license": "l"}, "storage_backend": "zarr"} + { + "legal": {"owner": "o", "license": "l"}, + "num_samples": {}, + "storage_backend": "zarr", + } ) d = model.to_dict() assert d["legal"] == {"owner": "o", "license": "l"} @@ -155,6 +167,7 @@ def test_getitem_returns_plain_value_and_unwraps_nested_dataclasses(): model = Infos.from_mapping( { "legal": {"owner": "o", "license": "l"}, + "num_samples": {}, "storage_backend": "zarr", } ) @@ -166,14 +179,24 @@ def test_getitem_returns_plain_value_and_unwraps_nested_dataclasses(): def test_getitem_raises_key_error_for_unknown_field(): - model = Infos.from_mapping({"legal": {"owner": "o", "license": "l"}}) + model = Infos.from_mapping( + { + "legal": {"owner": "o", "license": "l"}, + "num_samples": {}, + "storage_backend": "zarr", + } + ) with pytest.raises(KeyError): model["does_not_exist"] def test_contains_handles_strings_and_non_strings(): model = Infos.from_mapping( - {"legal": {"owner": "o", "license": "l"}, "storage_backend": "zarr"} + { + "legal": {"owner": "o", "license": "l"}, + "num_samples": {}, + "storage_backend": "zarr", + } ) assert "legal" in model assert "storage_backend" in model @@ -186,27 +209,45 @@ def test_contains_handles_strings_and_non_strings(): def test_get_returns_default_when_missing(): - model = Infos.from_mapping({"legal": {"owner": "o", "license": "l"}}) - assert model.get("storage_backend", "fallback") == "fallback" + model = Infos.from_mapping( + { + "legal": {"owner": "o", "license": "l"}, + "num_samples": {}, + "storage_backend": "zarr", + } + ) + assert model.get("data_description", "fallback") == "fallback" assert model.get("legal")["owner"] == "o" def test_save_to_file_treats_suffixless_path_as_directory(tmp_path): target = tmp_path / "myinfos" - Infos(legal=Legal(owner="o", license="l")).save_to_file(target) + Infos( + legal=Legal(owner="o", license="l"), + num_samples={}, + storage_backend="zarr", + ).save_to_file(target) # Suffix-less, non-existing paths are treated as directories that # will hold an ``infos.yaml``. assert (target / "infos.yaml").is_file() def test_save_to_file_into_existing_directory(tmp_path): - Infos(legal=Legal(owner="o", license="l")).save_to_file(tmp_path) + Infos( + legal=Legal(owner="o", license="l"), + num_samples={}, + storage_backend="zarr", + ).save_to_file(tmp_path) assert (tmp_path / "infos.yaml").is_file() def test_save_to_file_replaces_non_yaml_suffix(tmp_path): target = tmp_path / "weird.txt" - Infos(legal=Legal(owner="o", license="l")).save_to_file(target) + Infos( + legal=Legal(owner="o", license="l"), + num_samples={}, + storage_backend="zarr", + ).save_to_file(target) # ``.txt`` suffix is replaced by ``.yaml``. assert (tmp_path / "weird.yaml").is_file() assert not target.exists() @@ -223,7 +264,11 @@ def test_save_to_file_preserves_unknown_future_keys(tmp_path, monkeypatch): monkeypatch.setattr(infos_mod, "_KEY_ORDER", ()) model = Infos.from_mapping( - {"legal": {"owner": "o", "license": "l"}, "storage_backend": "zarr"} + { + "legal": {"owner": "o", "license": "l"}, + "num_samples": {}, + "storage_backend": "zarr", + } ) target = tmp_path / "out.yaml" model.save_to_file(target) From 84200bf1bec7400b7047a867cc362a1e57eeea92 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 4 Jun 2026 21:10:05 +0200 Subject: [PATCH 02/11] wip --- src/plaid/cli/plaidcheck.py | 143 +++++++++++++++++++++++++++++++++-- tests/cli/test_plaidcheck.py | 114 +++++++++++++++++++++++++++- 2 files changed, 249 insertions(+), 8 deletions(-) diff --git a/src/plaid/cli/plaidcheck.py b/src/plaid/cli/plaidcheck.py index b0a4b922..d3c23740 100644 --- a/src/plaid/cli/plaidcheck.py +++ b/src/plaid/cli/plaidcheck.py @@ -8,6 +8,7 @@ import CGNS.PAT.cgnsutils as CGU import numpy as np +from tqdm import tqdm from plaid.constants import CGNS_FIELD_LOCATIONS from plaid.storage import init_from_disk @@ -276,6 +277,102 @@ def _is_branch_without_data_in_mapping( return any(name.startswith(prefix) for name in feat_map) +def _progress( + iterable: Any, *, desc: str, show_progress: bool, total: int | None = None +): + """Wrap an iterable in a tqdm progress bar when requested. + + Args: + iterable: Iterable to wrap. + desc: Progress bar description. + show_progress: Whether the progress bar is enabled. + total: Optional total length. + + Returns: + Iterable, possibly wrapped by :class:`tqdm.tqdm`. + """ + return tqdm(iterable, desc=desc, total=total, disable=not show_progress) + + +def _resolve_problem_split_indices( + split_ids: Any, + split_len: int, +) -> list[int]: + """Resolve a problem-definition split declaration into concrete indices. + + Args: + split_ids: Either the special all-samples marker or an iterable of indices. + split_len: Number of samples available in the referenced split. + + Returns: + Concrete sample indices to instantiate. + """ + if split_ids == "all": + return list(range(split_len)) + return list(split_ids) + + +def _check_problem_definition_sample_features( + *, + pb_name: str, + split_dict_name: str, + split_name: str, + idx: int, + dataset: Any, + converter: Any, + features: list[str], + report: CheckReport, +) -> None: + """Instantiate and validate one problem-definition sample view. + + The sample is instantiated with the exact feature subset requested by the + problem definition, then each requested feature is read back and checked for + invalid content (None, NaN, Inf, empty arrays, object arrays containing None). + + Args: + pb_name: Problem-definition name. + split_dict_name: ``train_split`` or ``test_split``. + split_name: Dataset split name. + idx: Sample index. + dataset: Backend dataset object. + converter: Storage converter exposing ``to_plaid``. + features: Feature paths to request and validate. + report: Report updated with detected errors/warnings. + """ + location = f"problem_definitions/{pb_name}/{split_dict_name}/{split_name}[{idx}]" + try: + sample = converter.to_plaid(dataset, idx, features=features) + except Exception as exc: + report.add( + "error", + "PB_DEF_SAMPLE_CONVERSION_ERROR", + location, + str(exc), + ) + return + + for feature in features: + try: + value = sample.get_feature_by_path(feature) + except Exception as exc: + report.add( + "error", + "PB_DEF_FEATURE_READ_ERROR", + f"{location} {feature}", + str(exc), + ) + continue + + issue = _check_numeric_content(value) + if issue is not None: + report.add( + "warning", + "PB_DEF_INVALID_FEATURE_VALUE", + f"{location} {feature}", + issue, + ) + + def compute_checksum(sample: Any) -> str: """Compute a SHA-256 checksum for a converted sample representation. @@ -296,6 +393,7 @@ def compute_checksum(sample: Any) -> str: def check_dataset( path: Path, splits: Optional[list[str]] = None, + show_progress: bool = True, ) -> CheckReport: """Run integrity checks on a local PLAID dataset. @@ -316,6 +414,7 @@ def check_dataset( Args: path: Dataset directory. splits: Optional selected split names. + show_progress: Whether to display tqdm progress bars for expensive checks. Returns: A populated :class:`CheckReport`. @@ -454,7 +553,12 @@ def check_dataset( ) # Deep-check to validate content and detect non valide data in fields (nan inf) - for idx in range(actual_n): + for idx in _progress( + range(actual_n), + desc=f"Checking split {split}", + show_progress=show_progress, + total=actual_n, + ): try: sample = converter.to_plaid(dataset, idx) except Exception as exc: @@ -602,9 +706,8 @@ def check_dataset( f"Unknown split in {split_dict_name}: {split_name}", ) continue - if split_ids == "all": - continue - ids_list = list(split_ids) + split_len = len(datasetdict[split_name]) + ids_list = _resolve_problem_split_indices(split_ids, split_len) if len(ids_list) != len(set(ids_list)): report.add( "error", @@ -612,7 +715,6 @@ def check_dataset( f"problem_definitions/{pb_name}", f"Duplicated indices in {split_dict_name}", ) - split_len = len(datasetdict[split_name]) bad = [i for i in ids_list if i < 0 or i >= split_len] if bad: report.add( @@ -621,6 +723,31 @@ def check_dataset( f"problem_definitions/{pb_name}", f"Out-of-range indices in {split_dict_name} (first 10): {bad[:10]}", ) + continue + + if split_dict_name == "train_split": + features = list(pb_def.input_features) + list( + pb_def.output_features + ) + else: + features = list(pb_def.input_features) + + for idx in _progress( + ids_list, + desc=f"Checking problem {pb_name} {split_dict_name}", + show_progress=show_progress, + total=len(ids_list), + ): + _check_problem_definition_sample_features( + pb_name=pb_name, + split_dict_name=split_dict_name, + split_name=split_name, + idx=idx, + dataset=datasetdict[split_name], + converter=converterdict[split_name], + features=features, + report=report, + ) return report @@ -664,7 +791,11 @@ def main(argv: Optional[list[str]] = None) -> int: parser = _build_parser() args = parser.parse_args(argv) - report = check_dataset(path=args.path, splits=args.split) + report = check_dataset( + path=args.path, + splits=args.split, + show_progress=not args.json, + ) if args.json: print(report.to_json()) diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index a66f391d..8260f14e 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -234,12 +234,14 @@ def __init__( global_names: list[str] | None = None, tree: Any = None, checksum: str = "same", + features: dict[str, Any] | None = None, ) -> None: self._global_value = global_value self._field_value = field_value self._global_names = ["G"] if global_names is None else global_names self._tree = tree self._checksum = checksum + self._features = {} if features is None else features def get_zone_names(self, base: str, time: float) -> list[str]: # noqa: ARG002 """Return deterministic zone names for checker traversal. @@ -266,6 +268,8 @@ def get_feature_by_path(self, path: str) -> Any: # noqa: ARG002 Returns: Global value payload. """ + if path in self._features: + return self._features[path] return self._global_value def get_tree(self): @@ -335,8 +339,14 @@ def __init__( ) -> None: self._samples = samples self._fail_indices = set() if fail_indices is None else fail_indices + self.feature_requests: list[list[str] | None] = [] - def to_plaid(self, dataset: _FakeDataset, idx: int) -> Any: # noqa: ARG002 + def to_plaid( + self, + dataset: _FakeDataset, # noqa: ARG002 + idx: int, + features: list[str] | None = None, + ) -> Any: """Return fake sample or raise conversion error. Args: @@ -346,6 +356,7 @@ def to_plaid(self, dataset: _FakeDataset, idx: int) -> Any: # noqa: ARG002 Returns: Fake sample instance. """ + self.feature_requests.append(features) if idx in self._fail_indices: raise RuntimeError("boom") return self._samples[idx] @@ -624,6 +635,82 @@ def __init__(self, train_split, test_split): assert any(msg.code == "PB_DEF_OUT_OF_RANGE_INDICES" for msg in report_pb.messages) +def test_check_dataset_problem_definition_instantiates_filtered_features( + tmp_path: Path, monkeypatch +) -> None: + """Problem definitions should instantiate exact train/test feature subsets.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "problem_definitions").mkdir() + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: { # noqa: ARG005 + "storage_backend": "zarr", + "num_samples": {"train": 1, "test": 1}, + }, + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ( # noqa: ARG005 + {"train": {}, "test": {}}, + {"Input": {}, "Output": {}}, + {"train": {}, "test": {}}, + None, + ), + ) + + train_converter = _FakeConverter( + [ + _FakeSampleForCheck( + features={"Input": np.array([1.0]), "Output": np.array([np.nan])} + ) + ] + ) + test_converter = _FakeConverter( + [ + _FakeSampleForCheck( + features={"Input": np.array([2.0]), "Output": np.array([np.nan])} + ) + ] + ) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ( # noqa: ARG005 + {"train": _FakeDataset(1), "test": _FakeDataset(1)}, + {"train": train_converter, "test": test_converter}, + ), + ) + + class _PBDef: + input_features = ["Input"] + output_features = ["Output"] + train_split = {"train": [0]} + test_split = {"test": [0]} + + monkeypatch.setattr( + plaidcheck, + "load_problem_definitions_from_disk", + lambda path: {"pb": _PBDef()}, # noqa: ARG005 + ) + + report = check_dataset(dataset, show_progress=False) + + assert train_converter.feature_requests[-1] == ["Input", "Output"] + assert test_converter.feature_requests[-1] == ["Input"] + invalid = [ + msg for msg in report.messages if msg.code == "PB_DEF_INVALID_FEATURE_VALUE" + ] + assert any( + "train_split" in msg.location and "Output" in msg.location for msg in invalid + ) + assert not any( + "test_split" in msg.location and "Output" in msg.location for msg in invalid + ) + + def test_check_dataset_problem_definition_read_error_names_yaml_file( tmp_path: Path, monkeypatch ) -> None: @@ -669,8 +756,31 @@ def test_main_strict_returns_warning_exit_code( report = CheckReport(messages=[]) report.add("warning", "W", "loc", "msg") - monkeypatch.setattr(plaidcheck, "check_dataset", lambda path, splits=None: report) # noqa: ARG005 + monkeypatch.setattr( + plaidcheck, + "check_dataset", + lambda path, splits=None, show_progress=True: report, # noqa: ARG005 + ) code = main([str(tmp_path), "--strict"]) _ = capsys.readouterr().out assert code == 2 + + +def test_main_json_disables_progress(monkeypatch, tmp_path: Path, capsys) -> None: + """JSON mode should disable progress bars in check_dataset.""" + seen: dict[str, bool] = {} + report = CheckReport(messages=[]) + + def _fake_check_dataset(path, splits=None, show_progress=True): # noqa: ARG001 + seen["show_progress"] = show_progress + return report + + monkeypatch.setattr(plaidcheck, "check_dataset", _fake_check_dataset) + + code = main([str(tmp_path), "--json"]) + payload = json.loads(capsys.readouterr().out) + + assert code == 0 + assert payload["counts"] == {"error": 0, "warning": 0, "info": 0} + assert seen["show_progress"] is False From ccde36be350f4e83cf79696d6848b21d54b7c4c4 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 4 Jun 2026 21:58:50 +0200 Subject: [PATCH 03/11] wip --- docs/source/concepts/check.md | 62 +++++++- docs/source/concepts/infos.md | 14 +- docs/source/concepts/problem_definition.md | 19 ++- docs/source/tutorials/storage.md | 2 +- examples/infos_example.py | 13 +- examples/problem_definition_example.py | 6 +- src/plaid/cli/plaidcheck.py | 32 +++- src/plaid/infos.py | 44 +----- src/plaid/problem_definition.py | 61 +------- src/plaid/storage/cgns/reader.py | 2 +- src/plaid/storage/cgns/writer.py | 2 +- src/plaid/storage/hf_datasets/reader.py | 2 +- src/plaid/storage/hf_datasets/writer.py | 2 +- src/plaid/storage/reader.py | 10 +- src/plaid/storage/writer.py | 6 +- src/plaid/storage/zarr/reader.py | 2 +- src/plaid/storage/zarr/writer.py | 2 +- .../viewer/services/plaid_dataset_service.py | 5 +- tests/cli/test_plaidcheck.py | 148 ++++++++++++++++-- tests/conftest.py | 2 +- tests/storage/test_cgns_init.py | 4 +- tests/storage/test_hf_datasets_init.py | 2 +- tests/storage/test_storage.py | 4 +- tests/storage/test_zarr_init.py | 2 +- tests/test_problem_definition.py | 34 ++-- tests/utils/test_info.py | 66 ++------ 26 files changed, 304 insertions(+), 244 deletions(-) diff --git a/docs/source/concepts/check.md b/docs/source/concepts/check.md index f387297d..d88c9514 100644 --- a/docs/source/concepts/check.md +++ b/docs/source/concepts/check.md @@ -1,3 +1,63 @@ +--- +title: Dataset check +--- + # Dataset check -`plaid-check` is a \ No newline at end of file +`plaid-check` validates the integrity of a local PLAID dataset. + +It checks: + +- required on-disk files and directories; +- `infos.yaml`, metadata and split sample counts; +- sample conversion through the declared storage backend; +- invalid numeric values such as `None`, empty arrays, `NaN` and `Inf`; +- duplicated samples; +- optional `problem_definitions/` feature names, splits and indices. + +## Basic usage + +```bash +plaid-check /path/to/plaid_dataset +``` + +A valid dataset prints an `[OK]` line and returns exit code `0`. + +## Options + +Check only selected splits: + +```bash +plaid-check /path/to/plaid_dataset --split train --split test +``` + +Check only selected problem definitions: + +```bash +plaid-check /path/to/plaid_dataset --problem-definition regression_500 +``` + +Emit a machine-readable report: + +```bash +plaid-check /path/to/plaid_dataset --json +``` + +Make warnings fail the command: + +```bash +plaid-check /path/to/plaid_dataset --strict +``` + +## Report format + +Messages are reported with a severity, a stable code, a location and a short +description. Errors return exit code `1`; warnings return exit code `2` only in +strict mode. + +## Notes + +- For CGNS datasets, only `infos.yaml` and `data/` are required at the root. +- For other backends, metadata files and `constants/` are checked as well. +- Without `--problem-definition`, all discovered problem definitions are checked. +- In JSON mode, progress bars are disabled to keep the output parseable. \ No newline at end of file diff --git a/docs/source/concepts/infos.md b/docs/source/concepts/infos.md index 5a0889b3..be1550f6 100644 --- a/docs/source/concepts/infos.md +++ b/docs/source/concepts/infos.md @@ -38,7 +38,7 @@ infos = Infos( Infos can also be built from a plain mapping, for instance after reading YAML: ```python -infos = Infos.from_mapping( +infos = Infos.model_validate( { "legal": { "owner": "Safran", @@ -72,15 +72,15 @@ infos.save_to_file("/path/to/plaid_dataset/infos.yaml") If a directory path is provided, the file is saved as `infos.yaml` inside that directory. -## Mapping-style access +## Typed access and serialization -`Infos` provides read-only mapping-style helpers for compatibility with code -expecting a YAML-like dictionary: +`Infos` is a Pydantic model. Access metadata through typed attributes and use +Pydantic serialization when a plain mapping is needed: ```python -owner = infos["legal"]["owner"] -backend = infos.get("storage_backend") -payload = infos.to_dict() +owner = infos.legal.owner +backend = infos.storage_backend +payload = infos.model_dump(exclude_none=True) ``` ## Notes diff --git a/docs/source/concepts/problem_definition.md b/docs/source/concepts/problem_definition.md index ee6bb01c..c462bdf7 100644 --- a/docs/source/concepts/problem_definition.md +++ b/docs/source/concepts/problem_definition.md @@ -35,6 +35,21 @@ pb = ProblemDefinition( Feature lists are normalized by the model: entries are converted to strings, sorted, checked for duplicates, and rejected if empty. +Problem definitions can also be validated from a plain mapping, for instance +after reading YAML: + +```python +pb = ProblemDefinition.model_validate( + { + "name": "regression_1", + "input_features": ["Base/Zone/GridCoordinates/CoordinateX"], + "output_features": ["Base/Zone/VertexFields/pressure"], + "train_split": {"train": [0, 1, 2]}, + "test_split": {"test": [3, 4]}, + } +) +``` + ## Loading from disk Load a definition from a dataset path: @@ -64,8 +79,8 @@ pb.save_to_file("problem_definitions/regression_1.yaml") ## Notes -- Input/output features are plain strings correspond to CGNS paths. -- Splits are represented by `train_split` and `test_split` dictionaries. +- Input/output features are plain strings corresponding to CGNS paths. +- Splits are represented by `train_split` and `test_split` dictionaries and are accessed directly as model attributes. - Split values can be explicit index sequences or the string `"all"`. - `add_input_features(...)` and `add_output_features(...)` accept either a single string or a sequence of strings after initialization. \ No newline at end of file diff --git a/docs/source/tutorials/storage.md b/docs/source/tutorials/storage.md index e826c956..b95011c5 100644 --- a/docs/source/tutorials/storage.md +++ b/docs/source/tutorials/storage.md @@ -96,7 +96,7 @@ curated_test_ids = curated_test_ids[:10] #--------------------------------------------------------------- # infos and problem definition must be define to correctly populate the dataset's metadata -infos = Infos.from_mapping( +infos = Infos.model_validate( { "legal": { "owner": "NeuralOperator (https://zenodo.org/records/13993629)", diff --git a/examples/infos_example.py b/examples/infos_example.py index cd7ed37b..dada6711 100644 --- a/examples/infos_example.py +++ b/examples/infos_example.py @@ -24,7 +24,7 @@ # 3. Saving and loading infos # # This notebook provides examples of using the Infos class to define dataset -# metadata, access entries with mapping-style helpers, and save/load infos. +# metadata, access entries with typed attributes, and save/load infos. # # **Each section is documented and explained.** @@ -59,7 +59,7 @@ # ### Initialize Infos from a plain mapping # %% -infos_from_mapping = Infos.from_mapping( +infos_from_mapping = Infos.model_validate( { "legal": { "owner": "PLAID", @@ -113,12 +113,13 @@ print(f"{infos.storage_backend = }") # %% [markdown] -# ### Retrieve data with mapping-style helpers +# ### Retrieve data with Pydantic attributes # %% -print(f"{infos['legal'] = }") -print(f"{infos.get('storage_backend') = }") -print(f"{infos.to_dict() = }") +print(f"{infos.legal = }") +print(f"{infos.legal.owner = }") +print(f"{infos.storage_backend = }") +print(f"{infos.model_dump(exclude_none=True) = }") # %% [markdown] # ## Section 3: Saving and Loading Infos diff --git a/examples/problem_definition_example.py b/examples/problem_definition_example.py index 0e1a4d7b..4b01641f 100644 --- a/examples/problem_definition_example.py +++ b/examples/problem_definition_example.py @@ -101,10 +101,10 @@ print(f"{problem.test_split = }") # %% -split_names = [problem.get_train_split_name(), problem.get_test_split_name()] +split_names = [next(iter(problem.train_split)), next(iter(problem.test_split))] split_indices = { - problem.get_train_split_name(): problem.get_train_split_indices(), - problem.get_test_split_name(): problem.get_test_split_indices(), + next(iter(problem.train_split)): next(iter(problem.train_split.values())), + next(iter(problem.test_split)): next(iter(problem.test_split.values())), } print(f"{split_names = }") print(f"{split_indices = }") diff --git a/src/plaid/cli/plaidcheck.py b/src/plaid/cli/plaidcheck.py index d3c23740..41b6a312 100644 --- a/src/plaid/cli/plaidcheck.py +++ b/src/plaid/cli/plaidcheck.py @@ -394,6 +394,7 @@ def check_dataset( path: Path, splits: Optional[list[str]] = None, show_progress: bool = True, + problem_definitions: Optional[list[str]] = None, ) -> CheckReport: """Run integrity checks on a local PLAID dataset. @@ -415,6 +416,8 @@ def check_dataset( path: Dataset directory. splits: Optional selected split names. show_progress: Whether to display tqdm progress bars for expensive checks. + problem_definitions: Optional selected problem-definition names. When + omitted, all discovered problem definitions are checked. Returns: A populated :class:`CheckReport`. @@ -436,7 +439,7 @@ def check_dataset( report.add("error", "INFOS_READ_ERROR", "infos.yaml", str(exc)) return report - declared_backend_for_layout = infos.get("storage_backend") + declared_backend_for_layout = infos.storage_backend if not isinstance(declared_backend_for_layout, str): declared_backend_for_layout = None @@ -450,7 +453,7 @@ def check_dataset( # Validate top-level dataset declarations from infos.yaml before calling # init_from_disk(), because storage initialization indexes num_samples by # split and otherwise reports missing entries as opaque KeyError messages. - declared_backend = infos.get("storage_backend") + declared_backend = infos.storage_backend if not isinstance(declared_backend, str): report.add( "error", @@ -459,7 +462,7 @@ def check_dataset( "Missing or invalid 'storage_backend' in infos.yaml", ) - num_samples = infos.get("num_samples", {}) + num_samples = infos.num_samples if not isinstance(num_samples, dict): report.add( "error", "NUM_SAMPLES_INVALID", "infos.yaml", "'num_samples' must be a dict" @@ -663,7 +666,23 @@ def check_dataset( # still run). validate_pb_def_features = declared_backend_for_layout != "cgns" + target_pb_names = ( + set(problem_definitions) if problem_definitions else set(pb_defs) + ) + unknown_pb_names = target_pb_names - set(pb_defs) + for pb_name in sorted(unknown_pb_names): + available = " and ".join(f'"{x}"' for x in sorted(pb_defs)) + report.add( + "error", + "PB_DEF_UNKNOWN", + f"problem_definitions/{pb_name}", + f"Problem definition not found, available are {available}", + ) + target_pb_names = target_pb_names & set(pb_defs) + for pb_name, pb_def in pb_defs.items(): + if pb_name not in target_pb_names: + continue if validate_pb_def_features: for feat in pb_def.input_features: if feat not in all_features: @@ -776,6 +795,12 @@ def _build_parser() -> argparse.ArgumentParser: action="store_true", help="Treat warnings as failure", ) + parser.add_argument( + "--problem-definition", + action="append", + default=None, + help="Problem definition to check (can be provided multiple times)", + ) return parser @@ -795,6 +820,7 @@ def main(argv: Optional[list[str]] = None) -> int: path=args.path, splits=args.split, show_progress=not args.json, + problem_definitions=args.problem_definition, ) if args.json: diff --git a/src/plaid/infos.py b/src/plaid/infos.py index 54ca32b2..769f1c98 100644 --- a/src/plaid/infos.py +++ b/src/plaid/infos.py @@ -112,11 +112,6 @@ def normalize_mapping(cls, infos: dict[str, Any]) -> dict[str, Any]: # Disk I/O # ------------------------------------------------------------------ - @classmethod - def from_mapping(cls, infos: dict[str, Any]) -> "Infos": - """Build a validated :class:`Infos` from a plain mapping.""" - return cls.model_validate(infos) - @classmethod def from_path(cls, path: Union[str, Path]) -> "Infos": """Load and validate an :class:`Infos` from a YAML file. @@ -140,44 +135,7 @@ def from_path(cls, path: Union[str, Path]) -> "Infos": with path.open("r", encoding="utf-8") as file: data = yaml.safe_load(file) or {} - return cls.from_mapping(data) - - # ------------------------------------------------------------------ - # Mapping-like accessors (read-only convenience) - # ------------------------------------------------------------------ - - def to_dict(self) -> dict[str, Any]: - """Return a plain ``dict`` representation of these infos.""" - return self.model_dump(exclude_none=True) - - def __getitem__(self, key: str) -> Any: - """Return the value associated with ``key`` using mapping-style access.""" - if not hasattr(self, key): - raise KeyError(key) - value = getattr(self, key) - # Unwrap nested dataclasses to plain dicts when accessed by key, so that - # callers expecting a YAML-like mapping continue to work transparently. - if hasattr(value, "__pydantic_fields__"): - return { - f: getattr(value, f) - for f in value.__pydantic_fields__ - if getattr(value, f) is not None - } - return value - - def __contains__(self, key: object) -> bool: - """Return whether ``key`` is a known field with a non-``None`` value.""" - if not isinstance(key, str): - return False - if not hasattr(self, key): - return False - return getattr(self, key) is not None - - def get(self, key: str, default: Any = None) -> Any: - """Return ``self[key]`` when present, otherwise ``default``.""" - if key in self: - return self[key] - return default + return cls.model_validate(data) def save_to_file(self, path: Union[str, Path]) -> None: """Save infos to ``path`` as a YAML file. diff --git a/src/plaid/problem_definition.py b/src/plaid/problem_definition.py index ff68a766..cfeba76e 100644 --- a/src/plaid/problem_definition.py +++ b/src/plaid/problem_definition.py @@ -12,13 +12,11 @@ import logging from pathlib import Path -from typing import Any, Literal, Sequence, Union, cast +from typing import Any, Literal, Sequence, Union import yaml from pydantic import BaseModel, ConfigDict, field_validator -from .types import IndexArrayType - # %% Globals logger = logging.getLogger(__name__) @@ -41,18 +39,6 @@ class ProblemDefinition(BaseModel): train_split: dict[str, Sequence[int] | Literal["all"]] test_split: dict[str, Sequence[int] | Literal["all"]] - @classmethod - def from_mapping(cls, data: dict[str, Any]) -> "ProblemDefinition": - """Build a validated :class:`ProblemDefinition` from a plain mapping. - - Args: - data: YAML-like mapping containing one problem definition. - - Returns: - Validated problem definition instance. - """ - return cls.model_validate(data) - @classmethod def from_path(cls, path: str | Path) -> "ProblemDefinition": """Load and validate one problem definition from a YAML file. @@ -76,7 +62,7 @@ def from_path(cls, path: str | Path) -> "ProblemDefinition": with path.open("r", encoding="utf-8") as file: data = yaml.safe_load(file) or {} - return cls.from_mapping(data) + return cls.model_validate(data) @field_validator("input_features", mode="before") @classmethod @@ -121,49 +107,6 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) - # # -------------------------------------------------------------------------# - def get_train_split_name(self) -> str: - """Return the name of the train split.""" - if self.train_split is None: - raise ValueError("train_split is not defined.") - return list(self.train_split.keys())[0] - - def get_train_split_indices(self) -> IndexArrayType | Literal["all"]: - """Return the indices associated with the train split. - - Raises: - ValueError: If `train_split` is not defined. - - Returns: - IndexArrayType | Literal["all"]: The indices associated with the train split. - """ - if self.train_split is None: - raise ValueError("train_split is not defined.") - return cast( - IndexArrayType | Literal["all"], next(iter(self.train_split.values())) - ) - - def get_test_split_name(self) -> str: - """Return the name of the test split.""" - if self.test_split is None: - raise ValueError("test_split is not defined.") - return list(self.test_split.keys())[0] - - def get_test_split_indices(self) -> IndexArrayType | Literal["all"]: - """Return the indices associated with the test split. - - Raises: - ValueError: If `test_split` is not defined. - - Returns: - IndexArrayType | Literal["all"]: The indices associated with the test split. - """ - if self.test_split is None: - raise ValueError("test_split is not defined.") - return cast( - IndexArrayType | Literal["all"], next(iter(self.test_split.values())) - ) - def add_input_features(self, inputs: Union[str, Sequence[str]]) -> None: """Add input features identifiers to the problem. diff --git a/src/plaid/storage/cgns/reader.py b/src/plaid/storage/cgns/reader.py index 4241a055..13ab5481 100644 --- a/src/plaid/storage/cgns/reader.py +++ b/src/plaid/storage/cgns/reader.py @@ -264,7 +264,7 @@ def init_datasetdict_streaming_from_hub( else: infos = load_infos_from_hub(repo_id=repo_id) selected_ids = { - split: range(n_samples) for split, n_samples in infos["num_samples"].items() + split: range(n_samples) for split, n_samples in infos.num_samples.items() } return { diff --git a/src/plaid/storage/cgns/writer.py b/src/plaid/storage/cgns/writer.py index 656b584c..a1095813 100644 --- a/src/plaid/storage/cgns/writer.py +++ b/src/plaid/storage/cgns/writer.py @@ -185,7 +185,7 @@ def configure_dataset_card( None: This function does not return a value; it pushes the dataset card directly to Hugging Face Hub. """ - infos_dict = infos.to_dict() + infos_dict = infos.model_dump(exclude_none=True) dataset_card_str = """--- task_categories: - graph-ml diff --git a/src/plaid/storage/hf_datasets/reader.py b/src/plaid/storage/hf_datasets/reader.py index 956ae24c..6650dc87 100644 --- a/src/plaid/storage/hf_datasets/reader.py +++ b/src/plaid/storage/hf_datasets/reader.py @@ -94,7 +94,7 @@ def download_datasetdict_from_hub( local_dir=tmp_dir, ) infos = load_infos_from_hub(repo_id=repo_id) - split_names = list(infos["num_samples"].keys()) + split_names = list(infos.num_samples.keys()) base = Path(tmp_dir) / "data" data_files = {sn: str(base / f"{sn}*.parquet") for sn in split_names} datasetdict = load_dataset("parquet", data_files=data_files, cache_dir=tmp_dir) diff --git a/src/plaid/storage/hf_datasets/writer.py b/src/plaid/storage/hf_datasets/writer.py index e39295b1..049d931d 100644 --- a/src/plaid/storage/hf_datasets/writer.py +++ b/src/plaid/storage/hf_datasets/writer.py @@ -221,7 +221,7 @@ def configure_dataset_card( None: This function does not return a value; it updates the dataset card directly on Hugging Face Hub. """ - infos_dict = infos.to_dict() + infos_dict = infos.model_dump(exclude_none=True) readme_path = hf_hub_download( repo_id=repo_id, filename="README.md", repo_type="dataset" ) diff --git a/src/plaid/storage/reader.py b/src/plaid/storage/reader.py index 3ce5bb56..0b265142 100644 --- a/src/plaid/storage/reader.py +++ b/src/plaid/storage/reader.py @@ -244,8 +244,8 @@ def init_from_disk( """ infos = load_infos_from_disk(local_dir) - backend = infos["storage_backend"] - num_samples = infos["num_samples"] + backend = infos.storage_backend + num_samples = infos.num_samples datasetdict = get_backend(backend).init_from_disk(path=local_dir) @@ -299,7 +299,7 @@ def download_from_hub( infos = load_infos_from_hub(repo_id) pb_defs = load_problem_definitions_from_hub(repo_id) - backend = infos["storage_backend"] + backend = infos.storage_backend backend_spec = get_backend(backend) backend_spec.download_from_hub(repo_id, local_dir, split_ids, features, overwrite) @@ -340,8 +340,8 @@ def init_streaming_from_hub( ) infos = load_infos_from_hub(repo_id) - backend = infos["storage_backend"] - num_samples = infos["num_samples"] + backend = infos.storage_backend + num_samples = infos.num_samples backend_spec = get_backend(backend) datasetdict = backend_spec.init_datasetdict_streaming_from_hub( diff --git a/src/plaid/storage/writer.py b/src/plaid/storage/writer.py index 6f972395..de7f699b 100644 --- a/src/plaid/storage/writer.py +++ b/src/plaid/storage/writer.py @@ -243,10 +243,10 @@ def sample_constructor(file_path): num_samples={}, storage_backend=backend, ) - infos_data = infos.to_dict() + infos_data = infos.model_dump(exclude_none=True) infos_data["num_samples"] = num_samples infos_data["storage_backend"] = backend - infos = Infos.from_mapping(infos_data) + infos = Infos.model_validate(infos_data) save_infos_to_disk(output_folder, infos) @@ -297,7 +297,7 @@ def push_to_hub( """ infos = load_infos_from_disk(local_dir) - backend = infos["storage_backend"] + backend = infos.storage_backend backend_spec = get_backend(backend) backend_spec.push_local_to_hub(repo_id, local_dir, num_workers=num_workers) diff --git a/src/plaid/storage/zarr/reader.py b/src/plaid/storage/zarr/reader.py index 95bf6853..557840a1 100644 --- a/src/plaid/storage/zarr/reader.py +++ b/src/plaid/storage/zarr/reader.py @@ -321,7 +321,7 @@ def init_datasetdict_streaming_from_hub( else: infos = load_infos_from_hub(repo_id=repo_id) selected_ids = { - split: range(n_samples) for split, n_samples in infos["num_samples"].items() + split: range(n_samples) for split, n_samples in infos.num_samples.items() } return { diff --git a/src/plaid/storage/zarr/writer.py b/src/plaid/storage/zarr/writer.py index e50c10e7..e3323567 100644 --- a/src/plaid/storage/zarr/writer.py +++ b/src/plaid/storage/zarr/writer.py @@ -321,7 +321,7 @@ def configure_dataset_card( None: This function does not return a value; it pushes the dataset card directly to Hugging Face Hub. """ - infos_dict = infos.to_dict() + infos_dict = infos.model_dump(exclude_none=True) dataset_card_str = """--- task_categories: - graph-ml diff --git a/src/plaid/viewer/services/plaid_dataset_service.py b/src/plaid/viewer/services/plaid_dataset_service.py index 1c291def..6c5d973f 100644 --- a/src/plaid/viewer/services/plaid_dataset_service.py +++ b/src/plaid/viewer/services/plaid_dataset_service.py @@ -455,15 +455,16 @@ def _load_feature_metadata(self, dataset_id: str) -> tuple[list[str], list[str]] # writes no derived constant/variable schema metadata. Detect it # from ``infos.yaml`` and return empty feature catalogues so the # viewer falls back to inspecting samples directly. + infos = None try: if self._is_hub_dataset(dataset_id): infos = load_infos_from_hub(dataset_id) else: infos = load_infos_from_disk(str(self._dataset_dir(dataset_id))) except Exception: # pragma: no cover - defensive - infos = {} + pass if ( - infos.get("storage_backend") == "cgns" + infos is not None and infos.storage_backend == "cgns" ): # pragma: no cover - exercised via integration tests metadata = ([], []) self._feature_metadata[dataset_id] = metadata diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index 8260f14e..1478c61d 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -17,11 +17,19 @@ check_dataset, main, ) -from plaid.infos import Infos +from plaid.infos import Infos, Legal _REFERENCE_DATASETS = ("dataset_cgns", "dataset_hf") +def _infos(num_samples: dict[str, int], storage_backend: str = "zarr") -> Infos: + return Infos( + legal=Legal(owner="owner", license="license"), + num_samples=num_samples, + storage_backend=storage_backend, + ) + + def _copy_reference_dataset(tmp_path: Path, name: str = "dataset_cgns") -> Path: """Copy a reference dataset (CGNS or HF) used by container tests. @@ -418,7 +426,7 @@ def test_check_dataset_loader_failures_and_header_validations( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 1}}, # noqa: ARG005 + lambda path: _infos({"train": 1}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -449,7 +457,11 @@ def test_check_dataset_loader_failures_and_header_validations( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": 12, "num_samples": "bad"}, # noqa: ARG005 + lambda path: Infos.model_construct( # noqa: ARG005 + legal=Legal(owner="owner", license="license"), + storage_backend=12, + num_samples="bad", + ), ) report_header = check_dataset(dataset) assert any(msg.code == "BACKEND_MISSING" for msg in report_header.messages) @@ -465,7 +477,7 @@ def test_check_dataset_split_and_data_warnings_and_duplicates( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 3}}, # noqa: ARG005 + lambda path: _infos({"train": 3}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, "load_metadata_from_disk", lambda _path: ({}, {"Var": {}}, {}, None) @@ -521,7 +533,7 @@ def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> N monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 1}}, # noqa: ARG005 + lambda path: _infos({"train": 1}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -552,7 +564,7 @@ def test_check_dataset_missing_num_samples_split_is_clear( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 1}}, # noqa: ARG005 + lambda path: _infos({"train": 1}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -577,7 +589,7 @@ def test_check_dataset_problem_definition_validation_paths( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 2}}, # noqa: ARG005 + lambda path: _infos({"train": 2}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -645,10 +657,7 @@ def test_check_dataset_problem_definition_instantiates_filtered_features( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: { # noqa: ARG005 - "storage_backend": "zarr", - "num_samples": {"train": 1, "test": 1}, - }, + lambda path: _infos({"train": 1, "test": 1}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -711,6 +720,104 @@ class _PBDef: ) +def test_check_dataset_filters_problem_definitions(tmp_path: Path, monkeypatch) -> None: + """Selected problem definitions should be checked without checking others.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "problem_definitions").mkdir() + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 1}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Input": {}, "Output": {}}, {"train": {}}, None), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ( # noqa: ARG005 + {"train": _FakeDataset(1)}, + {"train": _FakeConverter([_FakeSampleForCheck()])}, + ), + ) + + class _PBDef: + def __init__(self, input_features, output_features, train_split): + self.input_features = input_features + self.output_features = output_features + self.train_split = train_split + self.test_split = None + + monkeypatch.setattr( + plaidcheck, + "load_problem_definitions_from_disk", + lambda path: { # noqa: ARG005 + "selected": _PBDef(["Input"], ["Output"], {"train": [0]}), + "skipped": _PBDef(["UnknownInput"], ["UnknownOutput"], {"ghost": [0]}), + }, + ) + + report = check_dataset( + dataset, + show_progress=False, + problem_definitions=["selected"], + ) + + assert not any( + "problem_definitions/skipped" in msg.location for msg in report.messages + ) + assert not any(msg.code == "PB_DEF_UNKNOWN_INPUT" for msg in report.messages) + assert not any(msg.code == "PB_DEF_UNKNOWN_SPLIT" for msg in report.messages) + + +def test_check_dataset_reports_unknown_requested_problem_definition( + tmp_path: Path, monkeypatch +) -> None: + """Unknown requested problem definitions should be reported explicitly.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "problem_definitions").mkdir() + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 0}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Input": {}}, {"train": {}}, None), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ({"train": _FakeDataset(0)}, {"train": _FakeConverter([])}), # noqa: ARG005 + ) + + class _PBDef: + input_features = ["Input"] + output_features = [] + train_split = {"train": []} + test_split = None + + monkeypatch.setattr( + plaidcheck, + "load_problem_definitions_from_disk", + lambda path: {"known": _PBDef()}, # noqa: ARG005 + ) + + report = check_dataset( + dataset, + show_progress=False, + problem_definitions=["ghost"], + ) + + assert any(msg.code == "PB_DEF_UNKNOWN" for msg in report.messages) + assert any("known" in msg.message for msg in report.messages) + + def test_check_dataset_problem_definition_read_error_names_yaml_file( tmp_path: Path, monkeypatch ) -> None: @@ -729,7 +836,7 @@ def test_check_dataset_problem_definition_read_error_names_yaml_file( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 0}}, # noqa: ARG005 + lambda path: _infos({"train": 0}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -759,7 +866,7 @@ def test_main_strict_returns_warning_exit_code( monkeypatch.setattr( plaidcheck, "check_dataset", - lambda path, splits=None, show_progress=True: report, # noqa: ARG005 + lambda path, splits=None, show_progress=True, problem_definitions=None: report, # noqa: ARG005 ) code = main([str(tmp_path), "--strict"]) _ = capsys.readouterr().out @@ -768,19 +875,26 @@ def test_main_strict_returns_warning_exit_code( def test_main_json_disables_progress(monkeypatch, tmp_path: Path, capsys) -> None: - """JSON mode should disable progress bars in check_dataset.""" - seen: dict[str, bool] = {} + """JSON mode should disable progress bars and forward CLI filters.""" + seen: dict[str, Any] = {} report = CheckReport(messages=[]) - def _fake_check_dataset(path, splits=None, show_progress=True): # noqa: ARG001 + def _fake_check_dataset( + path, # noqa: ARG001 + splits=None, # noqa: ARG001 + show_progress=True, + problem_definitions=None, + ): seen["show_progress"] = show_progress + seen["problem_definitions"] = problem_definitions return report monkeypatch.setattr(plaidcheck, "check_dataset", _fake_check_dataset) - code = main([str(tmp_path), "--json"]) + code = main([str(tmp_path), "--json", "--problem-definition", "pb"]) payload = json.loads(capsys.readouterr().out) assert code == 0 assert payload["counts"] == {"error": 0, "warning": 0, "info": 0} assert seen["show_progress"] is False + assert seen["problem_definitions"] == ["pb"] diff --git a/tests/conftest.py b/tests/conftest.py index 88a3355d..fb9b2a94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,7 +91,7 @@ def other_samples(nb_samples: int, zone_name: str, base_name: str) -> list[Sampl @pytest.fixture() def infos(): - return Infos.from_mapping( + return Infos.model_validate( { "legal": {"owner": "PLAID2", "license": "BSD-3"}, "data_production": {"type": "simulation", "simulator": "Z-set"}, diff --git a/tests/storage/test_cgns_init.py b/tests/storage/test_cgns_init.py index a3fc99e4..d66746f8 100644 --- a/tests/storage/test_cgns_init.py +++ b/tests/storage/test_cgns_init.py @@ -169,7 +169,7 @@ def test_cgns_backend_configure_dataset_card_requires_local_dir(): with pytest.raises(ValueError, match="local_dir must be provided for cgns backend"): CgnsBackend.configure_dataset_card( repo_id="dummy/repo", - infos=Infos.from_mapping( + infos=Infos.model_validate( { "legal": {"owner": "owner", "license": "cc-by-4.0"}, "num_samples": {}, @@ -188,7 +188,7 @@ def fake_configure_dataset_card(**kwargs): monkeypatch.setattr(cgns, "configure_dataset_card", fake_configure_dataset_card) - infos = Infos.from_mapping( + infos = Infos.model_validate( { "legal": {"owner": "owner", "license": "cc-by-4.0"}, "num_samples": {}, diff --git a/tests/storage/test_hf_datasets_init.py b/tests/storage/test_hf_datasets_init.py index 3771e74e..5dd2c816 100644 --- a/tests/storage/test_hf_datasets_init.py +++ b/tests/storage/test_hf_datasets_init.py @@ -210,7 +210,7 @@ def fake_configure_dataset_card( hf_datasets, "configure_dataset_card", fake_configure_dataset_card ) - infos = Infos.from_mapping( + infos = Infos.model_validate( { "legal": {"owner": "owner", "license": "cc-by-4.0"}, "num_samples": {}, diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index 8495c9e6..ca9fce37 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -185,8 +185,8 @@ def _sample_constructor(id): @pytest.fixture() def split_ids(problem_definition) -> dict: return { - "train": problem_definition.get_train_split_indices(), - "test": problem_definition.get_test_split_indices(), + "train": problem_definition.train_split["train"], + "test": problem_definition.test_split["test"], } diff --git a/tests/storage/test_zarr_init.py b/tests/storage/test_zarr_init.py index 220743be..d94d9436 100644 --- a/tests/storage/test_zarr_init.py +++ b/tests/storage/test_zarr_init.py @@ -176,7 +176,7 @@ def fake_configure_dataset_card( monkeypatch.setattr(zarr, "configure_dataset_card", fake_configure_dataset_card) - infos = Infos.from_mapping( + infos = Infos.model_validate( { "legal": {"owner": "owner", "license": "cc-by-4.0"}, "num_samples": {}, diff --git a/tests/test_problem_definition.py b/tests/test_problem_definition.py index b462421a..2bd7ebf7 100644 --- a/tests/test_problem_definition.py +++ b/tests/test_problem_definition.py @@ -109,7 +109,7 @@ def test_feature_lists_must_not_be_empty(self): # -------------------------------------------------------------------------# def test_from_mapping_validates_and_normalizes(self): - loaded = ProblemDefinition.from_mapping( + loaded = ProblemDefinition.model_validate( { "name": "pb_single", "input_features": ["in_b", "in_a"], @@ -122,10 +122,8 @@ def test_from_mapping_validates_and_normalizes(self): assert loaded.name == "pb_single" assert loaded.input_features == ["in_a", "in_b"] assert loaded.output_features == ["out_a", "out_b"] - assert loaded.get_train_split_name() == "train_0" - assert loaded.get_test_split_name() == "test_0" - assert loaded.get_train_split_indices() == [0, 1] - assert loaded.get_test_split_indices() == [2] + assert loaded.train_split == {"train_0": [0, 1]} + assert loaded.test_split == {"test_0": [2]} def test_from_path_loads_single_yaml_file(self, tmp_path: Path): file_path = tmp_path / "problem.yaml" @@ -145,8 +143,8 @@ def test_from_path_loads_single_yaml_file(self, tmp_path: Path): loaded = ProblemDefinition.from_path(file_path) assert loaded.name == "pb" - assert loaded.get_train_split_name() == "train" - assert loaded.get_test_split_name() == "test" + assert loaded.train_split == {"train": [0]} + assert loaded.test_split == {"test": [1]} def test_from_path_adds_yaml_suffix(self, tmp_path: Path): file_path = tmp_path / "problem.yaml" @@ -221,24 +219,12 @@ def test_split_replacement_logs_warning(self, problem_definition, caplog): assert "already exists -> data will be replaced" in caplog.text - def test_get_split_paths(self, problem_definition): + def test_split_fields_are_plain_dictionaries(self, problem_definition): problem_definition.train_split = {"train_0": [0, 1, 2]} problem_definition.test_split = {"test_0": [3, 4]} - assert problem_definition.get_train_split_name() == "train_0" - assert problem_definition.get_test_split_name() == "test_0" - assert problem_definition.get_train_split_indices() == [0, 1, 2] - assert problem_definition.get_test_split_indices() == [3, 4] - - def test_get_split_paths_raise_when_not_defined(self, problem_definition): - with pytest.raises(ValueError, match="train_split is not defined"): - problem_definition.get_train_split_name() - with pytest.raises(ValueError, match="train_split is not defined"): - problem_definition.get_train_split_indices() - with pytest.raises(ValueError, match="test_split is not defined"): - problem_definition.get_test_split_name() - with pytest.raises(ValueError, match="test_split is not defined"): - problem_definition.get_test_split_indices() + assert problem_definition.train_split == {"train_0": [0, 1, 2]} + assert problem_definition.test_split == {"test_0": [3, 4]} def test_add_feature_identifiers_duplicate_checks(self, problem_definition): problem_definition.add_input_features(["in_1", "in_2"]) @@ -278,5 +264,5 @@ def test_save_and_load_keep_yaml_suffix(self, tmp_path: Path): loaded = ProblemDefinition.from_path(file_path) assert loaded.name == "pb" - assert loaded.get_train_split_name() == "train" - assert loaded.get_test_split_name() == "test" + assert loaded.train_split == {"train": [0]} + assert loaded.test_split == {"test": [1]} diff --git a/tests/utils/test_info.py b/tests/utils/test_info.py index e356fc61..bfe9df20 100644 --- a/tests/utils/test_info.py +++ b/tests/utils/test_info.py @@ -80,7 +80,7 @@ def test_infos_save_and_load_roundtrip(tmp_path): "num_samples": {"train": 10}, "storage_backend": "zarr", } - model = Infos.from_mapping(infos) + model = Infos.model_validate(infos) target = tmp_path / "infos.yaml" model.save_to_file(target) @@ -98,7 +98,7 @@ def test_infos_from_path_directory(tmp_path): "num_samples": {"train": 10}, "storage_backend": "zarr", } - Infos.from_mapping(infos).save_to_file(tmp_path / "infos.yaml") + Infos.model_validate(infos).save_to_file(tmp_path / "infos.yaml") reloaded = Infos.from_path(tmp_path) assert reloaded.legal.license == "cc-by-4.0" @@ -150,74 +150,30 @@ def test_validate_required_only_missing_legal(): Infos.validate_required_only({}) -def test_to_dict_returns_plain_mapping(): - model = Infos.from_mapping( +def test_model_dump_returns_plain_mapping(): + model = Infos.model_validate( { "legal": {"owner": "o", "license": "l"}, "num_samples": {}, "storage_backend": "zarr", } ) - d = model.to_dict() + d = model.model_dump(exclude_none=True) assert d["legal"] == {"owner": "o", "license": "l"} assert d["storage_backend"] == "zarr" -def test_getitem_returns_plain_value_and_unwraps_nested_dataclasses(): - model = Infos.from_mapping( - { - "legal": {"owner": "o", "license": "l"}, - "num_samples": {}, - "storage_backend": "zarr", - } - ) - # Plain field. - assert model["storage_backend"] == "zarr" - # Nested dataclass is returned as a dict, with None entries dropped. - legal_dict = model["legal"] - assert legal_dict == {"owner": "o", "license": "l"} - - -def test_getitem_raises_key_error_for_unknown_field(): - model = Infos.from_mapping( - { - "legal": {"owner": "o", "license": "l"}, - "num_samples": {}, - "storage_backend": "zarr", - } - ) - with pytest.raises(KeyError): - model["does_not_exist"] - - -def test_contains_handles_strings_and_non_strings(): - model = Infos.from_mapping( +def test_attribute_access_returns_typed_values(): + model = Infos.model_validate( { "legal": {"owner": "o", "license": "l"}, "num_samples": {}, "storage_backend": "zarr", } ) - assert "legal" in model - assert "storage_backend" in model - # Field exists but is None -> not "in" the model. - assert "data_production" not in model - # Unknown attribute name. - assert "nope" not in model - # Non-string keys are never in the model. - assert (123 in model) is False - - -def test_get_returns_default_when_missing(): - model = Infos.from_mapping( - { - "legal": {"owner": "o", "license": "l"}, - "num_samples": {}, - "storage_backend": "zarr", - } - ) - assert model.get("data_description", "fallback") == "fallback" - assert model.get("legal")["owner"] == "o" + assert model.storage_backend == "zarr" + assert model.legal.owner == "o" + assert model.legal.license == "l" def test_save_to_file_treats_suffixless_path_as_directory(tmp_path): @@ -263,7 +219,7 @@ def test_save_to_file_preserves_unknown_future_keys(tmp_path, monkeypatch): from plaid import infos as infos_mod monkeypatch.setattr(infos_mod, "_KEY_ORDER", ()) - model = Infos.from_mapping( + model = Infos.model_validate( { "legal": {"owner": "o", "license": "l"}, "num_samples": {}, From 8ef24f96ac202779aef39eb59129d33c4f731a61 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 4 Jun 2026 22:10:48 +0200 Subject: [PATCH 04/11] wip --- docs/source/examples_tutorials.md | 2 +- docs/source/tutorials/downloadable_example.md | 16 ++++++++++++++++ docs/zensical.toml | 6 +++--- 3 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 docs/source/tutorials/downloadable_example.md diff --git a/docs/source/examples_tutorials.md b/docs/source/examples_tutorials.md index cfcb048a..095e2048 100644 --- a/docs/source/examples_tutorials.md +++ b/docs/source/examples_tutorials.md @@ -12,8 +12,8 @@ You can find here detailed examples for different parts of plaid, explained in J * [Sample](notebooks/containers/sample_example.md) * [Problem definition](notebooks/problem_definition_example.md) * [Infos](notebooks/infos_example.md) -* [Downloadable samples](notebooks/downloadable_example/sample_example.md) ## Tutorials +* [Downloadable samples](tutorials/downloadable_example.md) * [Conversion tutorial](tutorials/storage.md) diff --git a/docs/source/tutorials/downloadable_example.md b/docs/source/tutorials/downloadable_example.md new file mode 100644 index 00000000..8adb9e2a --- /dev/null +++ b/docs/source/tutorials/downloadable_example.md @@ -0,0 +1,16 @@ +--- +title: Downloadable samples +--- + +# Downloadable samples + +Retrieving sample examples is as easy as: + +```python +from plaid.downloadable_examples import AVAILABLE_EXAMPLES, samples + +print(AVAILABLE_EXAMPLES) +print("samples.vki_ls59:", samples.vki_ls59) +``` + +The first call to `samples.vki_ls59` triggers a download and takes a few seconds, whereas subsequent calls are instantaneous because they reuse the cached sample. diff --git a/docs/zensical.toml b/docs/zensical.toml index 87591ece..c4ff7ef7 100644 --- a/docs/zensical.toml +++ b/docs/zensical.toml @@ -52,11 +52,11 @@ nav = [ { "Sample" = "notebooks/containers/sample_example.md" }, { "Problem definition" = "notebooks/problem_definition_example.md" }, { "Infos" = "notebooks/infos_example.md" }, - { "Downloadable samples" = "notebooks/downloadable_example/sample_example.md" }, + { "Downloadable samples" = "tutorials/downloadable_example.md" }, { "Conversion tutorial" = "tutorials/storage.md" }, ] }, - { "`plaid-viewer`" = "concepts/viewer.md" }, - { "`plaid-check`" = "concepts/check.md" }, + { "plaid-viewer" = "concepts/viewer.md" }, + { "plaid-check" = "concepts/check.md" }, # >>> AUTO-GENERATED API REFERENCE START # The block below is overwritten by `python docs/generate_api_stubs.py`. # Edit that script (or the markers) instead of changing this section by hand. From 7a226615e7ff1b63b6d51651eea4a7bd2627b3fb Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 4 Jun 2026 22:19:50 +0200 Subject: [PATCH 05/11] wip --- src/plaid/containers/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plaid/containers/sample.py b/src/plaid/containers/sample.py index 2ede2cc8..2a707e52 100644 --- a/src/plaid/containers/sample.py +++ b/src/plaid/containers/sample.py @@ -1211,7 +1211,7 @@ def get_zone_type( if zone_node is None: raise KeyError( - f"there is no base/zone <{base}/{zone}>, you should first create one with `Sample.init_zone({zone=},{base=})`" + f"There is no base/zone <{base}/{zone}>, you should first create one with `Sample.init_zone({zone=},{base=})`." ) return CGU.getValueByPath(zone_node, "ZoneType").tobytes().decode() From f817f5412ef5db21aeedea1e36e49f9d7da8e4f4 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Fri, 5 Jun 2026 06:28:34 +0200 Subject: [PATCH 06/11] wip --- tests/cli/test_plaidcheck.py | 87 ++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index 1478c61d..d5dcb6a3 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -12,6 +12,7 @@ from plaid.cli.plaidcheck import ( CheckReport, _check_numeric_content, + _check_problem_definition_sample_features, _is_branch_without_data, _is_branch_without_data_in_mapping, check_dataset, @@ -370,6 +371,30 @@ def to_plaid( return self._samples[idx] +class _FakeSampleWithFeatureFailure: + """Sample-like object that fails for selected feature paths.""" + + def __init__(self, values: dict[str, Any], failing_features: set[str]) -> None: + self._values = values + self._failing_features = failing_features + + def get_feature_by_path(self, path: str) -> Any: + """Return configured values or raise for configured paths. + + Args: + path: Feature path to read. + + Returns: + Configured feature value. + + Raises: + RuntimeError: If the path is configured as failing. + """ + if path in self._failing_features: + raise RuntimeError(f"cannot read {path}") + return self._values[path] + + def test_check_numeric_content_all_remaining_branches() -> None: """Numeric checker should report all remaining invalid content cases.""" assert _check_numeric_content([]) == "value is empty" @@ -381,6 +406,68 @@ def test_check_numeric_content_all_remaining_branches() -> None: ) +def test_check_problem_definition_sample_reports_conversion_error() -> None: + """Problem-definition sample conversion failures should be reported.""" + report = CheckReport(messages=[]) + converter = _FakeConverter([_FakeSampleForCheck()], fail_indices={0}) + + _check_problem_definition_sample_features( + pb_name="pb", + split_dict_name="train_split", + split_name="train", + idx=0, + dataset=_FakeDataset(1), + converter=converter, + features=["Input"], + report=report, + ) + + assert len(report.messages) == 1 + msg = report.messages[0] + assert msg.severity == "error" + assert msg.code == "PB_DEF_SAMPLE_CONVERSION_ERROR" + assert msg.location == "problem_definitions/pb/train_split/train[0]" + assert msg.message == "boom" + + +def test_check_problem_definition_sample_reports_feature_read_error_and_continues() -> ( + None +): + """Feature read failures should be reported without stopping later checks.""" + report = CheckReport(messages=[]) + sample = _FakeSampleWithFeatureFailure( + values={"Good": np.array([1.0]), "BadValue": np.array([np.nan])}, + failing_features={"Unreadable"}, + ) + converter = _FakeConverter([sample]) + + _check_problem_definition_sample_features( + pb_name="pb", + split_dict_name="test_split", + split_name="test", + idx=0, + dataset=_FakeDataset(1), + converter=converter, + features=["Unreadable", "Good", "BadValue"], + report=report, + ) + + assert any( + msg.severity == "error" + and msg.code == "PB_DEF_FEATURE_READ_ERROR" + and msg.location == "problem_definitions/pb/test_split/test[0] Unreadable" + and msg.message == "cannot read Unreadable" + for msg in report.messages + ) + assert any( + msg.severity == "warning" + and msg.code == "PB_DEF_INVALID_FEATURE_VALUE" + and msg.location == "problem_definitions/pb/test_split/test[0] BadValue" + and msg.message == "contains NaN" + for msg in report.messages + ) + + def test_is_branch_without_data_false_variants(monkeypatch) -> None: """Branch helper should return False for missing tree/node/children layout.""" From 9ba33d0896896025a1137509a1aa7132a5ad02bf Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Fri, 5 Jun 2026 07:14:15 +0200 Subject: [PATCH 07/11] wip --- docs/source/concepts/dataset.md | 3 +- docs/source/concepts/infos.md | 26 ++- docs/source/concepts/problem_definition.md | 11 +- docs/source/tutorials/storage.md | 3 +- examples/infos_example.py | 16 +- examples/problem_definition_example.py | 9 +- src/plaid/cli/plaidcheck.py | 7 +- src/plaid/infos.py | 72 ++++--- src/plaid/problem_definition.py | 17 +- src/plaid/storage/cgns/writer.py | 2 +- src/plaid/storage/common/reader.py | 7 +- src/plaid/storage/common/writer.py | 19 +- src/plaid/storage/hf_datasets/writer.py | 2 +- src/plaid/storage/writer.py | 22 +-- src/plaid/storage/zarr/writer.py | 2 +- tests/cli/test_plaidcheck.py | 9 +- .../problem_definitions/regression_1.yaml | 1 - .../problem_definitions/regression_1.yaml | 1 - tests/storage/test_storage.py | 7 +- tests/test_problem_definition.py | 39 ++-- tests/utils/test_info.py | 175 ++++++++++-------- 21 files changed, 237 insertions(+), 213 deletions(-) diff --git a/docs/source/concepts/dataset.md b/docs/source/concepts/dataset.md index 270bb5bf..0cd9a317 100644 --- a/docs/source/concepts/dataset.md +++ b/docs/source/concepts/dataset.md @@ -121,7 +121,6 @@ from plaid.infos import DataDescription, Infos, Legal from plaid.storage import save_to_disk pb_def = ProblemDefinition( - name="regression_1", input_features=["Global/input"], output_features=["Base/Zone/VertexFields/pressure"], train_split={"train": [0, 1, 2]}, @@ -138,7 +137,7 @@ save_to_disk( sample_constructor=sample_constructor, ids={"train": [0, 1, 2]}, infos=infos, - pb_defs=pb_def, + pb_defs={"regression_1": pb_def}, ) ``` diff --git a/docs/source/concepts/infos.md b/docs/source/concepts/infos.md index be1550f6..333be889 100644 --- a/docs/source/concepts/infos.md +++ b/docs/source/concepts/infos.md @@ -14,8 +14,8 @@ In the current API, infos stores: hardware, contact, or location - `data_description`, for optional dataset description entries such as the number of samples, DOE, inputs, and outputs -- `num_samples`, as a required dictionary keyed by split name -- `storage_backend`, as a required storage backend identifier +- `num_samples`, as a dictionary keyed by split name, populated by storage writers +- `storage_backend`, as a storage backend identifier, populated by storage writers ## Basic usage @@ -24,8 +24,6 @@ from plaid.infos import DataProduction, Infos, Legal infos = Infos( legal=Legal(owner="Safran", license="proprietary"), - num_samples={"train": 10, "test": 5}, - storage_backend="zarr", data_production=DataProduction( type="simulation", physics="fluid dynamics", @@ -44,22 +42,31 @@ infos = Infos.model_validate( "owner": "Safran", "license": "proprietary", }, - "num_samples": {"train": 10, "test": 5}, - "storage_backend": "zarr", } ) ``` +`num_samples` and `storage_backend` are derived from the chosen storage backend +and the saved split contents. They can be omitted when creating an `Infos` +object that will later be passed to `save_to_disk(...)`; PLAID fills them before +writing `infos.yaml`. + ## Loading from disk -Load infos from a dataset path or directly from an `infos.yaml` file: +Load infos from a complete dataset path or directly from an `infos.yaml` file: ```python infos = Infos.from_path("/path/to/plaid_dataset") ``` When a directory is provided, `Infos.from_path(...)` looks for `infos.yaml` -inside that directory. +inside that directory. By default, loading from disk requires the persisted +storage metadata (`num_samples` and `storage_backend`) to be present. To load a +draft infos file that has not been produced by `save_to_disk(...)`, use: + +```python +infos = Infos.from_path("/path/to/draft/infos.yaml", require_persisted=False) +``` ## Saving @@ -85,7 +92,8 @@ payload = infos.model_dump(exclude_none=True) ## Notes -- `legal.owner`, `legal.license`, `num_samples`, and `storage_backend` are required when validating complete infos. +- `legal.owner` and `legal.license` are required when creating infos. +- `num_samples` and `storage_backend` are required when loading persisted dataset infos. - `num_samples` and `storage_backend` are overwritten with the actual saved dataset values when `save_to_disk(..., infos=...)` is called before writing `infos.yaml`. - Unknown keys are rejected during validation. - `save_to_file(...)` writes YAML using the standard infos key order. diff --git a/docs/source/concepts/problem_definition.md b/docs/source/concepts/problem_definition.md index c462bdf7..86fd9a6d 100644 --- a/docs/source/concepts/problem_definition.md +++ b/docs/source/concepts/problem_definition.md @@ -8,18 +8,19 @@ title: Problem definition In the current API, a problem definition stores: -- `name` (`str`, required) - `input_features` (`list[str]`, required and non-empty) - `output_features` (`list[str]`, required and non-empty) - `train_split` and `test_split` (required) +The problem identifier is not stored in the model. On disk, it is the YAML +filename stem; in memory, it is the dictionary key used for the definition. + ## Basic usage ```python from plaid import ProblemDefinition pb = ProblemDefinition( - name="regression_1", input_features=[ "Base/Zone/GridCoordinates/CoordinateX", "Base/Zone/GridCoordinates/CoordinateY", @@ -41,7 +42,6 @@ after reading YAML: ```python pb = ProblemDefinition.model_validate( { - "name": "regression_1", "input_features": ["Base/Zone/GridCoordinates/CoordinateX"], "output_features": ["Base/Zone/VertexFields/pressure"], "train_split": {"train": [0, 1, 2]}, @@ -60,7 +60,7 @@ pb = ProblemDefinition.from_path( ) ``` -At storage level, problem definitions are loaded as a dictionary keyed by name: +At storage level, problem definitions are loaded as a dictionary keyed by YAML filename stem: ```python from plaid.storage import load_problem_definitions_from_disk @@ -77,6 +77,9 @@ Save to YAML: pb.save_to_file("problem_definitions/regression_1.yaml") ``` +This writes no `name:` key; `regression_1` is inferred from the filename by the +storage loader. + ## Notes - Input/output features are plain strings corresponding to CGNS paths. diff --git a/docs/source/tutorials/storage.md b/docs/source/tutorials/storage.md index b95011c5..2fef6142 100644 --- a/docs/source/tutorials/storage.md +++ b/docs/source/tutorials/storage.md @@ -124,7 +124,6 @@ output_features = [ pb_def = ProblemDefinition( - name="regression_1", input_features=input_features, output_features=output_features, train_split={"train": "all"}, @@ -180,7 +179,7 @@ for backend in all_backends: ids=ids, backend=backend, infos=infos, - pb_defs=pb_def, + pb_defs={"regression_1": pb_def}, num_proc=N_PROC, overwrite=True, verbose=True) diff --git a/examples/infos_example.py b/examples/infos_example.py index dada6711..fbcb790c 100644 --- a/examples/infos_example.py +++ b/examples/infos_example.py @@ -50,8 +50,6 @@ print("#---# Infos") infos = Infos( legal=Legal(owner="PLAID", license="MIT"), - num_samples={"train": 2, "test": 2}, - storage_backend="cgns", ) print(f"{infos = }") @@ -66,8 +64,6 @@ "license": "MIT", }, "data_description": "Example metadata for a PLAID dataset.", - "num_samples": {"train": 2, "test": 2}, - "storage_backend": "cgns", } ) print(f"{infos_from_mapping = }") @@ -101,16 +97,14 @@ print(f"{infos.data_production = }") # %% [markdown] -# ### Set data description, sample counts, and storage backend +# ### Set data description # %% infos.data_description = "Example dataset generated for the Infos example." -infos.num_samples = {"train": 3, "test": 1} -infos.storage_backend = "zarr" print(f"{infos.data_description = }") -print(f"{infos.num_samples = }") -print(f"{infos.storage_backend = }") +print(f"{infos.num_samples = }") # Populated by save_to_disk for saved datasets. +print(f"{infos.storage_backend = }") # Populated by save_to_disk for saved datasets. # %% [markdown] # ### Retrieve data with Pydantic attributes @@ -144,12 +138,12 @@ # ### Load Infos from a YAML file # %% -loaded_infos = Infos.from_path(infos_save_fname) +loaded_infos = Infos.from_path(infos_save_fname, require_persisted=False) print(loaded_infos) # %% [markdown] # ### Load Infos from a directory containing infos.yaml # %% -loaded_infos_from_dir = Infos.from_path(test_pth) +loaded_infos_from_dir = Infos.from_path(test_pth, require_persisted=False) print(loaded_infos_from_dir) \ No newline at end of file diff --git a/examples/problem_definition_example.py b/examples/problem_definition_example.py index 4b01641f..9dc27db0 100644 --- a/examples/problem_definition_example.py +++ b/examples/problem_definition_example.py @@ -55,7 +55,6 @@ # %% print("#---# ProblemDefinition") problem = ProblemDefinition( - name="my_problem_definition", input_features=[scalar_3_feat_id, field_1_feat_id], output_features=[field_2_feat_id], train_split={"train": [0, 1]}, @@ -86,10 +85,10 @@ # This section demonstrates how to handle and configure ProblemDefinition objects and access data. # %% [markdown] -# ### Set Problem Definition name - -# %% -print(f"{problem.name = }") +# ### Problem Definition identifier +# +# ProblemDefinition has no stored name. When saved in a dataset, its identifier +# is the YAML filename stem or the key in a `dict[str, ProblemDefinition]`. # %% [markdown] # ### Set Problem Definition split diff --git a/src/plaid/cli/plaidcheck.py b/src/plaid/cli/plaidcheck.py index 41b6a312..ae73efee 100644 --- a/src/plaid/cli/plaidcheck.py +++ b/src/plaid/cli/plaidcheck.py @@ -11,14 +11,19 @@ from tqdm import tqdm from plaid.constants import CGNS_FIELD_LOCATIONS +from plaid.infos import Infos from plaid.storage import init_from_disk from plaid.storage.common.reader import ( - load_infos_from_disk, load_metadata_from_disk, load_problem_definitions_from_disk, ) +def load_infos_from_disk(path: Path) -> Infos: + """Load infos for checker diagnostics without persisted-field enforcement.""" + return Infos.from_path(path, require_persisted=False) + + @dataclass class CheckMessage: """One integrity check message. diff --git a/src/plaid/infos.py b/src/plaid/infos.py index 769f1c98..88f65c33 100644 --- a/src/plaid/infos.py +++ b/src/plaid/infos.py @@ -7,7 +7,7 @@ from typing import Any, Union import yaml -from pydantic import BaseModel, ConfigDict, ValidationError +from pydantic import BaseModel, ConfigDict, Field, ValidationError from pydantic.dataclasses import dataclass logger = logging.getLogger(__name__) @@ -17,14 +17,6 @@ ) -@dataclass(config=_PD_CONFIG) -class Legal: - """Legal ownership and licensing metadata.""" - - owner: str - license: str - - @dataclass(config=_PD_CONFIG) class DataProduction: """Dataset production context metadata.""" @@ -42,7 +34,8 @@ class DataProduction: # Order used when serializing to YAML. _KEY_ORDER = ( - "legal", + "owner", + "license", "data_production", "data_description", "num_samples", @@ -55,24 +48,39 @@ class Infos(BaseModel): model_config = _PD_CONFIG - legal: Legal + owner: str + license: str data_production: DataProduction | None = None data_description: str | None = None - num_samples: dict[str, int] - storage_backend: str + num_samples: dict[str, int] = Field(default_factory=dict) + storage_backend: str | None = None + + def require_persisted(self) -> "Infos": + """Validate fields that must exist in persisted dataset infos. + + ``num_samples`` and ``storage_backend`` are derived by storage writers + when a dataset is saved, so they are optional while users prepare an + ``Infos`` object. Once infos are loaded from disk or the Hub, however, + readers need both fields to select the backend and split sizes. + """ + if "num_samples" not in self.model_fields_set: + raise ValueError("Missing required persisted infos field: 'num_samples'") + if "storage_backend" not in self.model_fields_set or not self.storage_backend: + raise ValueError("Missing required persisted infos field: 'storage_backend'") + return self @classmethod def validate_authorized_only(cls, infos: dict[str, Any]) -> "Infos": """Validate schema/authorized keys without enforcing required sections.""" normalized = dict(infos) - had_legal = "legal" in normalized + had_owner = "owner" in normalized + had_license = "license" in normalized had_num_samples = "num_samples" in normalized had_storage_backend = "storage_backend" in normalized - if not had_legal: - normalized["legal"] = { - "owner": "__placeholder__", - "license": "__placeholder__", - } + if not had_owner: + normalized["owner"] = "__placeholder__" + if not had_license: + normalized["license"] = "__placeholder__" if not had_num_samples: normalized["num_samples"] = {} if not had_storage_backend: @@ -89,8 +97,10 @@ def validate_authorized_only(cls, infos: dict[str, Any]) -> "Infos": raise KeyError(f"Unauthorized infos key: {loc!r}") from exc raise - if not had_legal: - model.legal = Legal(owner="", license="") + if not had_owner: + model.owner = "" + if not had_license: + model.license = "" if not had_num_samples: model.num_samples = {} if not had_storage_backend: @@ -99,8 +109,13 @@ def validate_authorized_only(cls, infos: dict[str, Any]) -> "Infos": @classmethod def validate_required_only(cls, infos: dict[str, Any]) -> None: - """Validate required entries using pydantic-required fields.""" - cls.model_validate(infos) + """Validate entries required for persisted dataset infos.""" + cls.model_validate(infos).require_persisted() + + @classmethod + def validate_persisted(cls, infos: dict[str, Any]) -> "Infos": + """Validate and return complete infos loaded from persisted storage.""" + return cls.model_validate(infos).require_persisted() @classmethod def normalize_mapping(cls, infos: dict[str, Any]) -> dict[str, Any]: @@ -113,12 +128,14 @@ def normalize_mapping(cls, infos: dict[str, Any]) -> dict[str, Any]: # ------------------------------------------------------------------ @classmethod - def from_path(cls, path: Union[str, Path]) -> "Infos": + def from_path(cls, path: Union[str, Path], require_persisted: bool = True) -> "Infos": """Load and validate an :class:`Infos` from a YAML file. Args: path: Path to the YAML file (typically ``infos.yaml``) or to a directory containing it. + require_persisted: When True, require storage-derived metadata + fields expected in a complete on-disk dataset. Returns: Validated :class:`Infos` instance. @@ -135,7 +152,10 @@ def from_path(cls, path: Union[str, Path]) -> "Infos": with path.open("r", encoding="utf-8") as file: data = yaml.safe_load(file) or {} - return cls.model_validate(data) + infos = cls.model_validate(data) + if require_persisted: + infos.require_persisted() + return infos def save_to_file(self, path: Union[str, Path]) -> None: """Save infos to ``path`` as a YAML file. @@ -155,7 +175,7 @@ def save_to_file(self, path: Union[str, Path]) -> None: path.parent.mkdir(parents=True, exist_ok=True) - data = self.model_dump(exclude_none=True) + data = self.model_dump(exclude_none=True, exclude_unset=True) ordered_data = {key: data[key] for key in _KEY_ORDER if key in data} # Preserve any future fields. for key, value in data.items(): diff --git a/src/plaid/problem_definition.py b/src/plaid/problem_definition.py index cfeba76e..cb83c943 100644 --- a/src/plaid/problem_definition.py +++ b/src/plaid/problem_definition.py @@ -33,7 +33,6 @@ class ProblemDefinition(BaseModel): revalidate_instances="always", validate_assignment=True, extra="forbid" ) - name: str input_features: list[str] output_features: list[str] train_split: dict[str, Sequence[int] | Literal["all"]] @@ -85,17 +84,7 @@ def normalize_output_features(cls, v): return _normalize_list(v) def __setattr__(self, name: str, value: Any) -> None: - """Override the default attribute setting behavior to enforce immutability for certain fields and log warnings for others.""" - # to set the name, task, score_function only once and oly once - if name in ["name"]: - current_value = getattr(self, name, None) - if ( - current_value is not None - and value is not None - and current_value != value - ): - raise AttributeError(f"'{name}' is already set and cannot be changed.") - # warning if set + """Override attribute setting to log warnings when split fields are replaced.""" if name in ["train_split", "test_split"]: current_value = getattr(self, name, None) if ( @@ -121,7 +110,6 @@ def add_input_features(self, inputs: Union[str, Sequence[str]]) -> None: from plaid.problem_definition import ProblemDefinition problem = ProblemDefinition( - name="example", input_features=["angle"], output_features=["pressure"], train_split={"train": "all"}, @@ -163,7 +151,6 @@ def add_output_features(self, outputs: Union[str, Sequence[str]]) -> None: from plaid.problem_definition import ProblemDefinition problem = ProblemDefinition( - name="example", input_features=["angle"], output_features=["pressure"], train_split={"train": "all"}, @@ -202,7 +189,6 @@ def save_to_file(self, path: Union[str, Path]) -> None: from plaid import ProblemDefinition problem = ProblemDefinition( - name="example", input_features=["angle"], output_features=["pressure"], train_split={"train": "all"}, @@ -219,7 +205,6 @@ def save_to_file(self, path: Union[str, Path]) -> None: data = self.model_dump() key_order = [ - "name", "input_features", "output_features", "train_split", diff --git a/src/plaid/storage/cgns/writer.py b/src/plaid/storage/cgns/writer.py index a1095813..5f98d0e3 100644 --- a/src/plaid/storage/cgns/writer.py +++ b/src/plaid/storage/cgns/writer.py @@ -224,7 +224,7 @@ def configure_dataset_card( lines = lines[: indices[1] + 1] count = 6 - lines.insert(count, f"license: {infos.legal.license}") + lines.insert(count, f"license: {infos.license}") count += 1 lines.insert(count, "viewer: false") count += 1 diff --git a/src/plaid/storage/common/reader.py b/src/plaid/storage/common/reader.py index 1f20b4e5..dbca8407 100644 --- a/src/plaid/storage/common/reader.py +++ b/src/plaid/storage/common/reader.py @@ -62,7 +62,7 @@ def load_problem_definitions_from_disk( into ``ProblemDefinition`` objects. Each file is loaded using ``ProblemDefinition.from_path`` and inserted into - a dictionary keyed by the problem definition name. + a dictionary keyed by the YAML filename stem. Expected local layout: / @@ -77,7 +77,7 @@ def load_problem_definitions_from_disk( Returns: dict[str, ProblemDefinition]: - Mapping from problem definition names to loaded ``ProblemDefinition`` + Mapping from problem definition filename stems to loaded ``ProblemDefinition`` objects. Raises: @@ -100,8 +100,7 @@ def load_problem_definitions_from_disk( f"Failed to load problem definition file " f"'{pb_def_path.name}': {exc}" ) from exc - pb_name = pb_def.name if isinstance(pb_def.name, str) else p.stem - pb_defs[pb_name] = pb_def + pb_defs[p.stem] = pb_def return pb_defs else: raise ValueError( diff --git a/src/plaid/storage/common/writer.py b/src/plaid/storage/common/writer.py index 41231824..fcc03232 100644 --- a/src/plaid/storage/common/writer.py +++ b/src/plaid/storage/common/writer.py @@ -39,25 +39,30 @@ def save_infos_to_disk(path: Union[str, Path], infos: Infos) -> None: def save_problem_definitions_to_disk( path: Union[str, Path], - pb_defs: Union[dict[str, ProblemDefinition], ProblemDefinition], + pb_defs: dict[str, ProblemDefinition], ) -> None: """Save ProblemDefinitions to disk. Args: path (Union[str, Path]): The directory path for saving. - pb_defs (Union[dict[str, ProblemDefinition], ProblemDefinition]): The problem definitions to save. + pb_defs (dict[str, ProblemDefinition]): Mapping from problem definition identifiers to definitions. """ if isinstance(pb_defs, ProblemDefinition): - pb_defs = {pb_defs.name: pb_defs} + raise TypeError( + "pb_defs must be a dict[str, ProblemDefinition]; " + "use the dictionary key as the problem identifier." + ) + if not isinstance(pb_defs, dict): + raise TypeError("pb_defs must be a dict[str, ProblemDefinition]") target_dir = Path(path) / "problem_definitions" target_dir.mkdir(parents=True, exist_ok=True) for name, pb_def in pb_defs.items(): - if name is None: - raise ValueError( - "At least one of the provided pb_defs has no initialized name." - ) + if not isinstance(name, str) or not name: + raise TypeError("Problem definition identifiers must be non-empty strings") + if not isinstance(pb_def, ProblemDefinition): + raise TypeError("pb_defs values must be ProblemDefinition instances") pb_def.save_to_file(target_dir / name) diff --git a/src/plaid/storage/hf_datasets/writer.py b/src/plaid/storage/hf_datasets/writer.py index 049d931d..3a98147b 100644 --- a/src/plaid/storage/hf_datasets/writer.py +++ b/src/plaid/storage/hf_datasets/writer.py @@ -240,7 +240,7 @@ def configure_dataset_card( lines = lines[: indices[1] + 1] count = 1 - lines.insert(count, f"license: {infos.legal.license}") + lines.insert(count, f"license: {infos.license}") count += 1 if viewer is False: lines.insert(count, "viewer: false") diff --git a/src/plaid/storage/writer.py b/src/plaid/storage/writer.py index de7f699b..d5603c7c 100644 --- a/src/plaid/storage/writer.py +++ b/src/plaid/storage/writer.py @@ -27,7 +27,7 @@ from plaid.storage.registry import available_backends, get_backend from ..containers.sample import Sample -from ..infos import Infos, Legal +from ..infos import Infos from ..problem_definition import ProblemDefinition from .common.preprocessor import preprocess from .common.reader import ( @@ -120,7 +120,7 @@ def save_to_disk( ids: Mapping[str, Any], infos: Optional[Infos] = None, backend: str = "hf_datasets", - pb_defs: Optional[Union[dict[str, ProblemDefinition], ProblemDefinition]] = None, + pb_defs: Optional[dict[str, ProblemDefinition]] = None, num_proc: int = 1, verbose: bool = False, overwrite: bool = False, @@ -155,9 +155,8 @@ def sample_constructor(file_path): "test": test_file_paths, }, infos=Infos( - legal={"owner": "owner", "license": "license"}, - num_samples={}, - storage_backend="hf_datasets", + owner="owner", + license="license", ), num_proc=6, ) @@ -175,8 +174,8 @@ def sample_constructor(file_path): ``'zarr'``). infos: Dataset information to save with the dataset. If ``None``, a placeholder :class:`~plaid.Infos` is created with - ``legal=Legal(owner="unknown", license="unknown")``. - pb_defs: Optional problem definitions to save. + ``owner="unknown", license="unknown"``. + pb_defs: Optional mapping from problem definition identifiers to definitions. num_proc: Number of processes to use for parallel writing. When ``num_proc > 1`` PLAID automatically shards the identifier sequences and distributes work across workers. @@ -239,14 +238,13 @@ def sample_constructor(file_path): # overriding any inherited values from the input ``infos``. if infos is None: infos = Infos( - legal=Legal(owner="unknown", license="unknown"), - num_samples={}, - storage_backend=backend, + owner="unknown", + license="unknown", ) infos_data = infos.model_dump(exclude_none=True) infos_data["num_samples"] = num_samples infos_data["storage_backend"] = backend - infos = Infos.model_validate(infos_data) + infos = Infos.validate_persisted(infos_data) save_infos_to_disk(output_folder, infos) @@ -298,6 +296,8 @@ def push_to_hub( infos = load_infos_from_disk(local_dir) backend = infos.storage_backend + if backend is None: + raise ValueError("Missing 'storage_backend' in persisted infos") backend_spec = get_backend(backend) backend_spec.push_local_to_hub(repo_id, local_dir, num_workers=num_workers) diff --git a/src/plaid/storage/zarr/writer.py b/src/plaid/storage/zarr/writer.py index e3323567..71408523 100644 --- a/src/plaid/storage/zarr/writer.py +++ b/src/plaid/storage/zarr/writer.py @@ -360,7 +360,7 @@ def configure_dataset_card( lines = lines[: indices[1] + 1] count = 6 - lines.insert(count, f"license: {infos.legal.license}") + lines.insert(count, f"license: {infos.license}") count += 1 lines.insert(count, "viewer: false") count += 1 diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index d5dcb6a3..cfe8fe5a 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -18,17 +18,13 @@ check_dataset, main, ) -from plaid.infos import Infos, Legal +from plaid.infos import Infos _REFERENCE_DATASETS = ("dataset_cgns", "dataset_hf") def _infos(num_samples: dict[str, int], storage_backend: str = "zarr") -> Infos: - return Infos( - legal=Legal(owner="owner", license="license"), - num_samples=num_samples, - storage_backend=storage_backend, - ) + return Infos(owner="owner", license="license") def _copy_reference_dataset(tmp_path: Path, name: str = "dataset_cgns") -> Path: @@ -913,7 +909,6 @@ def test_check_dataset_problem_definition_read_error_names_yaml_file( pb_def_dir = dataset / "problem_definitions" pb_def_dir.mkdir() (pb_def_dir / "bad_definition.yaml").write_text( - "name: bad\n" "input_features: [in]\n" "output_features: [out]\n" "unexpected_key: value\n", diff --git a/tests/containers/dataset_cgns/problem_definitions/regression_1.yaml b/tests/containers/dataset_cgns/problem_definitions/regression_1.yaml index 7a7e1680..137a7af0 100644 --- a/tests/containers/dataset_cgns/problem_definitions/regression_1.yaml +++ b/tests/containers/dataset_cgns/problem_definitions/regression_1.yaml @@ -1,4 +1,3 @@ -name: regression_1 input_features: - Base_2_3/Zone/Elements_TRI_3/ElementConnectivity - Base_2_3/Zone/GridCoordinates/CoordinateX diff --git a/tests/containers/dataset_hf/problem_definitions/regression_1.yaml b/tests/containers/dataset_hf/problem_definitions/regression_1.yaml index 7a7e1680..137a7af0 100644 --- a/tests/containers/dataset_hf/problem_definitions/regression_1.yaml +++ b/tests/containers/dataset_hf/problem_definitions/regression_1.yaml @@ -1,4 +1,3 @@ -name: regression_1 input_features: - Base_2_3/Zone/Elements_TRI_3/ElementConnectivity - Base_2_3/Zone/GridCoordinates/CoordinateX diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index ca9fce37..f98bbadc 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -164,7 +164,6 @@ def main_splits() -> dict: @pytest.fixture() def problem_definition(main_splits) -> ProblemDefinition: return ProblemDefinition( - name="problem_definition", input_features=["feature_name_1", "feature_name_2"], output_features=["feature_name_1"], train_split={"train": main_splits["train"]}, @@ -266,8 +265,7 @@ def test_hf_datasets( overwrite=False, ) - with pytest.raises(ValueError): - problem_definition.name = None + with pytest.raises(TypeError, match="dict\[str, ProblemDefinition\]"): save_to_disk( output_folder=test_dir, sample_constructor=sample_constructor, @@ -289,7 +287,8 @@ def test_hf_datasets( verbose=True, ) - load_problem_definitions_from_disk(test_dir) + loaded_pb_defs = load_problem_definitions_from_disk(test_dir) + assert set(loaded_pb_defs) == {"pb_def"} with pytest.raises(ValueError): load_problem_definitions_from_disk("dummy") diff --git a/tests/test_problem_definition.py b/tests/test_problem_definition.py index 2bd7ebf7..64dd0809 100644 --- a/tests/test_problem_definition.py +++ b/tests/test_problem_definition.py @@ -15,7 +15,6 @@ @pytest.fixture() def problem_definition() -> ProblemDefinition: return ProblemDefinition.model_construct( - name=None, input_features=[], output_features=[], train_split=None, @@ -25,8 +24,6 @@ def problem_definition() -> ProblemDefinition: @pytest.fixture() def problem_definition_full(problem_definition: ProblemDefinition) -> ProblemDefinition: - problem_definition.name = "regression_1" - # ---- feature_identifier = "Global/feature" predict_feature_identifier = "Global/predict_feature" @@ -89,7 +86,6 @@ def test_required_fields(self): def test_feature_lists_must_not_be_empty(self): base = { - "name": "pb", "train_split": {"train": "all"}, "test_split": {"test": "all"}, } @@ -111,7 +107,6 @@ def test_feature_lists_must_not_be_empty(self): def test_from_mapping_validates_and_normalizes(self): loaded = ProblemDefinition.model_validate( { - "name": "pb_single", "input_features": ["in_b", "in_a"], "output_features": ["out_b", "out_a"], "train_split": {"train_0": [0, 1]}, @@ -119,7 +114,6 @@ def test_from_mapping_validates_and_normalizes(self): } ) - assert loaded.name == "pb_single" assert loaded.input_features == ["in_a", "in_b"] assert loaded.output_features == ["out_a", "out_b"] assert loaded.train_split == {"train_0": [0, 1]} @@ -128,7 +122,6 @@ def test_from_mapping_validates_and_normalizes(self): def test_from_path_loads_single_yaml_file(self, tmp_path: Path): file_path = tmp_path / "problem.yaml" file_path.write_text( - "name: pb\n" "input_features:\n" " - in_1\n" "output_features:\n" @@ -142,14 +135,12 @@ def test_from_path_loads_single_yaml_file(self, tmp_path: Path): loaded = ProblemDefinition.from_path(file_path) - assert loaded.name == "pb" assert loaded.train_split == {"train": [0]} assert loaded.test_split == {"test": [1]} def test_from_path_adds_yaml_suffix(self, tmp_path: Path): file_path = tmp_path / "problem.yaml" file_path.write_text( - "name: pb\n" "input_features: [in_1]\n" "output_features: [out_1]\n" "train_split:\n" @@ -161,12 +152,27 @@ def test_from_path_adds_yaml_suffix(self, tmp_path: Path): loaded = ProblemDefinition.from_path(tmp_path / "problem") - assert loaded.name == "pb" + assert loaded.input_features == ["in_1"] + + def test_from_path_rejects_old_name_key(self, tmp_path: Path): + file_path = tmp_path / "problem_with_name.yaml" + file_path.write_text( + "name: pb\n" + "input_features: [in_1]\n" + "output_features: [out_1]\n" + "train_split:\n" + " train: all\n" + "test_split:\n" + " test: all\n", + encoding="utf-8", + ) + + with pytest.raises(ValidationError, match="extra_forbidden"): + ProblemDefinition.from_path(file_path) def test_from_path_unknown_key_raises(self, tmp_path: Path): file_path = tmp_path / "problem_with_unknown.yaml" file_path.write_text( - "name: pb\n" "input_features: [in_1]\n" "output_features: [out_1]\n" "train_split:\n" @@ -189,7 +195,6 @@ def test_feature_validators_reject_duplicates(self): ValidationError, match="duplicated values in input_features" ): ProblemDefinition( - name="pb", input_features=["a", "a"], output_features=["out"], train_split={"train": "all"}, @@ -200,18 +205,12 @@ def test_feature_validators_reject_duplicates(self): ValidationError, match="duplicated values in output_features" ): ProblemDefinition( - name="pb", input_features=["in"], output_features=["a", "a"], train_split={"train": "all"}, test_split={"test": "all"}, ) - def test_non_overwritable_attributes_raise(self, problem_definition): - problem_definition.name = "problem_a" - with pytest.raises(AttributeError, match="'name' is already set"): - problem_definition.name = "problem_b" - def test_split_replacement_logs_warning(self, problem_definition, caplog): problem_definition.train_split = {"train_0": [0, 1]} with caplog.at_level("WARNING"): @@ -252,7 +251,6 @@ def test_split(self, problem_definition): def test_save_and_load_keep_yaml_suffix(self, tmp_path: Path): problem = ProblemDefinition( - name="pb", input_features=["in_1"], output_features=["out_1"], train_split={"train": [0]}, @@ -263,6 +261,7 @@ def test_save_and_load_keep_yaml_suffix(self, tmp_path: Path): loaded = ProblemDefinition.from_path(file_path) - assert loaded.name == "pb" + saved_text = file_path.read_text(encoding="utf-8") + assert "name:" not in saved_text assert loaded.train_split == {"train": [0]} assert loaded.test_split == {"test": [1]} diff --git a/tests/utils/test_info.py b/tests/utils/test_info.py index bfe9df20..02f6f983 100644 --- a/tests/utils/test_info.py +++ b/tests/utils/test_info.py @@ -1,16 +1,31 @@ import pytest from pydantic import ValidationError -from plaid.infos import Infos, Legal +from plaid.infos import Infos -def test_verify_info_accepts_special_internal_keys(): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, +def _valid_infos(**overrides): + data = { + "owner": "owner", + "license": "cc-by-4.0", "num_samples": {"train": 10}, "storage_backend": "zarr", } - Infos(**infos) + data.update(overrides) + return data + + +def test_verify_info_accepts_flat_owner_license_keys(): + Infos(**_valid_infos()) + + +def test_infos_allows_draft_without_storage_derived_fields(): + model = Infos(owner="owner", license="cc-by-4.0") + + assert model.owner == "owner" + assert model.license == "cc-by-4.0" + assert model.num_samples == {} + assert model.storage_backend is None def test_verify_info_rejects_unknown_category(): @@ -18,21 +33,32 @@ def test_verify_info_rejects_unknown_category(): Infos(**{"unknown": {"x": "y"}}) -def test_verify_info_rejects_unknown_key(): - with pytest.raises(ValidationError): - Infos(**{"legal": {"unknown_key": "v"}}) +def test_verify_info_rejects_legacy_legal_key(): + with pytest.raises(ValidationError, match="extra_forbidden"): + Infos(**{"legal": {"owner": "owner", "license": "cc-by-4.0"}}) def test_validate_required_only_missing_required_key(): with pytest.raises(ValueError): - Infos(**{"legal": {"owner": "someone"}}) + Infos(**{"owner": "someone"}) + + +def test_validate_required_only_requires_persisted_fields(): + with pytest.raises(ValueError, match="num_samples"): + Infos.validate_required_only({"owner": "owner", "license": "cc-by-4.0"}) + + with pytest.raises(ValueError, match="storage_backend"): + Infos.validate_required_only( + { + "owner": "owner", + "license": "cc-by-4.0", + "num_samples": {}, + } + ) def test_normalize_infos_rejects_legacy_plaid_section_and_copies(): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "plaid": {"version": "x"}, - } + infos = {"owner": "owner", "license": "cc-by-4.0", "plaid": {"version": "x"}} with pytest.raises(ValidationError, match="extra_forbidden"): Infos.normalize_mapping(infos) @@ -45,84 +71,83 @@ def test_model_validate_rejects_plaid_section(): with pytest.raises(ValidationError): Infos.model_validate( { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "owner": "owner", + "license": "cc-by-4.0", "plaid": {"version": "x"}, } ) def test_dataset_info_model_validate_success(): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "num_samples": {"train": 10}, - "storage_backend": "zarr", - } + model = Infos.model_validate(_valid_infos()) - model = Infos.model_validate(infos) - - assert model.legal.owner == "owner" + assert model.owner == "owner" + assert model.license == "cc-by-4.0" assert model.storage_backend == "zarr" def test_dataset_info_model_validate_rejects_extra_top_level_key(): with pytest.raises(ValueError): - Infos.model_validate( - { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "unknown": {}, - } - ) + Infos.model_validate({"owner": "owner", "license": "cc-by-4.0", "unknown": {}}) def test_infos_save_and_load_roundtrip(tmp_path): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "num_samples": {"train": 10}, - "storage_backend": "zarr", - } - model = Infos.model_validate(infos) + model = Infos.model_validate(_valid_infos()) target = tmp_path / "infos.yaml" model.save_to_file(target) assert target.is_file() reloaded = Infos.from_path(target) - assert reloaded.legal.owner == "owner" + assert reloaded.owner == "owner" + assert reloaded.license == "cc-by-4.0" assert reloaded.storage_backend == "zarr" assert reloaded.num_samples == {"train": 10} def test_infos_from_path_directory(tmp_path): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "num_samples": {"train": 10}, - "storage_backend": "zarr", - } - Infos.model_validate(infos).save_to_file(tmp_path / "infos.yaml") + Infos.model_validate(_valid_infos()).save_to_file(tmp_path / "infos.yaml") reloaded = Infos.from_path(tmp_path) - assert reloaded.legal.license == "cc-by-4.0" + assert reloaded.license == "cc-by-4.0" -def test_validate_authorized_only_allows_missing_legal(): +def test_infos_from_path_requires_persisted_fields_by_default(tmp_path): + Infos(owner="o", license="l").save_to_file(tmp_path) + + with pytest.raises(ValueError, match="num_samples"): + Infos.from_path(tmp_path) + + +def test_infos_from_path_can_load_draft_infos(tmp_path): + Infos(owner="o", license="l").save_to_file(tmp_path) + + reloaded = Infos.from_path(tmp_path, require_persisted=False) + + assert reloaded.owner == "o" + assert reloaded.license == "l" + assert reloaded.num_samples == {} + assert reloaded.storage_backend is None + + +def test_validate_authorized_only_allows_missing_owner_license(): model = Infos.validate_authorized_only( {"num_samples": {"train": 1}, "storage_backend": "zarr"} ) - # Missing legal must be filled with empty placeholder values. - assert model.legal.owner == "" - assert model.legal.license == "" + # Missing user-authored required fields are filled with empty placeholder values. + assert model.owner == "" + assert model.license == "" assert model.storage_backend == "zarr" -def test_validate_authorized_only_with_legal_present(): - model = Infos.validate_authorized_only({"legal": {"owner": "o", "license": "l"}}) - assert model.legal.owner == "o" +def test_validate_authorized_only_with_owner_license_present(): + model = Infos.validate_authorized_only({"owner": "o", "license": "l"}) + assert model.owner == "o" + assert model.license == "l" def test_validate_authorized_only_rejects_unauthorized_key(): with pytest.raises(KeyError): - Infos.validate_authorized_only( - {"legal": {"owner": "o", "license": "l"}, "unknown": {}} - ) + Infos.validate_authorized_only({"owner": "o", "license": "l", "unknown": {}}) def test_validate_authorized_only_reraises_other_validation_errors(): @@ -130,22 +155,21 @@ def test_validate_authorized_only_reraises_other_validation_errors(): # validation error that is *not* of the unauthorized-key kind, so it # must be re-raised as ValidationError. with pytest.raises(ValidationError): - Infos.validate_authorized_only( - {"legal": {"owner": "o", "license": "l"}, "num_samples": "nope"} - ) + Infos.validate_authorized_only({"owner": "o", "license": "l", "num_samples": "nope"}) def test_validate_required_only_accepts_valid_mapping(): Infos.validate_required_only( { - "legal": {"owner": "o", "license": "l"}, + "owner": "o", + "license": "l", "num_samples": {"train": 1}, "storage_backend": "zarr", } ) -def test_validate_required_only_missing_legal(): +def test_validate_required_only_missing_owner_license(): with pytest.raises(ValidationError): Infos.validate_required_only({}) @@ -153,57 +177,48 @@ def test_validate_required_only_missing_legal(): def test_model_dump_returns_plain_mapping(): model = Infos.model_validate( { - "legal": {"owner": "o", "license": "l"}, + "owner": "o", + "license": "l", "num_samples": {}, "storage_backend": "zarr", } ) d = model.model_dump(exclude_none=True) - assert d["legal"] == {"owner": "o", "license": "l"} + assert d["owner"] == "o" + assert d["license"] == "l" assert d["storage_backend"] == "zarr" def test_attribute_access_returns_typed_values(): model = Infos.model_validate( { - "legal": {"owner": "o", "license": "l"}, + "owner": "o", + "license": "l", "num_samples": {}, "storage_backend": "zarr", } ) assert model.storage_backend == "zarr" - assert model.legal.owner == "o" - assert model.legal.license == "l" + assert model.owner == "o" + assert model.license == "l" def test_save_to_file_treats_suffixless_path_as_directory(tmp_path): target = tmp_path / "myinfos" - Infos( - legal=Legal(owner="o", license="l"), - num_samples={}, - storage_backend="zarr", - ).save_to_file(target) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file(target) # Suffix-less, non-existing paths are treated as directories that # will hold an ``infos.yaml``. assert (target / "infos.yaml").is_file() def test_save_to_file_into_existing_directory(tmp_path): - Infos( - legal=Legal(owner="o", license="l"), - num_samples={}, - storage_backend="zarr", - ).save_to_file(tmp_path) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file(tmp_path) assert (tmp_path / "infos.yaml").is_file() def test_save_to_file_replaces_non_yaml_suffix(tmp_path): target = tmp_path / "weird.txt" - Infos( - legal=Legal(owner="o", license="l"), - num_samples={}, - storage_backend="zarr", - ).save_to_file(target) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file(target) # ``.txt`` suffix is replaced by ``.yaml``. assert (tmp_path / "weird.yaml").is_file() assert not target.exists() @@ -221,7 +236,8 @@ def test_save_to_file_preserves_unknown_future_keys(tmp_path, monkeypatch): monkeypatch.setattr(infos_mod, "_KEY_ORDER", ()) model = Infos.model_validate( { - "legal": {"owner": "o", "license": "l"}, + "owner": "o", + "license": "l", "num_samples": {}, "storage_backend": "zarr", } @@ -230,7 +246,8 @@ def test_save_to_file_preserves_unknown_future_keys(tmp_path, monkeypatch): model.save_to_file(target) text = target.read_text(encoding="utf-8") assert "storage_backend" in text - assert "legal" in text + assert "owner" in text + assert "license" in text def test_from_path_raises_file_not_found(tmp_path): From 9f9b40caa9f2fb5ca397f56d62912d2976eaf7c2 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Fri, 5 Jun 2026 16:46:11 +0000 Subject: [PATCH 08/11] wip --- docs/source/concepts/dataset.md | 5 +++-- docs/source/concepts/infos.md | 18 ++++++++--------- docs/source/tutorials/storage.md | 6 ++---- examples/infos_example.py | 25 ++++++++++++------------ src/plaid/infos.py | 8 ++++++-- tests/cli/test_plaidcheck.py | 14 ++++++++----- tests/conftest.py | 3 ++- tests/containers/dataset_cgns/infos.yaml | 5 ++--- tests/containers/dataset_hf/infos.yaml | 5 ++--- tests/containers/test_utils.py | 7 +++---- tests/storage/test_cgns_init.py | 6 ++++-- tests/storage/test_hf_datasets_init.py | 3 ++- tests/storage/test_storage.py | 2 +- tests/storage/test_zarr_init.py | 3 ++- tests/utils/test_info.py | 16 +++++++++++---- 15 files changed, 72 insertions(+), 54 deletions(-) diff --git a/docs/source/concepts/dataset.md b/docs/source/concepts/dataset.md index 0cd9a317..17432e77 100644 --- a/docs/source/concepts/dataset.md +++ b/docs/source/concepts/dataset.md @@ -117,7 +117,7 @@ dictionary with the same schema: ```python from plaid import ProblemDefinition -from plaid.infos import DataDescription, Infos, Legal +from plaid.infos import DataDescription, Infos from plaid.storage import save_to_disk pb_def = ProblemDefinition( @@ -127,7 +127,8 @@ pb_def = ProblemDefinition( test_split={"test": "all"}, ) infos = Infos( - legal=Legal(owner="CompanyX", license="proprietary"), + owner="CompanyX", + license="proprietary", data_description=DataDescription(number_of_samples=3), num_samples={"train": 3}, ) diff --git a/docs/source/concepts/infos.md b/docs/source/concepts/infos.md index 333be889..da6dc815 100644 --- a/docs/source/concepts/infos.md +++ b/docs/source/concepts/infos.md @@ -9,7 +9,8 @@ root of a PLAID dataset. In the current API, infos stores: -- `legal`, with required `owner` and `license` entries +- `owner` and `license`, required string entries describing the dataset + ownership and licensing - `data_production`, for optional production context such as simulator, hardware, contact, or location - `data_description`, for optional dataset description entries such as the @@ -20,10 +21,11 @@ In the current API, infos stores: ## Basic usage ```python -from plaid.infos import DataProduction, Infos, Legal +from plaid.infos import DataProduction, Infos infos = Infos( - legal=Legal(owner="Safran", license="proprietary"), + owner="Safran", + license="proprietary", data_production=DataProduction( type="simulation", physics="fluid dynamics", @@ -38,10 +40,8 @@ Infos can also be built from a plain mapping, for instance after reading YAML: ```python infos = Infos.model_validate( { - "legal": { - "owner": "Safran", - "license": "proprietary", - }, + "owner": "Safran", + "license": "proprietary", } ) ``` @@ -85,14 +85,14 @@ directory. Pydantic serialization when a plain mapping is needed: ```python -owner = infos.legal.owner +owner = infos.owner backend = infos.storage_backend payload = infos.model_dump(exclude_none=True) ``` ## Notes -- `legal.owner` and `legal.license` are required when creating infos. +- `owner` and `license` are required when creating infos. - `num_samples` and `storage_backend` are required when loading persisted dataset infos. - `num_samples` and `storage_backend` are overwritten with the actual saved dataset values when `save_to_disk(..., infos=...)` is called before writing `infos.yaml`. - Unknown keys are rejected during validation. diff --git a/docs/source/tutorials/storage.md b/docs/source/tutorials/storage.md index 2fef6142..ac4dbe03 100644 --- a/docs/source/tutorials/storage.md +++ b/docs/source/tutorials/storage.md @@ -98,10 +98,8 @@ curated_test_ids = curated_test_ids[:10] infos = Infos.model_validate( { - "legal": { - "owner": "NeuralOperator (https://zenodo.org/records/13993629)", - "license": "cc-by-4.0", - }, + "owner": "NeuralOperator (https://zenodo.org/records/13993629)", + "license": "cc-by-4.0", "data_production": { "physics": "CFD", "type": "simulation", diff --git a/examples/infos_example.py b/examples/infos_example.py index fbcb790c..58015054 100644 --- a/examples/infos_example.py +++ b/examples/infos_example.py @@ -36,7 +36,7 @@ # %% # Import necessary libraries and classes -from plaid.infos import DataProduction, Infos, Legal +from plaid.infos import DataProduction, Infos # %% [markdown] # ## Section 1: Initializing Infos @@ -49,7 +49,8 @@ # %% print("#---# Infos") infos = Infos( - legal=Legal(owner="PLAID", license="MIT"), + owner="PLAID", + license="MIT", ) print(f"{infos = }") @@ -59,10 +60,8 @@ # %% infos_from_mapping = Infos.model_validate( { - "legal": { - "owner": "PLAID", - "license": "MIT", - }, + "owner": "PLAID", + "license": "MIT", "data_description": "Example metadata for a PLAID dataset.", } ) @@ -75,11 +74,13 @@ # metadata. # %% [markdown] -# ### Set legal metadata +# ### Set owner and license metadata # %% -infos.legal = Legal(owner="Safran", license="proprietary") -print(f"{infos.legal = }") +infos.owner = "Safran" +infos.license = "proprietary" +print(f"{infos.owner = }") +print(f"{infos.license = }") # %% [markdown] # ### Set data production metadata @@ -110,8 +111,8 @@ # ### Retrieve data with Pydantic attributes # %% -print(f"{infos.legal = }") -print(f"{infos.legal.owner = }") +print(f"{infos.owner = }") +print(f"{infos.license = }") print(f"{infos.storage_backend = }") print(f"{infos.model_dump(exclude_none=True) = }") @@ -146,4 +147,4 @@ # %% loaded_infos_from_dir = Infos.from_path(test_pth, require_persisted=False) -print(loaded_infos_from_dir) \ No newline at end of file +print(loaded_infos_from_dir) diff --git a/src/plaid/infos.py b/src/plaid/infos.py index 88f65c33..c92d4f88 100644 --- a/src/plaid/infos.py +++ b/src/plaid/infos.py @@ -66,7 +66,9 @@ def require_persisted(self) -> "Infos": if "num_samples" not in self.model_fields_set: raise ValueError("Missing required persisted infos field: 'num_samples'") if "storage_backend" not in self.model_fields_set or not self.storage_backend: - raise ValueError("Missing required persisted infos field: 'storage_backend'") + raise ValueError( + "Missing required persisted infos field: 'storage_backend'" + ) return self @classmethod @@ -128,7 +130,9 @@ def normalize_mapping(cls, infos: dict[str, Any]) -> dict[str, Any]: # ------------------------------------------------------------------ @classmethod - def from_path(cls, path: Union[str, Path], require_persisted: bool = True) -> "Infos": + def from_path( + cls, path: Union[str, Path], require_persisted: bool = True + ) -> "Infos": """Load and validate an :class:`Infos` from a YAML file. Args: diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index cfe8fe5a..e04f3bf6 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -24,7 +24,12 @@ def _infos(num_samples: dict[str, int], storage_backend: str = "zarr") -> Infos: - return Infos(owner="owner", license="license") + return Infos( + owner="owner", + license="license", + num_samples=num_samples, + storage_backend=storage_backend, + ) def _copy_reference_dataset(tmp_path: Path, name: str = "dataset_cgns") -> Path: @@ -541,7 +546,8 @@ def test_check_dataset_loader_failures_and_header_validations( plaidcheck, "load_infos_from_disk", lambda path: Infos.model_construct( # noqa: ARG005 - legal=Legal(owner="owner", license="license"), + owner="owner", + license="license", storage_backend=12, num_samples="bad", ), @@ -909,9 +915,7 @@ def test_check_dataset_problem_definition_read_error_names_yaml_file( pb_def_dir = dataset / "problem_definitions" pb_def_dir.mkdir() (pb_def_dir / "bad_definition.yaml").write_text( - "input_features: [in]\n" - "output_features: [out]\n" - "unexpected_key: value\n", + "input_features: [in]\noutput_features: [out]\nunexpected_key: value\n", encoding="utf-8", ) diff --git a/tests/conftest.py b/tests/conftest.py index fb9b2a94..fbab2be9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,7 +93,8 @@ def other_samples(nb_samples: int, zone_name: str, base_name: str) -> list[Sampl def infos(): return Infos.model_validate( { - "legal": {"owner": "PLAID2", "license": "BSD-3"}, + "owner": "PLAID2", + "license": "BSD-3", "data_production": {"type": "simulation", "simulator": "Z-set"}, "num_samples": {}, "storage_backend": "zarr", diff --git a/tests/containers/dataset_cgns/infos.yaml b/tests/containers/dataset_cgns/infos.yaml index 749ac15d..ffe70f8a 100644 --- a/tests/containers/dataset_cgns/infos.yaml +++ b/tests/containers/dataset_cgns/infos.yaml @@ -1,6 +1,5 @@ -legal: - owner: NeuralOperator (https://zenodo.org/records/13993629) - license: cc-by-4.0 +owner: NeuralOperator (https://zenodo.org/records/13993629) +license: cc-by-4.0 data_production: type: simulation physics: CFD diff --git a/tests/containers/dataset_hf/infos.yaml b/tests/containers/dataset_hf/infos.yaml index 45560724..2478f564 100644 --- a/tests/containers/dataset_hf/infos.yaml +++ b/tests/containers/dataset_hf/infos.yaml @@ -1,6 +1,5 @@ -legal: - owner: NeuralOperator (https://zenodo.org/records/13993629) - license: cc-by-4.0 +owner: NeuralOperator (https://zenodo.org/records/13993629) +license: cc-by-4.0 data_production: type: simulation physics: CFD diff --git a/tests/containers/test_utils.py b/tests/containers/test_utils.py index b119c9d4..f8b35c4b 100644 --- a/tests/containers/test_utils.py +++ b/tests/containers/test_utils.py @@ -144,16 +144,15 @@ def test_get_feature_details_from_path(self, url, expected): def test_validate_required_only(self): infos = { - "legal": {"owner": "Joh Doe", "license": "cc-by-sa-4.0"}, + "owner": "Joh Doe", + "license": "cc-by-sa-4.0", "num_samples": {"train": 1}, "storage_backend": "zarr", } Infos.validate_required_only(infos) infos_missing_license = { - "legal": { - "owner": "Joh Doe", - }, + "owner": "Joh Doe", "num_samples": {"train": 1}, "storage_backend": "zarr", } diff --git a/tests/storage/test_cgns_init.py b/tests/storage/test_cgns_init.py index d66746f8..41089793 100644 --- a/tests/storage/test_cgns_init.py +++ b/tests/storage/test_cgns_init.py @@ -171,7 +171,8 @@ def test_cgns_backend_configure_dataset_card_requires_local_dir(): repo_id="dummy/repo", infos=Infos.model_validate( { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "owner": "owner", + "license": "cc-by-4.0", "num_samples": {}, "storage_backend": "cgns", } @@ -190,7 +191,8 @@ def fake_configure_dataset_card(**kwargs): infos = Infos.model_validate( { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "owner": "owner", + "license": "cc-by-4.0", "num_samples": {}, "storage_backend": "cgns", } diff --git a/tests/storage/test_hf_datasets_init.py b/tests/storage/test_hf_datasets_init.py index 5dd2c816..50673630 100644 --- a/tests/storage/test_hf_datasets_init.py +++ b/tests/storage/test_hf_datasets_init.py @@ -212,7 +212,8 @@ def fake_configure_dataset_card( infos = Infos.model_validate( { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "owner": "owner", + "license": "cc-by-4.0", "num_samples": {}, "storage_backend": "hf_datasets", } diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index f98bbadc..e9727767 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -265,7 +265,7 @@ def test_hf_datasets( overwrite=False, ) - with pytest.raises(TypeError, match="dict\[str, ProblemDefinition\]"): + with pytest.raises(TypeError, match=r"dict\[str, ProblemDefinition\]"): save_to_disk( output_folder=test_dir, sample_constructor=sample_constructor, diff --git a/tests/storage/test_zarr_init.py b/tests/storage/test_zarr_init.py index d94d9436..2715e305 100644 --- a/tests/storage/test_zarr_init.py +++ b/tests/storage/test_zarr_init.py @@ -178,7 +178,8 @@ def fake_configure_dataset_card( infos = Infos.model_validate( { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "owner": "owner", + "license": "cc-by-4.0", "num_samples": {}, "storage_backend": "zarr", } diff --git a/tests/utils/test_info.py b/tests/utils/test_info.py index 02f6f983..9b2273aa 100644 --- a/tests/utils/test_info.py +++ b/tests/utils/test_info.py @@ -155,7 +155,9 @@ def test_validate_authorized_only_reraises_other_validation_errors(): # validation error that is *not* of the unauthorized-key kind, so it # must be re-raised as ValidationError. with pytest.raises(ValidationError): - Infos.validate_authorized_only({"owner": "o", "license": "l", "num_samples": "nope"}) + Infos.validate_authorized_only( + {"owner": "o", "license": "l", "num_samples": "nope"} + ) def test_validate_required_only_accepts_valid_mapping(): @@ -205,20 +207,26 @@ def test_attribute_access_returns_typed_values(): def test_save_to_file_treats_suffixless_path_as_directory(tmp_path): target = tmp_path / "myinfos" - Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file(target) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file( + target + ) # Suffix-less, non-existing paths are treated as directories that # will hold an ``infos.yaml``. assert (target / "infos.yaml").is_file() def test_save_to_file_into_existing_directory(tmp_path): - Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file(tmp_path) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file( + tmp_path + ) assert (tmp_path / "infos.yaml").is_file() def test_save_to_file_replaces_non_yaml_suffix(tmp_path): target = tmp_path / "weird.txt" - Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file(target) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file( + target + ) # ``.txt`` suffix is replaced by ``.yaml``. assert (tmp_path / "weird.yaml").is_file() assert not target.exists() From 6c0eaae3bc79a2a86f6fec864cb94cc705a5ccf9 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Fri, 5 Jun 2026 17:11:08 +0000 Subject: [PATCH 09/11] wip --- docs/source/tutorials/downloadable_example.md | 8 ++- examples/downloadable_example/__init__.py | 0 .../downloadable_example/sample_example.py | 49 -------------- tests/cli/test_plaidcheck.py | 49 ++++++++++++++ tests/storage/test_common_writer.py | 65 +++++++++++++++++++ 5 files changed, 121 insertions(+), 50 deletions(-) delete mode 100644 examples/downloadable_example/__init__.py delete mode 100644 examples/downloadable_example/sample_example.py create mode 100644 tests/storage/test_common_writer.py diff --git a/docs/source/tutorials/downloadable_example.md b/docs/source/tutorials/downloadable_example.md index 8adb9e2a..99fdb9d7 100644 --- a/docs/source/tutorials/downloadable_example.md +++ b/docs/source/tutorials/downloadable_example.md @@ -4,6 +4,8 @@ title: Downloadable samples # Downloadable samples +## First retrieval + Retrieving sample examples is as easy as: ```python @@ -13,4 +15,8 @@ print(AVAILABLE_EXAMPLES) print("samples.vki_ls59:", samples.vki_ls59) ``` -The first call to `samples.vki_ls59` triggers a download and takes a few seconds, whereas subsequent calls are instantaneous because they reuse the cached sample. +The first call to `samples.vki_ls59` triggers a download and takes a few seconds. + +## Cached retrieval + +Subsequent calls are instantaneous because they reuse the cached sample. \ No newline at end of file diff --git a/examples/downloadable_example/__init__.py b/examples/downloadable_example/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/downloadable_example/sample_example.py b/examples/downloadable_example/sample_example.py deleted file mode 100644 index f97cec82..00000000 --- a/examples/downloadable_example/sample_example.py +++ /dev/null @@ -1,49 +0,0 @@ -# --- -# jupyter: -# jupytext: -# formats: ipynb,py:percent -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.17.3 -# kernelspec: -# display_name: plaid-dev -# language: python -# name: python3 -# --- - -# %% [markdown] -# # Downloadable samples -# -# This Jupyter Notebook show how to easily retrieve sample examples online. - -# %% -import warnings -warnings.filterwarnings("ignore", message=".*IProgress not found.*") - -# %% [markdown] -# Retrieving sample examples is as easy as: - -# %% -from plaid.downloadable_examples import AVAILABLE_EXAMPLES -import time - -print(AVAILABLE_EXAMPLES) - -# %% -from plaid.downloadable_examples import samples - -start = time.perf_counter() -print("samples.vki_ls59:", samples.vki_ls59) -end = time.perf_counter() - -print(f"First sample retrieval duration: {end - start:.6f} seconds") -assert(len(samples.vki_ls59.get_global_names())==8) - -# %% -start = time.perf_counter() -sample = samples.vki_ls59 -end = time.perf_counter() - -print(f"Second sample retrieval duration: {end - start:.6f} seconds (in cache)") diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index e04f3bf6..9f24c120 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -70,6 +70,26 @@ def test_check_dataset_missing_infos(tmp_path: Path, dataset_name: str) -> None: assert any(msg.code == "MISSING_PATH" for msg in report.messages) +@pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) +def test_check_dataset_missing_required_layout_after_valid_infos( + tmp_path: Path, dataset_name: str +) -> None: + """Missing layout file (other than infos.yaml) should short-circuit checks.""" + dataset_path = _copy_reference_dataset(tmp_path, dataset_name) + if dataset_name == "dataset_cgns": + # CGNS backend only requires infos.yaml + data/. + shutil.rmtree(dataset_path / "data") + else: + (dataset_path / "variable_schema.yaml").unlink() + + report = check_dataset(dataset_path) + + assert report.has_errors() + assert any(msg.code == "MISSING_PATH" for msg in report.messages) + # The early return on missing layout means we never reach init-related codes. + assert not any(msg.code == "DATASET_INIT_ERROR" for msg in report.messages) + + @pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) def test_check_dataset_rejects_extra_infos_key( tmp_path: Path, dataset_name: str @@ -643,6 +663,35 @@ def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> N assert any(msg.code == "SAMPLE_CONVERSION_ERROR" for msg in report.messages) +def test_check_dataset_init_keyerror_reported_as_missing_split( + tmp_path: Path, monkeypatch +) -> None: + """KeyError raised by `init_from_disk` should map to NUM_SAMPLES_MISSING_SPLIT.""" + dataset = _make_minimal_layout(tmp_path) + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 1}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Var": {}}, {"train": {}}, None), # noqa: ARG005 + ) + + def _raise_key_error(path): # noqa: ARG001 + raise KeyError("ghost_split") + + monkeypatch.setattr(plaidcheck, "init_from_disk", _raise_key_error) + + report = check_dataset(dataset) + + assert any(msg.code == "NUM_SAMPLES_MISSING_SPLIT" for msg in report.messages) + assert any("ghost_split" in msg.message for msg in report.messages) + assert not any(msg.code == "DATASET_INIT_ERROR" for msg in report.messages) + + def test_check_dataset_missing_num_samples_split_is_clear( tmp_path: Path, monkeypatch ) -> None: diff --git a/tests/storage/test_common_writer.py b/tests/storage/test_common_writer.py new file mode 100644 index 00000000..76eb9ac8 --- /dev/null +++ b/tests/storage/test_common_writer.py @@ -0,0 +1,65 @@ +"""Tests for `plaid.storage.common.writer` validation paths.""" + +from pathlib import Path + +import pytest + +from plaid.problem_definition import ProblemDefinition +from plaid.storage.common.writer import save_problem_definitions_to_disk + + +def _make_pb_def() -> ProblemDefinition: + return ProblemDefinition( + input_features=["Global/in"], + output_features=["Global/out"], + train_split={"train": [0]}, + test_split={"test": [0]}, + ) + + +def test_save_problem_definitions_to_disk_rejects_non_dict_non_pbdef( + tmp_path: Path, +) -> None: + """Passing a non-dict, non-ProblemDefinition value should raise TypeError.""" + with pytest.raises(TypeError, match=r"dict\[str, ProblemDefinition\]"): + save_problem_definitions_to_disk(tmp_path, [("name", _make_pb_def())]) # type: ignore[arg-type] + + +def test_save_problem_definitions_to_disk_rejects_non_string_identifier( + tmp_path: Path, +) -> None: + """Non-string / empty identifiers should raise TypeError.""" + pb_def = _make_pb_def() + with pytest.raises(TypeError, match="non-empty strings"): + save_problem_definitions_to_disk(tmp_path, {123: pb_def}) # type: ignore[dict-item] + with pytest.raises(TypeError, match="non-empty strings"): + save_problem_definitions_to_disk(tmp_path, {"": pb_def}) + + +def test_save_problem_definitions_to_disk_rejects_non_pbdef_value( + tmp_path: Path, +) -> None: + """Non-ProblemDefinition values should raise TypeError.""" + with pytest.raises(TypeError, match="ProblemDefinition instances"): + save_problem_definitions_to_disk(tmp_path, {"pb": "not a pb_def"}) # type: ignore[dict-item] + + +def test_save_problem_definitions_to_disk_rejects_bare_pbdef(tmp_path: Path) -> None: + """Passing a bare ProblemDefinition (not wrapped in a dict) should raise.""" + with pytest.raises(TypeError, match="use the dictionary key as the problem"): + save_problem_definitions_to_disk(tmp_path, _make_pb_def()) # type: ignore[arg-type] + + +def test_save_problem_definitions_to_disk_writes_each_definition( + tmp_path: Path, +) -> None: + """Happy path: each ProblemDefinition is delegated to its `save_to_file`.""" + pb_defs = {"pb_a": _make_pb_def(), "pb_b": _make_pb_def()} + + save_problem_definitions_to_disk(tmp_path, pb_defs) + + target_dir = tmp_path / "problem_definitions" + assert target_dir.is_dir() + # ProblemDefinition.save_to_file serialises each definition as a YAML file. + assert (target_dir / "pb_a.yaml").is_file() + assert (target_dir / "pb_b.yaml").is_file() From 0522ec1238426616db313c6a1eb034fbc3643b3b Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Fri, 5 Jun 2026 17:17:56 +0000 Subject: [PATCH 10/11] wip --- src/plaid/cli/plaidcheck.py | 22 +++++++++------------- tests/cli/test_plaidcheck.py | 20 +++++++------------- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/src/plaid/cli/plaidcheck.py b/src/plaid/cli/plaidcheck.py index ae73efee..61712a83 100644 --- a/src/plaid/cli/plaidcheck.py +++ b/src/plaid/cli/plaidcheck.py @@ -331,8 +331,14 @@ def _check_problem_definition_sample_features( """Instantiate and validate one problem-definition sample view. The sample is instantiated with the exact feature subset requested by the - problem definition, then each requested feature is read back and checked for - invalid content (None, NaN, Inf, empty arrays, object arrays containing None). + problem definition, then each requested feature is read back to validate + that the requested feature paths can actually be resolved. + + Numeric content (NaN, Inf, None, empty arrays, ...) is intentionally not + re-checked here: the per-split loop in :func:`check_dataset` already walks + every sample's globals and fields and reports such issues with the + ``INVALID_DATA_VALUE A`` code. Re-checking them in this loop would only + produce duplicate warnings under a different code/location. Args: pb_name: Problem-definition name. @@ -358,7 +364,7 @@ def _check_problem_definition_sample_features( for feature in features: try: - value = sample.get_feature_by_path(feature) + sample.get_feature_by_path(feature) except Exception as exc: report.add( "error", @@ -366,16 +372,6 @@ def _check_problem_definition_sample_features( f"{location} {feature}", str(exc), ) - continue - - issue = _check_numeric_content(value) - if issue is not None: - report.add( - "warning", - "PB_DEF_INVALID_FEATURE_VALUE", - f"{location} {feature}", - issue, - ) def compute_checksum(sample: Any) -> str: diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index 9f24c120..6fb571e3 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -480,12 +480,10 @@ def test_check_problem_definition_sample_reports_feature_read_error_and_continue and msg.message == "cannot read Unreadable" for msg in report.messages ) - assert any( - msg.severity == "warning" - and msg.code == "PB_DEF_INVALID_FEATURE_VALUE" - and msg.location == "problem_definitions/pb/test_split/test[0] BadValue" - and msg.message == "contains NaN" - for msg in report.messages + # Numeric content is intentionally not re-checked here: it is already + # validated by the per-split loop in `check_dataset`. + assert not any( + msg.code == "PB_DEF_INVALID_FEATURE_VALUE" for msg in report.messages ) @@ -847,14 +845,10 @@ class _PBDef: assert train_converter.feature_requests[-1] == ["Input", "Output"] assert test_converter.feature_requests[-1] == ["Input"] - invalid = [ - msg for msg in report.messages if msg.code == "PB_DEF_INVALID_FEATURE_VALUE" - ] - assert any( - "train_split" in msg.location and "Output" in msg.location for msg in invalid - ) + # The pb-def loop no longer re-checks numeric content; it only verifies that + # the requested feature subset can be converted and read back. assert not any( - "test_split" in msg.location and "Output" in msg.location for msg in invalid + msg.code == "PB_DEF_INVALID_FEATURE_VALUE" for msg in report.messages ) From a769e41219ecc9339d22ade4e53695a31cdd5f46 Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Fri, 5 Jun 2026 17:43:17 +0000 Subject: [PATCH 11/11] wip --- docs/source/tutorials/storage.md | 280 +++++++++++++------------------ 1 file changed, 120 insertions(+), 160 deletions(-) diff --git a/docs/source/tutorials/storage.md b/docs/source/tutorials/storage.md index ac4dbe03..72bdc663 100644 --- a/docs/source/tutorials/storage.md +++ b/docs/source/tutorials/storage.md @@ -34,7 +34,6 @@ workflow with minimal changes to your PLAID code. ```python import time from pathlib import Path -import shutil import numpy as np @@ -50,16 +49,15 @@ import Muscat.MeshContainers.ElementsDescription as ED N_PROC = 6 # number of parallel processes (set to 1 for sequential execution) -# raw data dowloaded from https://zenodo.org/records/13993629 +# raw data downloaded from https://zenodo.org/records/13993629 # set the folder where the raw data has been downloaded: BASE_RAW_DATA_FOLDER = "/path/to/raw" # TO UPDATE # set the folder where the data converted to plaid will be saved locally BASE_GENERATED_DATA_FOLDER = "/path/to/generated" # TO UPDATE -# set the Huggging Face's repo_id where the datasets will be uploaded +# set the Hugging Face's repo_id where the datasets will be uploaded BASE_REPO_ID = "channel/ShapeNetCar" # TO UPDATE - - -all_backends = ["hf_datasets", "cgns", "zarr"] +# set the folder where the downloaded data will be saved locally +BASE_DOWNLOADED_DATA_FOLDER = "/path/to/downloaded" # TO UPDATE #--------------------------------------------------------------- # define some functions to handle ShapeNetCar data @@ -80,35 +78,27 @@ tri_folders = [p for p in base_dir.iterdir() if p.is_dir()] curated_train_ids = [] curated_test_ids = [] -count = 0 -for folder in tri_folders: +for count, folder in enumerate(tri_folders): id_ = int(folder.name) if id_ in train_ids: curated_train_ids.append(count) else: curated_test_ids.append(count) - count+=1 # we can reduced the number of samples in each split for faster execution curated_train_ids = curated_train_ids[:10] curated_test_ids = curated_test_ids[:10] #--------------------------------------------------------------- -# infos and problem definition must be define to correctly populate the dataset's metadata - -infos = Infos.model_validate( - { - "owner": "NeuralOperator (https://zenodo.org/records/13993629)", - "license": "cc-by-4.0", - "data_production": { - "physics": "CFD", - "type": "simulation", - "script": "Converted to PLAID format for standardized access", - }, - "data_description": "No changes to data content from original dataset", - } +# infos and problem definition can be defined to correctly populate the dataset's metadata (they are not mandatory) + +infos = Infos( + owner="NeuralOperator (https://zenodo.org/records/13993629)", + license="cc-by-4.0", + data_description="No changes to data content from original dataset", ) + input_features = [ "Base_2_3/Zone/Elements_TRI_3/ElementConnectivity", "Base_2_3/Zone/GridCoordinates/CoordinateX", @@ -120,7 +110,6 @@ output_features = [ "Base_2_3/Zone/VertexFields/pressure", ] - pb_def = ProblemDefinition( input_features=input_features, output_features=output_features, @@ -162,40 +151,31 @@ def sample_constructor(i): ids = {"train": curated_train_ids, "test": curated_test_ids} -for backend in all_backends: - - print("--------------------------------------") - print(f"Backend: {backend}, N_PROC: {N_PROC}") - - repo_id = f"{BASE_REPO_ID}_{backend}" - local_folder = f"{BASE_GENERATED_DATA_FOLDER}/{backend}_dataset" - - # DISK - start = time.time() - save_to_disk(output_folder=local_folder, - sample_constructor=sample_constructor, - ids=ids, - backend=backend, - infos=infos, - pb_defs={"regression_1": pb_def}, - num_proc=N_PROC, - overwrite=True, - verbose=True) - print(f"duration generate with num_proc={N_PROC} is {time.time()-start} s") - - # HUB - start = time.time() - push_to_hub(repo_id=repo_id, - local_dir=local_folder, - num_workers=N_PROC, - viewer=backend == "hf_datasets", - illustration_urls=["https://i.ibb.co/3mGHsHMk/Shape-Net-Car-samples.png"]) - print(f"duration push to hub N_PROC={N_PROC} is {time.time()-start} s") - -# Note: for maximal compatibility, you may need to call the `save_to_disk` and `push_to_hub` under the `if __name__ == "__main__":` context; - -if Path(tmp_cache_dir).exists(): - shutil.rmtree(Path(tmp_cache_dir)) +local_folder = f"{BASE_GENERATED_DATA_FOLDER}/hf_dataset" + +# DISK +start = time.time() +save_to_disk(output_folder=local_folder, + sample_constructor=sample_constructor, + ids=ids, + backend="hf_datasets", + infos=infos, + pb_defs={"regression_1": pb_def}, + num_proc=N_PROC, + overwrite=True, + verbose=True) +print(f"duration generate with num_proc={N_PROC} is {time.time()-start} s") + +# HUB +start = time.time() +push_to_hub(repo_id=BASE_REPO_ID, + local_dir=local_folder, + num_workers=N_PROC, + viewer=True, + illustration_urls=["https://i.ibb.co/3mGHsHMk/Shape-Net-Car-samples.png"]) +print(f"duration push to hub N_PROC={N_PROC} is {time.time()-start} s") + +# Note: for maximal compatibility, you may need to call `save_to_disk` and `push_to_hub` under an `if __name__ == "__main__":` guard. ``` ## How to read data from disk/hub @@ -212,90 +192,75 @@ from plaid.storage import init_from_disk, download_from_hub, init_streaming_from from plaid.storage import load_problem_definitions_from_disk -# set the Huggging Face's repo_id from which the datasets will be downloaded -BASE_REPO_ID = "channel/ShapeNetCar" # TO UPDATE -# set the folder where the downloaded data will be saved locally -BASE_DOWNLOADED_DATA_FOLDER = "/mnt/e/converted_datasets/ShapeNet-Car" # TO UPDATE - -all_backends = ["hf_datasets", "cgns", "zarr"] split = "train" -# Load problem definitions and define features as all the input and output features -pb_defs = load_problem_definitions_from_disk(f"{BASE_DOWNLOADED_DATA_FOLDER}/{all_backends[0]}_dataset") -pb_def = next(iter(pb_defs.values())) -features = pb_def.input_features + pb_def.output_features - print("----------------------------------------------------") print("-- Download datasets -------------------------------") print("----------------------------------------------------") -# download datasets -for backend in all_backends: - repo_id = f"{BASE_REPO_ID}_{backend}" - download_folder = f"{BASE_DOWNLOADED_DATA_FOLDER}/downloaded_{backend}_dataset" +# download dataset +download_folder = f"{BASE_DOWNLOADED_DATA_FOLDER}/downloaded_hf_dataset" + +# depending on the backends, one can download a subset of the samples and features. We keep them all here +split_ids_ = None +features_ = None - # depending on the backends, one can download a subset of the samples and features. We keep them all here - split_ids_ = None - features_ = None +download_from_hub(BASE_REPO_ID, download_folder, split_ids=split_ids_, features=features_, overwrite=True) - download_from_hub(repo_id, download_folder, split_ids = split_ids_, features = features_, overwrite = True) +# Load problem definitions and define features as all the input and output features +pb_defs = load_problem_definitions_from_disk(download_folder) +pb_def = next(iter(pb_defs.values())) +features = pb_def.input_features + pb_def.output_features print("-------------------------------------------------------") print("-- Dataset local read and plaid sample instantiation --") print("-------------------------------------------------------") -for backend in all_backends: - - datasetdict, converterdict = init_from_disk(f"{BASE_DOWNLOADED_DATA_FOLDER}/downloaded_{backend}_dataset") +datasetdict, converterdict = init_from_disk(download_folder) - # specify one dataset/converter pair for one split - dataset = datasetdict[split] - converter = converterdict[split] +# specify one dataset/converter pair for one split +dataset = datasetdict[split] +converter = converterdict[split] - print("backend: ", converter.backend) +# generic way to instantiate all the samples +start = time.time() +for i in range(len(dataset)): + plaid_sample = converter.to_plaid(dataset, i) +print(f"duration {time.time()-start}") - # generic way to instantiate all the samples - start = time.time() - for i in range(len(dataset)): - plaid_sample = converter.to_plaid(dataset, i) - print(f"duration {time.time()-start}") - - # Optional: extract only selected indices inside specific variable features - # (currently supported for hf_datasets and zarr backends). - field_path = "Base_2_3/Zone/VertexFields/pressure" - selected_idx = [0, 10, 20, 30] - plaid_sample_sub = converter.to_plaid( - dataset, - 0, - features=[field_path], - indexers={field_path: selected_idx}, - ) - - # instantiate the first sample, depends on the backend - sample = dataset[0] - # alternative way instantiate a plaid sample (much slower for hf_datasets) - plaid_sample = converter.sample_to_plaid(dataset[0]) +# Optional: extract only selected indices inside specific variable features +# (currently supported for hf_datasets and zarr backends). +field_path = "Base_2_3/Zone/VertexFields/pressure" +selected_idx = [0, 10, 20, 30] +plaid_sample_sub = converter.to_plaid( + dataset, + 0, + features=[field_path], + indexers={field_path: selected_idx}, +) - # save a plaid sample in a CGNS that can be opened in paraview - plaid_sample.save_to_dir(f"{BASE_DOWNLOADED_DATA_FOLDER}/sample_0_{backend}", overwrite = True) +# raw backend record for the first sample (format is backend-specific, no PLAID instantiation) +sample = dataset[0] +# alternative way to instantiate a plaid sample (much slower for hf_datasets) +plaid_sample = converter.sample_to_plaid(dataset[0]) - # generic way to read all features for all time steps - for t in plaid_sample.get_all_time_values(): - for path in pb_def.input_features: - plaid_sample.get_feature_by_path(path=path, time=t) - for path in pb_def.output_features: - plaid_sample.get_feature_by_path(path=path, time=t) +# save a plaid sample in a CGNS that can be opened in paraview +plaid_sample.save_to_dir(f"{BASE_DOWNLOADED_DATA_FOLDER}/sample_0_hf", overwrite = True) - # generic way to return the data as a dict containing all constant and variable features - if backend != "cgns": - sample_dict = converter.to_dict(dataset, 0) - sample_dict = converter.sample_to_dict(dataset[0]) +# generic way to access all features for all time steps (values are returned but not stored here) +for t in plaid_sample.get_all_time_values(): + for path in pb_def.input_features: + _ = plaid_sample.get_feature_by_path(path=path, time=t) + for path in pb_def.output_features: + _ = plaid_sample.get_feature_by_path(path=path, time=t) - # alternative way to return the data as a dict containing all constant and variable features from a plaid sample - sample_dict = converter.plaid_to_dict(plaid_sample) +# generic way to return the data as a dict containing all constant and variable features +sample_dict = converter.to_dict(dataset, 0) +sample_dict = converter.sample_to_dict(dataset[0]) - print("----------") +# alternative way to return the data as a dict containing all constant and variable features from a plaid sample +sample_dict = converter.plaid_to_dict(plaid_sample) print("----------------------------------------------------") @@ -313,37 +278,34 @@ class IndexDataset(torch.utils.data.Dataset): def __getitem__(self, idx): return idx -for backend in all_backends: - datasetdict, converterdict = init_from_disk(f"{BASE_DOWNLOADED_DATA_FOLDER}/{backend}_dataset") - dataset = datasetdict[split] - converter = converterdict[split] - - # define a torch dataloader directly from this IndexDataset class - loader = DataLoader( - IndexDataset(len(dataset)), - batch_size=10, - shuffle=False, - num_workers=12, - pin_memory=True, - persistent_workers=True - ) - print("backend: ", converter.backend) - start = time.time() - for batch in loader: - for idx in batch: - # efficient plaid sample reconstruction - plaid_sample = converter.to_plaid(dataset, idx) - # generic way of retrieving features and send them to GPU - for time_ in plaid_sample.get_all_time_values(): - torch_sample = {} - for path in features: - value = plaid_sample.get_feature_by_path(path=path, time=time_) - if value is not None: - if not value.flags.writeable: - value = value.copy() - torch_sample[path] = torch.as_tensor(value).to("cuda", non_blocking=True) - print(f"duration {time.time()-start}") - print("----------") +datasetdict, converterdict = init_from_disk(download_folder) +dataset = datasetdict[split] +converter = converterdict[split] + +# define a torch dataloader directly from this IndexDataset class +loader = DataLoader( + IndexDataset(len(dataset)), + batch_size=10, + shuffle=False, + num_workers=N_PROC, + pin_memory=True, + persistent_workers=True +) +start = time.time() +for batch in loader: + for idx in batch: + # efficient plaid sample reconstruction + plaid_sample = converter.to_plaid(dataset, idx) + # generic way of retrieving features and send them to GPU + for time_ in plaid_sample.get_all_time_values(): + torch_sample = {} + for path in features: + value = plaid_sample.get_feature_by_path(path=path, time=time_) + if value is not None: + if not value.flags.writeable: + value = value.copy() + torch_sample[path] = torch.as_tensor(value).to("cuda", non_blocking=True) +print(f"duration {time.time()-start}") ``` @@ -408,17 +370,15 @@ print("-- Streaming test ----------------------------------") print("----------------------------------------------------") -for backend in all_backends: - - datasetdict, converterdict = init_streaming_from_hub(f"{BASE_REPO_ID}_{backend}") +datasetdict, converterdict = init_streaming_from_hub(BASE_REPO_ID) - dataset = datasetdict[split] - converter = converterdict[split] +dataset = datasetdict[split] +converter = converterdict[split] - # dataset here is an IterableDataset, retrieving one sample and converting it to plaid - raw_sample = next(iter(dataset)) - plaid_sample = converter.sample_to_plaid(raw_sample) +# dataset here is an IterableDataset, retrieving one sample and converting it to plaid +raw_sample = next(iter(dataset)) +plaid_sample = converter.sample_to_plaid(raw_sample) - # utility to print a summary of the CGNS tree from the plaid sample - show_cgns_tree(plaid_sample.get_tree(0.)) +# utility to print a summary of the CGNS tree from the plaid sample +show_cgns_tree(plaid_sample.get_tree(0.)) ```