diff --git a/docs/source/concepts/check.md b/docs/source/concepts/check.md new file mode 100644 index 00000000..d88c9514 --- /dev/null +++ b/docs/source/concepts/check.md @@ -0,0 +1,63 @@ +--- +title: Dataset check +--- + +# Dataset check + +`plaid-check` validates the integrity of a local PLAID dataset. + +It checks: + +- required on-disk files and directories; +- `infos.yaml`, metadata and split sample counts; +- sample conversion through the declared storage backend; +- invalid numeric values such as `None`, empty arrays, `NaN` and `Inf`; +- duplicated samples; +- optional `problem_definitions/` feature names, splits and indices. + +## Basic usage + +```bash +plaid-check /path/to/plaid_dataset +``` + +A valid dataset prints an `[OK]` line and returns exit code `0`. + +## Options + +Check only selected splits: + +```bash +plaid-check /path/to/plaid_dataset --split train --split test +``` + +Check only selected problem definitions: + +```bash +plaid-check /path/to/plaid_dataset --problem-definition regression_500 +``` + +Emit a machine-readable report: + +```bash +plaid-check /path/to/plaid_dataset --json +``` + +Make warnings fail the command: + +```bash +plaid-check /path/to/plaid_dataset --strict +``` + +## Report format + +Messages are reported with a severity, a stable code, a location and a short +description. Errors return exit code `1`; warnings return exit code `2` only in +strict mode. + +## Notes + +- For CGNS datasets, only `infos.yaml` and `data/` are required at the root. +- For other backends, metadata files and `constants/` are checked as well. +- Without `--problem-definition`, all discovered problem definitions are checked. +- In JSON mode, progress bars are disabled to keep the output parseable. \ No newline at end of file diff --git a/docs/source/concepts/dataset.md b/docs/source/concepts/dataset.md index a455ac51..17432e77 100644 --- a/docs/source/concepts/dataset.md +++ b/docs/source/concepts/dataset.md @@ -117,12 +117,18 @@ dictionary with the same schema: ```python from plaid import ProblemDefinition -from plaid.infos import DataDescription, Infos, Legal +from plaid.infos import DataDescription, Infos from plaid.storage import save_to_disk -pb_def = ProblemDefinition(name="regression_1") +pb_def = ProblemDefinition( + input_features=["Global/input"], + output_features=["Base/Zone/VertexFields/pressure"], + train_split={"train": [0, 1, 2]}, + test_split={"test": "all"}, +) infos = Infos( - legal=Legal(owner="CompanyX", license="proprietary"), + owner="CompanyX", + license="proprietary", data_description=DataDescription(number_of_samples=3), num_samples={"train": 3}, ) @@ -132,7 +138,7 @@ save_to_disk( sample_constructor=sample_constructor, ids={"train": [0, 1, 2]}, infos=infos, - pb_defs=pb_def, + pb_defs={"regression_1": pb_def}, ) ``` diff --git a/docs/source/concepts/infos.md b/docs/source/concepts/infos.md index 06019e2e..da6dc815 100644 --- a/docs/source/concepts/infos.md +++ b/docs/source/concepts/infos.md @@ -9,21 +9,23 @@ root of a PLAID dataset. In the current API, infos stores: -- `legal`, with required `owner` and `license` entries +- `owner` and `license`, required string entries describing the dataset + ownership and licensing - `data_production`, for optional production context such as simulator, hardware, contact, or location - `data_description`, for optional dataset description entries such as the number of samples, DOE, inputs, and outputs -- `num_samples`, as a dictionary keyed by split name -- `storage_backend` +- `num_samples`, as a dictionary keyed by split name, populated by storage writers +- `storage_backend`, as a storage backend identifier, populated by storage writers ## Basic usage ```python -from plaid.infos import DataProduction, Infos, Legal +from plaid.infos import DataProduction, Infos infos = Infos( - legal=Legal(owner="Safran", license="proprietary"), + owner="Safran", + license="proprietary", data_production=DataProduction( type="simulation", physics="fluid dynamics", @@ -36,26 +38,35 @@ infos = Infos( Infos can also be built from a plain mapping, for instance after reading YAML: ```python -infos = Infos.from_mapping( +infos = Infos.model_validate( { - "legal": { - "owner": "Safran", - "license": "proprietary", - }, + "owner": "Safran", + "license": "proprietary", } ) ``` +`num_samples` and `storage_backend` are derived from the chosen storage backend +and the saved split contents. They can be omitted when creating an `Infos` +object that will later be passed to `save_to_disk(...)`; PLAID fills them before +writing `infos.yaml`. + ## Loading from disk -Load infos from a dataset path or directly from an `infos.yaml` file: +Load infos from a complete dataset path or directly from an `infos.yaml` file: ```python infos = Infos.from_path("/path/to/plaid_dataset") ``` When a directory is provided, `Infos.from_path(...)` looks for `infos.yaml` -inside that directory. +inside that directory. By default, loading from disk requires the persisted +storage metadata (`num_samples` and `storage_backend`) to be present. To load a +draft infos file that has not been produced by `save_to_disk(...)`, use: + +```python +infos = Infos.from_path("/path/to/draft/infos.yaml", require_persisted=False) +``` ## Saving @@ -68,20 +79,21 @@ infos.save_to_file("/path/to/plaid_dataset/infos.yaml") If a directory path is provided, the file is saved as `infos.yaml` inside that directory. -## Mapping-style access +## Typed access and serialization -`Infos` provides read-only mapping-style helpers for compatibility with code -expecting a YAML-like dictionary: +`Infos` is a Pydantic model. Access metadata through typed attributes and use +Pydantic serialization when a plain mapping is needed: ```python -owner = infos["legal"]["owner"] -backend = infos.get("storage_backend") -payload = infos.to_dict() +owner = infos.owner +backend = infos.storage_backend +payload = infos.model_dump(exclude_none=True) ``` ## Notes -- `legal.owner` and `legal.license` are required when validating complete infos. -- `num_samples` and `storage_backend` are automatically filled when `save_to_disk(..., infos=...)` is called before writing `infos.yaml`. +- `owner` and `license` are required when creating infos. +- `num_samples` and `storage_backend` are required when loading persisted dataset infos. +- `num_samples` and `storage_backend` are overwritten with the actual saved dataset values when `save_to_disk(..., infos=...)` is called before writing `infos.yaml`. - Unknown keys are rejected during validation. - `save_to_file(...)` writes YAML using the standard infos key order. diff --git a/docs/source/concepts/problem_definition.md b/docs/source/concepts/problem_definition.md index 239b384b..86fd9a6d 100644 --- a/docs/source/concepts/problem_definition.md +++ b/docs/source/concepts/problem_definition.md @@ -8,42 +8,59 @@ title: Problem definition In the current API, a problem definition stores: -- `name` -- `input_features` (`list[str]`) -- `output_features` (`list[str]`) -- `train_split` and `test_split` +- `input_features` (`list[str]`, required and non-empty) +- `output_features` (`list[str]`, required and non-empty) +- `train_split` and `test_split` (required) + +The problem identifier is not stored in the model. On disk, it is the YAML +filename stem; in memory, it is the dictionary key used for the definition. ## Basic usage ```python from plaid import ProblemDefinition -pb = ProblemDefinition(name="regression_1") - -pb.add_input_features([ - "Base/Zone/GridCoordinates/CoordinateX", - "Base/Zone/GridCoordinates/CoordinateY", -]) -pb.add_output_features([ - "Base/Zone/VertexFields/pressure", -]) - -pb.train_split = {"train": [0, 1, 2]} -pb.test_split = {"test": [3, 4]} +pb = ProblemDefinition( + input_features=[ + "Base/Zone/GridCoordinates/CoordinateX", + "Base/Zone/GridCoordinates/CoordinateY", + ], + output_features=[ + "Base/Zone/VertexFields/pressure", + ], + train_split={"train": [0, 1, 2]}, + test_split={"test": [3, 4]}, +) ``` Feature lists are normalized by the model: entries are converted to strings, -sorted, and checked for duplicates. +sorted, checked for duplicates, and rejected if empty. + +Problem definitions can also be validated from a plain mapping, for instance +after reading YAML: + +```python +pb = ProblemDefinition.model_validate( + { + "input_features": ["Base/Zone/GridCoordinates/CoordinateX"], + "output_features": ["Base/Zone/VertexFields/pressure"], + "train_split": {"train": [0, 1, 2]}, + "test_split": {"test": [3, 4]}, + } +) +``` ## Loading from disk Load a definition from a dataset path: ```python -pb = ProblemDefinition.from_path("/path/to/plaid_dataset", name="regression_1") +pb = ProblemDefinition.from_path( + "/path/to/plaid_dataset/problem_definitions/regression_1.yaml" +) ``` -At storage level, problem definitions are loaded as a dictionary keyed by name: +At storage level, problem definitions are loaded as a dictionary keyed by YAML filename stem: ```python from plaid.storage import load_problem_definitions_from_disk @@ -60,10 +77,13 @@ Save to YAML: pb.save_to_file("problem_definitions/regression_1.yaml") ``` +This writes no `name:` key; `regression_1` is inferred from the filename by the +storage loader. + ## Notes -- Input/output features are plain strings correspond to CGNS paths. -- Splits are represented by `train_split` and `test_split` dictionaries. +- Input/output features are plain strings corresponding to CGNS paths. +- Splits are represented by `train_split` and `test_split` dictionaries and are accessed directly as model attributes. - Split values can be explicit index sequences or the string `"all"`. - `add_input_features(...)` and `add_output_features(...)` accept either a - single string or a sequence of strings. \ No newline at end of file + single string or a sequence of strings after initialization. \ No newline at end of file diff --git a/docs/source/examples_tutorials.md b/docs/source/examples_tutorials.md index cfcb048a..095e2048 100644 --- a/docs/source/examples_tutorials.md +++ b/docs/source/examples_tutorials.md @@ -12,8 +12,8 @@ You can find here detailed examples for different parts of plaid, explained in J * [Sample](notebooks/containers/sample_example.md) * [Problem definition](notebooks/problem_definition_example.md) * [Infos](notebooks/infos_example.md) -* [Downloadable samples](notebooks/downloadable_example/sample_example.md) ## Tutorials +* [Downloadable samples](tutorials/downloadable_example.md) * [Conversion tutorial](tutorials/storage.md) diff --git a/docs/source/tutorials/downloadable_example.md b/docs/source/tutorials/downloadable_example.md new file mode 100644 index 00000000..99fdb9d7 --- /dev/null +++ b/docs/source/tutorials/downloadable_example.md @@ -0,0 +1,22 @@ +--- +title: Downloadable samples +--- + +# Downloadable samples + +## First retrieval + +Retrieving sample examples is as easy as: + +```python +from plaid.downloadable_examples import AVAILABLE_EXAMPLES, samples + +print(AVAILABLE_EXAMPLES) +print("samples.vki_ls59:", samples.vki_ls59) +``` + +The first call to `samples.vki_ls59` triggers a download and takes a few seconds. + +## Cached retrieval + +Subsequent calls are instantaneous because they reuse the cached sample. \ No newline at end of file diff --git a/docs/source/tutorials/storage.md b/docs/source/tutorials/storage.md index b7ef5500..72bdc663 100644 --- a/docs/source/tutorials/storage.md +++ b/docs/source/tutorials/storage.md @@ -34,7 +34,6 @@ workflow with minimal changes to your PLAID code. ```python import time from pathlib import Path -import shutil import numpy as np @@ -50,16 +49,15 @@ import Muscat.MeshContainers.ElementsDescription as ED N_PROC = 6 # number of parallel processes (set to 1 for sequential execution) -# raw data dowloaded from https://zenodo.org/records/13993629 +# raw data downloaded from https://zenodo.org/records/13993629 # set the folder where the raw data has been downloaded: BASE_RAW_DATA_FOLDER = "/path/to/raw" # TO UPDATE # set the folder where the data converted to plaid will be saved locally BASE_GENERATED_DATA_FOLDER = "/path/to/generated" # TO UPDATE -# set the Huggging Face's repo_id where the datasets will be uploaded +# set the Hugging Face's repo_id where the datasets will be uploaded BASE_REPO_ID = "channel/ShapeNetCar" # TO UPDATE - - -all_backends = ["hf_datasets", "cgns", "zarr"] +# set the folder where the downloaded data will be saved locally +BASE_DOWNLOADED_DATA_FOLDER = "/path/to/downloaded" # TO UPDATE #--------------------------------------------------------------- # define some functions to handle ShapeNetCar data @@ -80,37 +78,27 @@ tri_folders = [p for p in base_dir.iterdir() if p.is_dir()] curated_train_ids = [] curated_test_ids = [] -count = 0 -for folder in tri_folders: +for count, folder in enumerate(tri_folders): id_ = int(folder.name) if id_ in train_ids: curated_train_ids.append(count) else: curated_test_ids.append(count) - count+=1 # we can reduced the number of samples in each split for faster execution curated_train_ids = curated_train_ids[:10] curated_test_ids = curated_test_ids[:10] #--------------------------------------------------------------- -# infos and problem definition must be define to correctly populate the dataset's metadata - -infos = Infos.from_mapping( - { - "legal": { - "owner": "NeuralOperator (https://zenodo.org/records/13993629)", - "license": "cc-by-4.0", - }, - "data_production": { - "physics": "CFD", - "type": "simulation", - "script": "Converted to PLAID format for standardized access", - }, - "data_description": "No changes to data content from original dataset", - } +# infos and problem definition can be defined to correctly populate the dataset's metadata (they are not mandatory) + +infos = Infos( + owner="NeuralOperator (https://zenodo.org/records/13993629)", + license="cc-by-4.0", + data_description="No changes to data content from original dataset", ) + input_features = [ "Base_2_3/Zone/Elements_TRI_3/ElementConnectivity", "Base_2_3/Zone/GridCoordinates/CoordinateX", @@ -122,12 +110,12 @@ output_features = [ "Base_2_3/Zone/VertexFields/pressure", ] - -pb_def = ProblemDefinition(name="regression_1") -pb_def.add_input_features(input_features) -pb_def.add_output_features(output_features) -pb_def.train_split = {"train":"all"} -pb_def.test_split = {"test":"all"} +pb_def = ProblemDefinition( + input_features=input_features, + output_features=output_features, + train_split={"train": "all"}, + test_split={"test": "all"}, +) #--------------------------------------------------------------- # Define a simple function that takes a single identifier and returns a Sample. @@ -163,40 +151,31 @@ def sample_constructor(i): ids = {"train": curated_train_ids, "test": curated_test_ids} -for backend in all_backends: - - print("--------------------------------------") - print(f"Backend: {backend}, N_PROC: {N_PROC}") - - repo_id = f"{BASE_REPO_ID}_{backend}" - local_folder = f"{BASE_GENERATED_DATA_FOLDER}/{backend}_dataset" - - # DISK - start = time.time() - save_to_disk(output_folder=local_folder, - sample_constructor=sample_constructor, - ids=ids, - backend=backend, - infos=infos, - pb_defs=pb_def, - num_proc=N_PROC, - overwrite=True, - verbose=True) - print(f"duration generate with num_proc={N_PROC} is {time.time()-start} s") - - # HUB - start = time.time() - push_to_hub(repo_id=repo_id, - local_dir=local_folder, - num_workers=N_PROC, - viewer=backend == "hf_datasets", - illustration_urls=["https://i.ibb.co/3mGHsHMk/Shape-Net-Car-samples.png"]) - print(f"duration push to hub N_PROC={N_PROC} is {time.time()-start} s") - -# Note: for maximal compatibility, you may need to call the `save_to_disk` and `push_to_hub` under the `if __name__ == "__main__":` context; - -if Path(tmp_cache_dir).exists(): - shutil.rmtree(Path(tmp_cache_dir)) +local_folder = f"{BASE_GENERATED_DATA_FOLDER}/hf_dataset" + +# DISK +start = time.time() +save_to_disk(output_folder=local_folder, + sample_constructor=sample_constructor, + ids=ids, + backend="hf_datasets", + infos=infos, + pb_defs={"regression_1": pb_def}, + num_proc=N_PROC, + overwrite=True, + verbose=True) +print(f"duration generate with num_proc={N_PROC} is {time.time()-start} s") + +# HUB +start = time.time() +push_to_hub(repo_id=BASE_REPO_ID, + local_dir=local_folder, + num_workers=N_PROC, + viewer=True, + illustration_urls=["https://i.ibb.co/3mGHsHMk/Shape-Net-Car-samples.png"]) +print(f"duration push to hub N_PROC={N_PROC} is {time.time()-start} s") + +# Note: for maximal compatibility, you may need to call `save_to_disk` and `push_to_hub` under an `if __name__ == "__main__":` guard. ``` ## How to read data from disk/hub @@ -213,90 +192,75 @@ from plaid.storage import init_from_disk, download_from_hub, init_streaming_from from plaid.storage import load_problem_definitions_from_disk -# set the Huggging Face's repo_id from which the datasets will be downloaded -BASE_REPO_ID = "channel/ShapeNetCar" # TO UPDATE -# set the folder where the downloaded data will be saved locally -BASE_DOWNLOADED_DATA_FOLDER = "/mnt/e/converted_datasets/ShapeNet-Car" # TO UPDATE - -all_backends = ["hf_datasets", "cgns", "zarr"] split = "train" -# Load problem definitions and define features as all the input and output features -pb_defs = load_problem_definitions_from_disk(f"{BASE_DOWNLOADED_DATA_FOLDER}/{all_backends[0]}_dataset") -pb_def = next(iter(pb_defs.values())) -features = pb_def.input_features + pb_def.output_features - print("----------------------------------------------------") print("-- Download datasets -------------------------------") print("----------------------------------------------------") -# download datasets -for backend in all_backends: - repo_id = f"{BASE_REPO_ID}_{backend}" - download_folder = f"{BASE_DOWNLOADED_DATA_FOLDER}/downloaded_{backend}_dataset" +# download dataset +download_folder = f"{BASE_DOWNLOADED_DATA_FOLDER}/downloaded_hf_dataset" + +# depending on the backends, one can download a subset of the samples and features. We keep them all here +split_ids_ = None +features_ = None - # depending on the backends, one can download a subset of the samples and features. We keep them all here - split_ids_ = None - features_ = None +download_from_hub(BASE_REPO_ID, download_folder, split_ids=split_ids_, features=features_, overwrite=True) - download_from_hub(repo_id, download_folder, split_ids = split_ids_, features = features_, overwrite = True) +# Load problem definitions and define features as all the input and output features +pb_defs = load_problem_definitions_from_disk(download_folder) +pb_def = next(iter(pb_defs.values())) +features = pb_def.input_features + pb_def.output_features print("-------------------------------------------------------") print("-- Dataset local read and plaid sample instantiation --") print("-------------------------------------------------------") -for backend in all_backends: - - datasetdict, converterdict = init_from_disk(f"{BASE_DOWNLOADED_DATA_FOLDER}/downloaded_{backend}_dataset") +datasetdict, converterdict = init_from_disk(download_folder) - # specify one dataset/converter pair for one split - dataset = datasetdict[split] - converter = converterdict[split] +# specify one dataset/converter pair for one split +dataset = datasetdict[split] +converter = converterdict[split] - print("backend: ", converter.backend) +# generic way to instantiate all the samples +start = time.time() +for i in range(len(dataset)): + plaid_sample = converter.to_plaid(dataset, i) +print(f"duration {time.time()-start}") - # generic way to instantiate all the samples - start = time.time() - for i in range(len(dataset)): - plaid_sample = converter.to_plaid(dataset, i) - print(f"duration {time.time()-start}") - - # Optional: extract only selected indices inside specific variable features - # (currently supported for hf_datasets and zarr backends). - field_path = "Base_2_3/Zone/VertexFields/pressure" - selected_idx = [0, 10, 20, 30] - plaid_sample_sub = converter.to_plaid( - dataset, - 0, - features=[field_path], - indexers={field_path: selected_idx}, - ) - - # instantiate the first sample, depends on the backend - sample = dataset[0] - # alternative way instantiate a plaid sample (much slower for hf_datasets) - plaid_sample = converter.sample_to_plaid(dataset[0]) +# Optional: extract only selected indices inside specific variable features +# (currently supported for hf_datasets and zarr backends). +field_path = "Base_2_3/Zone/VertexFields/pressure" +selected_idx = [0, 10, 20, 30] +plaid_sample_sub = converter.to_plaid( + dataset, + 0, + features=[field_path], + indexers={field_path: selected_idx}, +) - # save a plaid sample in a CGNS that can be opened in paraview - plaid_sample.save_to_dir(f"{BASE_DOWNLOADED_DATA_FOLDER}/sample_0_{backend}", overwrite = True) +# raw backend record for the first sample (format is backend-specific, no PLAID instantiation) +sample = dataset[0] +# alternative way to instantiate a plaid sample (much slower for hf_datasets) +plaid_sample = converter.sample_to_plaid(dataset[0]) - # generic way to read all features for all time steps - for t in plaid_sample.get_all_time_values(): - for path in pb_def.input_features: - plaid_sample.get_feature_by_path(path=path, time=t) - for path in pb_def.output_features: - plaid_sample.get_feature_by_path(path=path, time=t) +# save a plaid sample in a CGNS that can be opened in paraview +plaid_sample.save_to_dir(f"{BASE_DOWNLOADED_DATA_FOLDER}/sample_0_hf", overwrite = True) - # generic way to return the data as a dict containing all constant and variable features - if backend != "cgns": - sample_dict = converter.to_dict(dataset, 0) - sample_dict = converter.sample_to_dict(dataset[0]) +# generic way to access all features for all time steps (values are returned but not stored here) +for t in plaid_sample.get_all_time_values(): + for path in pb_def.input_features: + _ = plaid_sample.get_feature_by_path(path=path, time=t) + for path in pb_def.output_features: + _ = plaid_sample.get_feature_by_path(path=path, time=t) - # alternative way to return the data as a dict containing all constant and variable features from a plaid sample - sample_dict = converter.plaid_to_dict(plaid_sample) +# generic way to return the data as a dict containing all constant and variable features +sample_dict = converter.to_dict(dataset, 0) +sample_dict = converter.sample_to_dict(dataset[0]) - print("----------") +# alternative way to return the data as a dict containing all constant and variable features from a plaid sample +sample_dict = converter.plaid_to_dict(plaid_sample) print("----------------------------------------------------") @@ -314,37 +278,34 @@ class IndexDataset(torch.utils.data.Dataset): def __getitem__(self, idx): return idx -for backend in all_backends: - datasetdict, converterdict = init_from_disk(f"{BASE_DOWNLOADED_DATA_FOLDER}/{backend}_dataset") - dataset = datasetdict[split] - converter = converterdict[split] - - # define a torch dataloader directly from this IndexDataset class - loader = DataLoader( - IndexDataset(len(dataset)), - batch_size=10, - shuffle=False, - num_workers=12, - pin_memory=True, - persistent_workers=True - ) - print("backend: ", converter.backend) - start = time.time() - for batch in loader: - for idx in batch: - # efficient plaid sample reconstruction - plaid_sample = converter.to_plaid(dataset, idx) - # generic way of retrieving features and send them to GPU - for time_ in plaid_sample.get_all_time_values(): - torch_sample = {} - for path in features: - value = plaid_sample.get_feature_by_path(path=path, time=time_) - if value is not None: - if not value.flags.writeable: - value = value.copy() - torch_sample[path] = torch.as_tensor(value).to("cuda", non_blocking=True) - print(f"duration {time.time()-start}") - print("----------") +datasetdict, converterdict = init_from_disk(download_folder) +dataset = datasetdict[split] +converter = converterdict[split] + +# define a torch dataloader directly from this IndexDataset class +loader = DataLoader( + IndexDataset(len(dataset)), + batch_size=10, + shuffle=False, + num_workers=N_PROC, + pin_memory=True, + persistent_workers=True +) +start = time.time() +for batch in loader: + for idx in batch: + # efficient plaid sample reconstruction + plaid_sample = converter.to_plaid(dataset, idx) + # generic way of retrieving features and send them to GPU + for time_ in plaid_sample.get_all_time_values(): + torch_sample = {} + for path in features: + value = plaid_sample.get_feature_by_path(path=path, time=time_) + if value is not None: + if not value.flags.writeable: + value = value.copy() + torch_sample[path] = torch.as_tensor(value).to("cuda", non_blocking=True) +print(f"duration {time.time()-start}") ``` @@ -409,17 +370,15 @@ print("-- Streaming test ----------------------------------") print("----------------------------------------------------") -for backend in all_backends: - - datasetdict, converterdict = init_streaming_from_hub(f"{BASE_REPO_ID}_{backend}") +datasetdict, converterdict = init_streaming_from_hub(BASE_REPO_ID) - dataset = datasetdict[split] - converter = converterdict[split] +dataset = datasetdict[split] +converter = converterdict[split] - # dataset here is an IterableDataset, retrieving one sample and converting it to plaid - raw_sample = next(iter(dataset)) - plaid_sample = converter.sample_to_plaid(raw_sample) +# dataset here is an IterableDataset, retrieving one sample and converting it to plaid +raw_sample = next(iter(dataset)) +plaid_sample = converter.sample_to_plaid(raw_sample) - # utility to print a summary of the CGNS tree from the plaid sample - show_cgns_tree(plaid_sample.get_tree(0.)) +# utility to print a summary of the CGNS tree from the plaid sample +show_cgns_tree(plaid_sample.get_tree(0.)) ``` diff --git a/docs/zensical.toml b/docs/zensical.toml index 068031ee..c4ff7ef7 100644 --- a/docs/zensical.toml +++ b/docs/zensical.toml @@ -52,10 +52,11 @@ nav = [ { "Sample" = "notebooks/containers/sample_example.md" }, { "Problem definition" = "notebooks/problem_definition_example.md" }, { "Infos" = "notebooks/infos_example.md" }, - { "Downloadable samples" = "notebooks/downloadable_example/sample_example.md" }, + { "Downloadable samples" = "tutorials/downloadable_example.md" }, { "Conversion tutorial" = "tutorials/storage.md" }, ] }, - { "Viewer" = "concepts/viewer.md" }, + { "plaid-viewer" = "concepts/viewer.md" }, + { "plaid-check" = "concepts/check.md" }, # >>> AUTO-GENERATED API REFERENCE START # The block below is overwritten by `python docs/generate_api_stubs.py`. # Edit that script (or the markers) instead of changing this section by hand. diff --git a/examples/downloadable_example/__init__.py b/examples/downloadable_example/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/downloadable_example/sample_example.py b/examples/downloadable_example/sample_example.py deleted file mode 100644 index f97cec82..00000000 --- a/examples/downloadable_example/sample_example.py +++ /dev/null @@ -1,49 +0,0 @@ -# --- -# jupyter: -# jupytext: -# formats: ipynb,py:percent -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.17.3 -# kernelspec: -# display_name: plaid-dev -# language: python -# name: python3 -# --- - -# %% [markdown] -# # Downloadable samples -# -# This Jupyter Notebook show how to easily retrieve sample examples online. - -# %% -import warnings -warnings.filterwarnings("ignore", message=".*IProgress not found.*") - -# %% [markdown] -# Retrieving sample examples is as easy as: - -# %% -from plaid.downloadable_examples import AVAILABLE_EXAMPLES -import time - -print(AVAILABLE_EXAMPLES) - -# %% -from plaid.downloadable_examples import samples - -start = time.perf_counter() -print("samples.vki_ls59:", samples.vki_ls59) -end = time.perf_counter() - -print(f"First sample retrieval duration: {end - start:.6f} seconds") -assert(len(samples.vki_ls59.get_global_names())==8) - -# %% -start = time.perf_counter() -sample = samples.vki_ls59 -end = time.perf_counter() - -print(f"Second sample retrieval duration: {end - start:.6f} seconds (in cache)") diff --git a/examples/infos_example.py b/examples/infos_example.py index a1dc2214..58015054 100644 --- a/examples/infos_example.py +++ b/examples/infos_example.py @@ -24,7 +24,7 @@ # 3. Saving and loading infos # # This notebook provides examples of using the Infos class to define dataset -# metadata, access entries with mapping-style helpers, and save/load infos. +# metadata, access entries with typed attributes, and save/load infos. # # **Each section is documented and explained.** @@ -36,7 +36,7 @@ # %% # Import necessary libraries and classes -from plaid.infos import DataProduction, Infos, Legal +from plaid.infos import DataProduction, Infos # %% [markdown] # ## Section 1: Initializing Infos @@ -49,7 +49,8 @@ # %% print("#---# Infos") infos = Infos( - legal=Legal(owner="PLAID", license="MIT"), + owner="PLAID", + license="MIT", ) print(f"{infos = }") @@ -57,12 +58,10 @@ # ### Initialize Infos from a plain mapping # %% -infos_from_mapping = Infos.from_mapping( +infos_from_mapping = Infos.model_validate( { - "legal": { - "owner": "PLAID", - "license": "MIT", - }, + "owner": "PLAID", + "license": "MIT", "data_description": "Example metadata for a PLAID dataset.", } ) @@ -75,11 +74,13 @@ # metadata. # %% [markdown] -# ### Set legal metadata +# ### Set owner and license metadata # %% -infos.legal = Legal(owner="Safran", license="proprietary") -print(f"{infos.legal = }") +infos.owner = "Safran" +infos.license = "proprietary" +print(f"{infos.owner = }") +print(f"{infos.license = }") # %% [markdown] # ### Set data production metadata @@ -97,24 +98,23 @@ print(f"{infos.data_production = }") # %% [markdown] -# ### Set data description, sample counts, and storage backend +# ### Set data description # %% infos.data_description = "Example dataset generated for the Infos example." -infos.num_samples = {"train": 2, "test": 2} -infos.storage_backend = "cgns" print(f"{infos.data_description = }") -print(f"{infos.num_samples = }") -print(f"{infos.storage_backend = }") +print(f"{infos.num_samples = }") # Populated by save_to_disk for saved datasets. +print(f"{infos.storage_backend = }") # Populated by save_to_disk for saved datasets. # %% [markdown] -# ### Retrieve data with mapping-style helpers +# ### Retrieve data with Pydantic attributes # %% -print(f"{infos['legal'] = }") -print(f"{infos.get('storage_backend') = }") -print(f"{infos.to_dict() = }") +print(f"{infos.owner = }") +print(f"{infos.license = }") +print(f"{infos.storage_backend = }") +print(f"{infos.model_dump(exclude_none=True) = }") # %% [markdown] # ## Section 3: Saving and Loading Infos @@ -139,12 +139,12 @@ # ### Load Infos from a YAML file # %% -loaded_infos = Infos.from_path(infos_save_fname) +loaded_infos = Infos.from_path(infos_save_fname, require_persisted=False) print(loaded_infos) # %% [markdown] # ### Load Infos from a directory containing infos.yaml # %% -loaded_infos_from_dir = Infos.from_path(test_pth) -print(loaded_infos_from_dir) \ No newline at end of file +loaded_infos_from_dir = Infos.from_path(test_pth, require_persisted=False) +print(loaded_infos_from_dir) diff --git a/examples/problem_definition_example.py b/examples/problem_definition_example.py index 42a5bfef..9dc27db0 100644 --- a/examples/problem_definition_example.py +++ b/examples/problem_definition_example.py @@ -18,7 +18,7 @@ # # This Jupyter Notebook demonstrates the usage of the ProblemDefinition class for defining machine learning problems using the PLAID library. It includes examples of: # -# 1. Initializing an empty ProblemDefinition +# 1. Initializing a complete ProblemDefinition # 2. Configuring problem characteristics and retrieve data # 3. Saving and loading problem definitions # @@ -37,38 +37,42 @@ from plaid import ProblemDefinition # %% [markdown] -# ## Section 1: Initializing an Empty ProblemDefinition +# ## Section 1: Initializing a ProblemDefinition # # This section demonstrates how to initialize a ProblemDefinition and add # input/output feature identifiers with the current API. # %% [markdown] -# ### Initialize and print ProblemDefinition +# ### Initialize feature identifiers # %% -print("#---# Empty ProblemDefinition") -problem = ProblemDefinition() -print(f"{problem = }") - -# %% -# ### Initialize some feature identifiers scalar_1_feat_id = "Global/scalar_1" scalar_2_feat_id = "Global/scalar_2" scalar_3_feat_id = "Global/scalar_3" field_1_feat_id = "Base_2_2/Zone/CellCenterFields/field_1" field_2_feat_id = "Base_2_2/Zone/VertexFields/field_2" +# %% +print("#---# ProblemDefinition") +problem = ProblemDefinition( + input_features=[scalar_3_feat_id, field_1_feat_id], + output_features=[field_2_feat_id], + train_split={"train": [0, 1]}, + test_split={"test": [2, 3]}, +) +print(f"{problem = }") + # %% [markdown] # ### Add inputs / outputs to a Problem Definition # %% -# Add unique input and output feature identifiers +# Add unique input and output feature identifiers after initialization #problem.add_input_features(scalar_1_feat_id) #problem.add_output_features(scalar_2_feat_id) # or Add list of input and output feature identifiers -problem.add_input_features([scalar_3_feat_id, field_1_feat_id]) -problem.add_output_features([field_2_feat_id]) +problem.add_input_features([scalar_1_feat_id]) +problem.add_output_features([scalar_2_feat_id]) print(f"{problem.input_features = }") print( @@ -81,28 +85,25 @@ # This section demonstrates how to handle and configure ProblemDefinition objects and access data. # %% [markdown] -# ### Set Problem Definition name - -# %% -problem.name = "my_problem_definition" -print(f"{problem.name = }") +# ### Problem Definition identifier +# +# ProblemDefinition has no stored name. When saved in a dataset, its identifier +# is the YAML filename stem or the key in a `dict[str, ProblemDefinition]`. # %% [markdown] # ### Set Problem Definition split # %% -# Current API uses `train_split` and `test_split` fields. +# Current API uses required `train_split` and `test_split` fields. # Note: each split field currently expects a dictionary with a single entry. -problem.train_split = {"train": [0, 1]} -problem.test_split = {"test": [2, 3]} print(f"{problem.train_split = }") print(f"{problem.test_split = }") # %% -split_names = [problem.get_train_split_name(), problem.get_test_split_name()] +split_names = [next(iter(problem.train_split)), next(iter(problem.test_split))] split_indices = { - problem.get_train_split_name(): problem.get_train_split_indices(), - problem.get_test_split_name(): problem.get_test_split_indices(), + next(iter(problem.train_split)): next(iter(problem.train_split.values())), + next(iter(problem.test_split)): next(iter(problem.test_split.values())), } print(f"{split_names = }") print(f"{split_indices = }") @@ -136,6 +137,5 @@ # ### Load a ProblemDefinition from a YAML file # %% -problem = ProblemDefinition() -problem._load_from_file_(pb_def_save_fname) +problem = ProblemDefinition.from_path(pb_def_save_fname) print(problem) diff --git a/src/plaid/cli/plaidcheck.py b/src/plaid/cli/plaidcheck.py index a2f33f3e..61712a83 100644 --- a/src/plaid/cli/plaidcheck.py +++ b/src/plaid/cli/plaidcheck.py @@ -8,16 +8,22 @@ import CGNS.PAT.cgnsutils as CGU import numpy as np +from tqdm import tqdm from plaid.constants import CGNS_FIELD_LOCATIONS +from plaid.infos import Infos from plaid.storage import init_from_disk from plaid.storage.common.reader import ( - load_infos_from_disk, load_metadata_from_disk, load_problem_definitions_from_disk, ) +def load_infos_from_disk(path: Path) -> Infos: + """Load infos for checker diagnostics without persisted-field enforcement.""" + return Infos.from_path(path, require_persisted=False) + + @dataclass class CheckMessage: """One integrity check message. @@ -158,6 +164,81 @@ def _check_numeric_content(value: Any) -> Optional[str]: return None +def _format_missing_split_message(split: object) -> str: + """Return an actionable message for missing split declarations. + + Args: + split: Split name/key reported by a low-level ``KeyError``. + + Returns: + Human-readable explanation suitable for a checker error. + """ + return ( + f"Split {split!r} exists in the stored dataset or metadata but is missing " + "from infos.yaml > num_samples. Add this split to num_samples, or remove " + "the corresponding split data/metadata from the dataset." + ) + + +def _check_num_samples_declares_splits( + num_samples: dict[str, Any], + split_names: set[str], + report: CheckReport, +) -> None: + """Validate that ``infos.yaml > num_samples`` declares known disk splits. + + Args: + num_samples: Mapping loaded from ``infos.yaml``. + split_names: Split names discovered from storage files/metadata. + report: Report updated with missing declaration errors. + + Returns: + None. + """ + declared_splits = {str(split) for split in num_samples} + for split in sorted(split_names - declared_splits): + report.add( + "error", + "NUM_SAMPLES_MISSING_SPLIT", + "infos.yaml", + _format_missing_split_message(split), + ) + + +def _discover_split_names_from_disk( + path: Path, + backend: Optional[str], + flat_cst: dict[str, Any], + constant_schema: dict[str, Any], +) -> set[str]: + """Discover split names from files/metadata without building converters. + + Args: + path: Dataset root. + backend: Declared storage backend, if valid. + flat_cst: Flattened constants keyed by split for non-CGNS backends. + constant_schema: Constant schema keyed by split for non-CGNS backends. + + Returns: + Split names discovered from the on-disk dataset structure and metadata. + """ + split_names: set[str] = set() + data_path = path / "data" + if data_path.exists(): + if backend in {"zarr", "cgns"}: + split_names.update(p.name for p in data_path.iterdir() if p.is_dir()) + elif backend == "hf_datasets": + split_names.update(flat_cst.keys()) + split_names.update(constant_schema.keys()) + if backend != "cgns": + constants_path = path / "constants" + if constants_path.exists(): + split_names.update(p.name for p in constants_path.iterdir() if p.is_dir()) + split_names.update(flat_cst.keys()) + split_names.update(constant_schema.keys()) + return {str(split) for split in split_names} + + def _is_branch_without_data(sample: Any, path: str) -> bool: """Return True when `path` points to a branch node with no direct value. @@ -201,6 +282,98 @@ def _is_branch_without_data_in_mapping( return any(name.startswith(prefix) for name in feat_map) +def _progress( + iterable: Any, *, desc: str, show_progress: bool, total: int | None = None +): + """Wrap an iterable in a tqdm progress bar when requested. + + Args: + iterable: Iterable to wrap. + desc: Progress bar description. + show_progress: Whether the progress bar is enabled. + total: Optional total length. + + Returns: + Iterable, possibly wrapped by :class:`tqdm.tqdm`. + """ + return tqdm(iterable, desc=desc, total=total, disable=not show_progress) + + +def _resolve_problem_split_indices( + split_ids: Any, + split_len: int, +) -> list[int]: + """Resolve a problem-definition split declaration into concrete indices. + + Args: + split_ids: Either the special all-samples marker or an iterable of indices. + split_len: Number of samples available in the referenced split. + + Returns: + Concrete sample indices to instantiate. + """ + if split_ids == "all": + return list(range(split_len)) + return list(split_ids) + + +def _check_problem_definition_sample_features( + *, + pb_name: str, + split_dict_name: str, + split_name: str, + idx: int, + dataset: Any, + converter: Any, + features: list[str], + report: CheckReport, +) -> None: + """Instantiate and validate one problem-definition sample view. + + The sample is instantiated with the exact feature subset requested by the + problem definition, then each requested feature is read back to validate + that the requested feature paths can actually be resolved. + + Numeric content (NaN, Inf, None, empty arrays, ...) is intentionally not + re-checked here: the per-split loop in :func:`check_dataset` already walks + every sample's globals and fields and reports such issues with the + ``INVALID_DATA_VALUE A`` code. Re-checking them in this loop would only + produce duplicate warnings under a different code/location. + + Args: + pb_name: Problem-definition name. + split_dict_name: ``train_split`` or ``test_split``. + split_name: Dataset split name. + idx: Sample index. + dataset: Backend dataset object. + converter: Storage converter exposing ``to_plaid``. + features: Feature paths to request and validate. + report: Report updated with detected errors/warnings. + """ + location = f"problem_definitions/{pb_name}/{split_dict_name}/{split_name}[{idx}]" + try: + sample = converter.to_plaid(dataset, idx, features=features) + except Exception as exc: + report.add( + "error", + "PB_DEF_SAMPLE_CONVERSION_ERROR", + location, + str(exc), + ) + return + + for feature in features: + try: + sample.get_feature_by_path(feature) + except Exception as exc: + report.add( + "error", + "PB_DEF_FEATURE_READ_ERROR", + f"{location} {feature}", + str(exc), + ) + + def compute_checksum(sample: Any) -> str: """Compute a SHA-256 checksum for a converted sample representation. @@ -221,6 +394,8 @@ def compute_checksum(sample: Any) -> str: def check_dataset( path: Path, splits: Optional[list[str]] = None, + show_progress: bool = True, + problem_definitions: Optional[list[str]] = None, ) -> CheckReport: """Run integrity checks on a local PLAID dataset. @@ -241,6 +416,9 @@ def check_dataset( Args: path: Dataset directory. splits: Optional selected split names. + show_progress: Whether to display tqdm progress bars for expensive checks. + problem_definitions: Optional selected problem-definition names. When + omitted, all discovered problem definitions are checked. Returns: A populated :class:`CheckReport`. @@ -262,7 +440,7 @@ def check_dataset( report.add("error", "INFOS_READ_ERROR", "infos.yaml", str(exc)) return report - declared_backend_for_layout = infos.get("storage_backend") + declared_backend_for_layout = infos.storage_backend if not isinstance(declared_backend_for_layout, str): declared_backend_for_layout = None @@ -273,6 +451,25 @@ def check_dataset( if report.has_errors(): return report + # Validate top-level dataset declarations from infos.yaml before calling + # init_from_disk(), because storage initialization indexes num_samples by + # split and otherwise reports missing entries as opaque KeyError messages. + declared_backend = infos.storage_backend + if not isinstance(declared_backend, str): + report.add( + "error", + "BACKEND_MISSING", + "infos.yaml", + "Missing or invalid 'storage_backend' in infos.yaml", + ) + + num_samples = infos.num_samples + if not isinstance(num_samples, dict): + report.add( + "error", "NUM_SAMPLES_INVALID", "infos.yaml", "'num_samples' must be a dict" + ) + num_samples = {} + # Load metadata when the backend defines it. The CGNS backend stores # self-contained samples and intentionally writes no derived metadata. if declared_backend_for_layout == "cgns": @@ -288,28 +485,29 @@ def check_dataset( report.add("error", "METADATA_READ_ERROR", str(path), str(exc)) return report - try: - datasetdict, converterdict = init_from_disk(path) - except Exception as exc: - report.add("error", "DATASET_INIT_ERROR", str(path), str(exc)) + discovered_splits = _discover_split_names_from_disk( + path, + declared_backend_for_layout, + flat_cst, + constant_schema, + ) + _check_num_samples_declares_splits(num_samples, discovered_splits, report) + if report.has_errors(): return report - # Validate top-level dataset declarations from infos.yaml. - declared_backend = infos.get("storage_backend") - if not isinstance(declared_backend, str): + try: + datasetdict, converterdict = init_from_disk(path) + except KeyError as exc: report.add( "error", - "BACKEND_MISSING", + "NUM_SAMPLES_MISSING_SPLIT", "infos.yaml", - "Missing or invalid 'storage_backend' in infos.yaml", + _format_missing_split_message(exc.args[0] if exc.args else str(exc)), ) - - num_samples = infos.get("num_samples", {}) - if not isinstance(num_samples, dict): - report.add( - "error", "NUM_SAMPLES_INVALID", "infos.yaml", "'num_samples' must be a dict" - ) - num_samples = {} + return report + except Exception as exc: + report.add("error", "DATASET_INIT_ERROR", str(path), str(exc)) + return report # Resolve the user-requested splits against the splits actually available. dataset_splits = set(datasetdict.keys()) @@ -359,7 +557,12 @@ def check_dataset( ) # Deep-check to validate content and detect non valide data in fields (nan inf) - for idx in range(actual_n): + for idx in _progress( + range(actual_n), + desc=f"Checking split {split}", + show_progress=show_progress, + total=actual_n, + ): try: sample = converter.to_plaid(dataset, idx) except Exception as exc: @@ -464,7 +667,23 @@ def check_dataset( # still run). validate_pb_def_features = declared_backend_for_layout != "cgns" + target_pb_names = ( + set(problem_definitions) if problem_definitions else set(pb_defs) + ) + unknown_pb_names = target_pb_names - set(pb_defs) + for pb_name in sorted(unknown_pb_names): + available = " and ".join(f'"{x}"' for x in sorted(pb_defs)) + report.add( + "error", + "PB_DEF_UNKNOWN", + f"problem_definitions/{pb_name}", + f"Problem definition not found, available are {available}", + ) + target_pb_names = target_pb_names & set(pb_defs) + for pb_name, pb_def in pb_defs.items(): + if pb_name not in target_pb_names: + continue if validate_pb_def_features: for feat in pb_def.input_features: if feat not in all_features: @@ -507,9 +726,8 @@ def check_dataset( f"Unknown split in {split_dict_name}: {split_name}", ) continue - if split_ids == "all": - continue - ids_list = list(split_ids) + split_len = len(datasetdict[split_name]) + ids_list = _resolve_problem_split_indices(split_ids, split_len) if len(ids_list) != len(set(ids_list)): report.add( "error", @@ -517,7 +735,6 @@ def check_dataset( f"problem_definitions/{pb_name}", f"Duplicated indices in {split_dict_name}", ) - split_len = len(datasetdict[split_name]) bad = [i for i in ids_list if i < 0 or i >= split_len] if bad: report.add( @@ -526,10 +743,32 @@ def check_dataset( f"problem_definitions/{pb_name}", f"Out-of-range indices in {split_dict_name} (first 10): {bad[:10]}", ) + continue + + if split_dict_name == "train_split": + features = list(pb_def.input_features) + list( + pb_def.output_features + ) + else: + features = list(pb_def.input_features) + + for idx in _progress( + ids_list, + desc=f"Checking problem {pb_name} {split_dict_name}", + show_progress=show_progress, + total=len(ids_list), + ): + _check_problem_definition_sample_features( + pb_name=pb_name, + split_dict_name=split_dict_name, + split_name=split_name, + idx=idx, + dataset=datasetdict[split_name], + converter=converterdict[split_name], + features=features, + report=report, + ) - # Emit an explicit success message when no errors or warnings were found. - if not report.messages: - report.add("info", "OK", str(path), "No issue detected") return report @@ -557,6 +796,12 @@ def _build_parser() -> argparse.ArgumentParser: action="store_true", help="Treat warnings as failure", ) + parser.add_argument( + "--problem-definition", + action="append", + default=None, + help="Problem definition to check (can be provided multiple times)", + ) return parser @@ -572,13 +817,23 @@ def main(argv: Optional[list[str]] = None) -> int: parser = _build_parser() args = parser.parse_args(argv) - report = check_dataset(path=args.path, splits=args.split) + report = check_dataset( + path=args.path, + splits=args.split, + show_progress=not args.json, + problem_definitions=args.problem_definition, + ) if args.json: print(report.to_json()) else: - for msg in report.messages: - print(f"[{msg.severity.upper()}] {msg.code} {msg.location}: {msg.message}") + if not report.messages: + print(f"[OK] {args.path}: No issue detected") + else: + for msg in report.messages: + print( + f"[{msg.severity.upper()}] {msg.code} {msg.location}: {msg.message}" + ) counts = report.counts() print( f"Summary: {counts['error']} error(s), " diff --git a/src/plaid/containers/sample.py b/src/plaid/containers/sample.py index 2ede2cc8..2a707e52 100644 --- a/src/plaid/containers/sample.py +++ b/src/plaid/containers/sample.py @@ -1211,7 +1211,7 @@ def get_zone_type( if zone_node is None: raise KeyError( - f"there is no base/zone <{base}/{zone}>, you should first create one with `Sample.init_zone({zone=},{base=})`" + f"There is no base/zone <{base}/{zone}>, you should first create one with `Sample.init_zone({zone=},{base=})`." ) return CGU.getValueByPath(zone_node, "ZoneType").tobytes().decode() diff --git a/src/plaid/infos.py b/src/plaid/infos.py index a4958057..c92d4f88 100644 --- a/src/plaid/infos.py +++ b/src/plaid/infos.py @@ -2,7 +2,6 @@ from __future__ import annotations -import copy import logging from pathlib import Path from typing import Any, Union @@ -18,14 +17,6 @@ ) -@dataclass(config=_PD_CONFIG) -class Legal: - """Legal ownership and licensing metadata.""" - - owner: str - license: str - - @dataclass(config=_PD_CONFIG) class DataProduction: """Dataset production context metadata.""" @@ -43,7 +34,8 @@ class DataProduction: # Order used when serializing to YAML. _KEY_ORDER = ( - "legal", + "owner", + "license", "data_production", "data_description", "num_samples", @@ -56,29 +48,45 @@ class Infos(BaseModel): model_config = _PD_CONFIG - legal: Legal + owner: str + license: str data_production: DataProduction | None = None data_description: str | None = None num_samples: dict[str, int] = Field(default_factory=dict) storage_backend: str | None = None - @classmethod - def _normalize_top_level(cls, infos: dict[str, Any]) -> dict[str, Any]: - # Drop legacy/unsupported top-level sections silently before validation. - normalized = copy.deepcopy(infos) - normalized.pop("plaid", None) - return normalized + def require_persisted(self) -> "Infos": + """Validate fields that must exist in persisted dataset infos. + + ``num_samples`` and ``storage_backend`` are derived by storage writers + when a dataset is saved, so they are optional while users prepare an + ``Infos`` object. Once infos are loaded from disk or the Hub, however, + readers need both fields to select the backend and split sizes. + """ + if "num_samples" not in self.model_fields_set: + raise ValueError("Missing required persisted infos field: 'num_samples'") + if "storage_backend" not in self.model_fields_set or not self.storage_backend: + raise ValueError( + "Missing required persisted infos field: 'storage_backend'" + ) + return self @classmethod def validate_authorized_only(cls, infos: dict[str, Any]) -> "Infos": """Validate schema/authorized keys without enforcing required sections.""" - normalized = cls._normalize_top_level(infos) - had_legal = "legal" in normalized - if not had_legal: - normalized["legal"] = { - "owner": "__placeholder__", - "license": "__placeholder__", - } + normalized = dict(infos) + had_owner = "owner" in normalized + had_license = "license" in normalized + had_num_samples = "num_samples" in normalized + had_storage_backend = "storage_backend" in normalized + if not had_owner: + normalized["owner"] = "__placeholder__" + if not had_license: + normalized["license"] = "__placeholder__" + if not had_num_samples: + normalized["num_samples"] = {} + if not had_storage_backend: + normalized["storage_backend"] = "__placeholder__" try: model = cls.model_validate(normalized) except ValidationError as exc: @@ -91,21 +99,30 @@ def validate_authorized_only(cls, infos: dict[str, Any]) -> "Infos": raise KeyError(f"Unauthorized infos key: {loc!r}") from exc raise - if not had_legal: - model.legal = Legal(owner="", license="") + if not had_owner: + model.owner = "" + if not had_license: + model.license = "" + if not had_num_samples: + model.num_samples = {} + if not had_storage_backend: + model.storage_backend = "" return model @classmethod def validate_required_only(cls, infos: dict[str, Any]) -> None: - """Validate required entries using pydantic-required fields.""" - normalized = cls._normalize_top_level(infos) - cls.model_validate(normalized) + """Validate entries required for persisted dataset infos.""" + cls.model_validate(infos).require_persisted() + + @classmethod + def validate_persisted(cls, infos: dict[str, Any]) -> "Infos": + """Validate and return complete infos loaded from persisted storage.""" + return cls.model_validate(infos).require_persisted() @classmethod def normalize_mapping(cls, infos: dict[str, Any]) -> dict[str, Any]: """Validate and return a normalized deep copy of infos.""" - normalized = cls._normalize_top_level(infos) - model = cls.model_validate(normalized) + model = cls.model_validate(infos) return model.model_dump(exclude_none=True) # ------------------------------------------------------------------ @@ -113,17 +130,16 @@ def normalize_mapping(cls, infos: dict[str, Any]) -> dict[str, Any]: # ------------------------------------------------------------------ @classmethod - def from_mapping(cls, infos: dict[str, Any]) -> "Infos": - """Build a validated :class:`Infos` from a plain mapping.""" - return cls.model_validate(cls._normalize_top_level(infos)) - - @classmethod - def from_path(cls, path: Union[str, Path]) -> "Infos": + def from_path( + cls, path: Union[str, Path], require_persisted: bool = True + ) -> "Infos": """Load and validate an :class:`Infos` from a YAML file. Args: path: Path to the YAML file (typically ``infos.yaml``) or to a directory containing it. + require_persisted: When True, require storage-derived metadata + fields expected in a complete on-disk dataset. Returns: Validated :class:`Infos` instance. @@ -140,44 +156,10 @@ def from_path(cls, path: Union[str, Path]) -> "Infos": with path.open("r", encoding="utf-8") as file: data = yaml.safe_load(file) or {} - return cls.from_mapping(data) - - # ------------------------------------------------------------------ - # Mapping-like accessors (read-only convenience) - # ------------------------------------------------------------------ - - def to_dict(self) -> dict[str, Any]: - """Return a plain ``dict`` representation of these infos.""" - return self.model_dump(exclude_none=True) - - def __getitem__(self, key: str) -> Any: - """Return the value associated with ``key`` using mapping-style access.""" - if not hasattr(self, key): - raise KeyError(key) - value = getattr(self, key) - # Unwrap nested dataclasses to plain dicts when accessed by key, so that - # callers expecting a YAML-like mapping continue to work transparently. - if hasattr(value, "__pydantic_fields__"): - return { - f: getattr(value, f) - for f in value.__pydantic_fields__ - if getattr(value, f) is not None - } - return value - - def __contains__(self, key: object) -> bool: - """Return whether ``key`` is a known field with a non-``None`` value.""" - if not isinstance(key, str): - return False - if not hasattr(self, key): - return False - return getattr(self, key) is not None - - def get(self, key: str, default: Any = None) -> Any: - """Return ``self[key]`` when present, otherwise ``default``.""" - if key in self: - return self[key] - return default + infos = cls.model_validate(data) + if require_persisted: + infos.require_persisted() + return infos def save_to_file(self, path: Union[str, Path]) -> None: """Save infos to ``path`` as a YAML file. @@ -197,7 +179,7 @@ def save_to_file(self, path: Union[str, Path]) -> None: path.parent.mkdir(parents=True, exist_ok=True) - data = self.model_dump(exclude_none=True) + data = self.model_dump(exclude_none=True, exclude_unset=True) ordered_data = {key: data[key] for key in _KEY_ORDER if key in data} # Preserve any future fields. for key, value in data.items(): diff --git a/src/plaid/problem_definition.py b/src/plaid/problem_definition.py index 4585b09d..cb83c943 100644 --- a/src/plaid/problem_definition.py +++ b/src/plaid/problem_definition.py @@ -12,12 +12,10 @@ import logging from pathlib import Path -from typing import Any, Literal, Optional, Sequence, Union, cast +from typing import Any, Literal, Sequence, Union import yaml -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from .types import IndexArrayType +from pydantic import BaseModel, ConfigDict, field_validator # %% Globals @@ -35,66 +33,42 @@ class ProblemDefinition(BaseModel): revalidate_instances="always", validate_assignment=True, extra="forbid" ) - name: Optional[str] = Field(default=None) - input_features: list[str] = Field(default_factory=list) - output_features: list[str] = Field(default_factory=list) - train_split: Optional[dict[str, Sequence[int] | Literal["all"]]] = Field( - default=None - ) - test_split: Optional[dict[str, Sequence[int] | Literal["all"]]] = Field( - default=None - ) - - @staticmethod - def from_path( - path: str | Path, name: str | None = None, **overrides: Any - ) -> "ProblemDefinition": - """Load a problem definition from a YAML file located at the specified path. + input_features: list[str] + output_features: list[str] + train_split: dict[str, Sequence[int] | Literal["all"]] + test_split: dict[str, Sequence[int] | Literal["all"]] - The YAML file should contain one or more problem definitions, and the desired definition can be selected by its name. + @classmethod + def from_path(cls, path: str | Path) -> "ProblemDefinition": + """Load and validate one problem definition from a YAML file. Args: - path (str | Path): The file path to the YAML file containing problem definitions. - name (str | None, optional): The name of the problem definition to load. If None, it will attempt to load the - only problem definition available in the file. Defaults to None. - **overrides: Additional keyword arguments to override specific fields in the loaded problem definition. - - Raises: - ValueError: If the specified name is not found in the YAML file. - RuntimeError: If multiple problem definitions are present without a specified name. + path: Path to the problem-definition YAML file. If no suffix is + provided, ``.yaml`` is appended. Returns: - ProblemDefinition: The loaded problem definition. + Validated problem definition instance. + + Raises: + FileNotFoundError: If the resolved YAML file does not exist. """ - from plaid.storage import load_problem_definitions_from_disk - - all_pb_def = load_problem_definitions_from_disk(path=Path(path)) - available = ", ".join(sorted(all_pb_def)) - if name is not None: - if name not in all_pb_def: - raise ValueError( - f"Problem definition '{name}' not found in {path}. " - f"Available definitions: {available}" - ) - data2 = all_pb_def[name].model_dump() - data2.update(overrides) - data = data2 - else: - if len(all_pb_def) > 1: - raise RuntimeError( - f"Non name specified, but more than one Problem definition. Available definitions: {available}" - ) - else: - data2 = next(iter(all_pb_def.values())).model_dump() - data2.update(overrides) - data = data2 + path = Path(path) + if path.suffix != ".yaml": + path = path.with_suffix(".yaml") + if not path.exists(): + raise FileNotFoundError(f'File "{path}" does not exist. Abort') + + with path.open("r", encoding="utf-8") as file: + data = yaml.safe_load(file) or {} - return ProblemDefinition(**data) + return cls.model_validate(data) @field_validator("input_features", mode="before") @classmethod def normalize_input_features(cls, v): """Normalize input features identifiers by ensuring they are unique and sorted.""" + if not v: + raise ValueError("input_features must not be empty") if len(set(v)) != len(v): raise ValueError("duplicated values in input_features") return _normalize_list(v) @@ -103,22 +77,14 @@ def normalize_input_features(cls, v): @classmethod def normalize_output_features(cls, v): """Normalize output features identifiers by ensuring they are unique and sorted.""" + if not v: + raise ValueError("output_features must not be empty") if len(set(v)) != len(v): raise ValueError("duplicated values in output_features") return _normalize_list(v) def __setattr__(self, name: str, value: Any) -> None: - """Override the default attribute setting behavior to enforce immutability for certain fields and log warnings for others.""" - # to set the name, task, score_function only once and oly once - if name in ["name"]: - current_value = getattr(self, name, None) - if ( - current_value is not None - and value is not None - and current_value != value - ): - raise AttributeError(f"'{name}' is already set and cannot be changed.") - # warning if set + """Override attribute setting to log warnings when split fields are replaced.""" if name in ["train_split", "test_split"]: current_value = getattr(self, name, None) if ( @@ -130,49 +96,6 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) - # # -------------------------------------------------------------------------# - def get_train_split_name(self) -> str: - """Return the name of the train split.""" - if self.train_split is None: - raise ValueError("train_split is not defined.") - return list(self.train_split.keys())[0] - - def get_train_split_indices(self) -> IndexArrayType | Literal["all"]: - """Return the indices associated with the train split. - - Raises: - ValueError: If `train_split` is not defined. - - Returns: - IndexArrayType | Literal["all"]: The indices associated with the train split. - """ - if self.train_split is None: - raise ValueError("train_split is not defined.") - return cast( - IndexArrayType | Literal["all"], next(iter(self.train_split.values())) - ) - - def get_test_split_name(self) -> str: - """Return the name of the test split.""" - if self.test_split is None: - raise ValueError("test_split is not defined.") - return list(self.test_split.keys())[0] - - def get_test_split_indices(self) -> IndexArrayType | Literal["all"]: - """Return the indices associated with the test split. - - Raises: - ValueError: If `test_split` is not defined. - - Returns: - IndexArrayType | Literal["all"]: The indices associated with the test split. - """ - if self.test_split is None: - raise ValueError("test_split is not defined.") - return cast( - IndexArrayType | Literal["all"], next(iter(self.test_split.values())) - ) - def add_input_features(self, inputs: Union[str, Sequence[str]]) -> None: """Add input features identifiers to the problem. @@ -186,7 +109,12 @@ def add_input_features(self, inputs: Union[str, Sequence[str]]) -> None: .. code-block:: python from plaid.problem_definition import ProblemDefinition - problem = ProblemDefinition() + problem = ProblemDefinition( + input_features=["angle"], + output_features=["pressure"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) input_features = ['omega', 'pressure'] problem.add_input_features(input_features) @@ -222,7 +150,12 @@ def add_output_features(self, outputs: Union[str, Sequence[str]]) -> None: .. code-block:: python from plaid.problem_definition import ProblemDefinition - problem = ProblemDefinition() + problem = ProblemDefinition( + input_features=["angle"], + output_features=["pressure"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) output_features = ['omega', 'pressure'] problem.add_output_features(output_features) @@ -255,7 +188,12 @@ def save_to_file(self, path: Union[str, Path]) -> None: .. code-block:: python from plaid import ProblemDefinition - problem = ProblemDefinition() + problem = ProblemDefinition( + input_features=["angle"], + output_features=["pressure"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) problem.save_to_file("/path/to/save_file") """ path = Path(path) @@ -267,7 +205,6 @@ def save_to_file(self, path: Union[str, Path]) -> None: data = self.model_dump() key_order = [ - "name", "input_features", "output_features", "train_split", @@ -285,37 +222,3 @@ def save_to_file(self, path: Union[str, Path]) -> None: sort_keys=False, allow_unicode=True, ) - - def _load_from_file_(self, path: Union[str, Path]) -> None: - """Load problem information, inputs, outputs, and split from the specified file in YAML format. - - Args: - path (Union[str,Path]): The filepath from which to load the problem information. - - Raises: - FileNotFoundError: Triggered if the provided file does not exist. - - Example: - .. code-block:: python - - from plaid import ProblemDefinition - problem = ProblemDefinition() - problem._load_from_file_("/path/to/load_file") - """ - path = Path(path) - - if path.suffix != ".yaml": - path = path.with_suffix(".yaml") - - if not path.exists(): - raise FileNotFoundError(f'File "{path}" does not exist. Abort') - - with path.open("r") as file: - data = yaml.safe_load(file) - - model_fields = type(self).model_fields.keys() - for key, value in data.items(): - if key in model_fields: - setattr(self, key, value) - else: - logger.warning(f" Data ignored! : {key}: {value}") diff --git a/src/plaid/storage/cgns/reader.py b/src/plaid/storage/cgns/reader.py index 4241a055..13ab5481 100644 --- a/src/plaid/storage/cgns/reader.py +++ b/src/plaid/storage/cgns/reader.py @@ -264,7 +264,7 @@ def init_datasetdict_streaming_from_hub( else: infos = load_infos_from_hub(repo_id=repo_id) selected_ids = { - split: range(n_samples) for split, n_samples in infos["num_samples"].items() + split: range(n_samples) for split, n_samples in infos.num_samples.items() } return { diff --git a/src/plaid/storage/cgns/writer.py b/src/plaid/storage/cgns/writer.py index 656b584c..5f98d0e3 100644 --- a/src/plaid/storage/cgns/writer.py +++ b/src/plaid/storage/cgns/writer.py @@ -185,7 +185,7 @@ def configure_dataset_card( None: This function does not return a value; it pushes the dataset card directly to Hugging Face Hub. """ - infos_dict = infos.to_dict() + infos_dict = infos.model_dump(exclude_none=True) dataset_card_str = """--- task_categories: - graph-ml @@ -224,7 +224,7 @@ def configure_dataset_card( lines = lines[: indices[1] + 1] count = 6 - lines.insert(count, f"license: {infos.legal.license}") + lines.insert(count, f"license: {infos.license}") count += 1 lines.insert(count, "viewer: false") count += 1 diff --git a/src/plaid/storage/common/reader.py b/src/plaid/storage/common/reader.py index 3a99fcc5..dbca8407 100644 --- a/src/plaid/storage/common/reader.py +++ b/src/plaid/storage/common/reader.py @@ -61,8 +61,8 @@ def load_problem_definitions_from_disk( ``problem_definitions/`` subdirectory under ``path`` and reconstructs them into ``ProblemDefinition`` objects. - Each file is loaded using ``ProblemDefinition._load_from_file_`` and inserted - into a dictionary keyed by the problem definition name. + Each file is loaded using ``ProblemDefinition.from_path`` and inserted into + a dictionary keyed by the YAML filename stem. Expected local layout: / @@ -77,7 +77,7 @@ def load_problem_definitions_from_disk( Returns: dict[str, ProblemDefinition]: - Mapping from problem definition names to loaded ``ProblemDefinition`` + Mapping from problem definition filename stems to loaded ``ProblemDefinition`` objects. Raises: @@ -92,10 +92,15 @@ def load_problem_definitions_from_disk( pb_defs = {} for p in pb_def_dir.iterdir(): if p.is_file(): - pb_def = ProblemDefinition() - pb_def._load_from_file_(pb_def_dir / Path(p.name)) - pb_name = pb_def.name if isinstance(pb_def.name, str) else p.stem - pb_defs[pb_name] = pb_def + pb_def_path = pb_def_dir / Path(p.name) + try: + pb_def = ProblemDefinition.from_path(pb_def_path) + except Exception as exc: + raise ValueError( + f"Failed to load problem definition file " + f"'{pb_def_path.name}': {exc}" + ) from exc + pb_defs[p.stem] = pb_def return pb_defs else: raise ValueError( diff --git a/src/plaid/storage/common/writer.py b/src/plaid/storage/common/writer.py index 41231824..fcc03232 100644 --- a/src/plaid/storage/common/writer.py +++ b/src/plaid/storage/common/writer.py @@ -39,25 +39,30 @@ def save_infos_to_disk(path: Union[str, Path], infos: Infos) -> None: def save_problem_definitions_to_disk( path: Union[str, Path], - pb_defs: Union[dict[str, ProblemDefinition], ProblemDefinition], + pb_defs: dict[str, ProblemDefinition], ) -> None: """Save ProblemDefinitions to disk. Args: path (Union[str, Path]): The directory path for saving. - pb_defs (Union[dict[str, ProblemDefinition], ProblemDefinition]): The problem definitions to save. + pb_defs (dict[str, ProblemDefinition]): Mapping from problem definition identifiers to definitions. """ if isinstance(pb_defs, ProblemDefinition): - pb_defs = {pb_defs.name: pb_defs} + raise TypeError( + "pb_defs must be a dict[str, ProblemDefinition]; " + "use the dictionary key as the problem identifier." + ) + if not isinstance(pb_defs, dict): + raise TypeError("pb_defs must be a dict[str, ProblemDefinition]") target_dir = Path(path) / "problem_definitions" target_dir.mkdir(parents=True, exist_ok=True) for name, pb_def in pb_defs.items(): - if name is None: - raise ValueError( - "At least one of the provided pb_defs has no initialized name." - ) + if not isinstance(name, str) or not name: + raise TypeError("Problem definition identifiers must be non-empty strings") + if not isinstance(pb_def, ProblemDefinition): + raise TypeError("pb_defs values must be ProblemDefinition instances") pb_def.save_to_file(target_dir / name) diff --git a/src/plaid/storage/hf_datasets/reader.py b/src/plaid/storage/hf_datasets/reader.py index 956ae24c..6650dc87 100644 --- a/src/plaid/storage/hf_datasets/reader.py +++ b/src/plaid/storage/hf_datasets/reader.py @@ -94,7 +94,7 @@ def download_datasetdict_from_hub( local_dir=tmp_dir, ) infos = load_infos_from_hub(repo_id=repo_id) - split_names = list(infos["num_samples"].keys()) + split_names = list(infos.num_samples.keys()) base = Path(tmp_dir) / "data" data_files = {sn: str(base / f"{sn}*.parquet") for sn in split_names} datasetdict = load_dataset("parquet", data_files=data_files, cache_dir=tmp_dir) diff --git a/src/plaid/storage/hf_datasets/writer.py b/src/plaid/storage/hf_datasets/writer.py index e39295b1..3a98147b 100644 --- a/src/plaid/storage/hf_datasets/writer.py +++ b/src/plaid/storage/hf_datasets/writer.py @@ -221,7 +221,7 @@ def configure_dataset_card( None: This function does not return a value; it updates the dataset card directly on Hugging Face Hub. """ - infos_dict = infos.to_dict() + infos_dict = infos.model_dump(exclude_none=True) readme_path = hf_hub_download( repo_id=repo_id, filename="README.md", repo_type="dataset" ) @@ -240,7 +240,7 @@ def configure_dataset_card( lines = lines[: indices[1] + 1] count = 1 - lines.insert(count, f"license: {infos.legal.license}") + lines.insert(count, f"license: {infos.license}") count += 1 if viewer is False: lines.insert(count, "viewer: false") diff --git a/src/plaid/storage/reader.py b/src/plaid/storage/reader.py index 3ce5bb56..0b265142 100644 --- a/src/plaid/storage/reader.py +++ b/src/plaid/storage/reader.py @@ -244,8 +244,8 @@ def init_from_disk( """ infos = load_infos_from_disk(local_dir) - backend = infos["storage_backend"] - num_samples = infos["num_samples"] + backend = infos.storage_backend + num_samples = infos.num_samples datasetdict = get_backend(backend).init_from_disk(path=local_dir) @@ -299,7 +299,7 @@ def download_from_hub( infos = load_infos_from_hub(repo_id) pb_defs = load_problem_definitions_from_hub(repo_id) - backend = infos["storage_backend"] + backend = infos.storage_backend backend_spec = get_backend(backend) backend_spec.download_from_hub(repo_id, local_dir, split_ids, features, overwrite) @@ -340,8 +340,8 @@ def init_streaming_from_hub( ) infos = load_infos_from_hub(repo_id) - backend = infos["storage_backend"] - num_samples = infos["num_samples"] + backend = infos.storage_backend + num_samples = infos.num_samples backend_spec = get_backend(backend) datasetdict = backend_spec.init_datasetdict_streaming_from_hub( diff --git a/src/plaid/storage/writer.py b/src/plaid/storage/writer.py index a0acfdd9..d5603c7c 100644 --- a/src/plaid/storage/writer.py +++ b/src/plaid/storage/writer.py @@ -27,7 +27,7 @@ from plaid.storage.registry import available_backends, get_backend from ..containers.sample import Sample -from ..infos import Infos, Legal +from ..infos import Infos from ..problem_definition import ProblemDefinition from .common.preprocessor import preprocess from .common.reader import ( @@ -120,7 +120,7 @@ def save_to_disk( ids: Mapping[str, Any], infos: Optional[Infos] = None, backend: str = "hf_datasets", - pb_defs: Optional[Union[dict[str, ProblemDefinition], ProblemDefinition]] = None, + pb_defs: Optional[dict[str, ProblemDefinition]] = None, num_proc: int = 1, verbose: bool = False, overwrite: bool = False, @@ -155,7 +155,8 @@ def sample_constructor(file_path): "test": test_file_paths, }, infos=Infos( - legal={"owner": "owner", "license": "license"}, + owner="owner", + license="license", ), num_proc=6, ) @@ -173,8 +174,8 @@ def sample_constructor(file_path): ``'zarr'``). infos: Dataset information to save with the dataset. If ``None``, a placeholder :class:`~plaid.Infos` is created with - ``legal=Legal(owner="unknown", license="unknown")``. - pb_defs: Optional problem definitions to save. + ``owner="unknown", license="unknown"``. + pb_defs: Optional mapping from problem definition identifiers to definitions. num_proc: Number of processes to use for parallel writing. When ``num_proc > 1`` PLAID automatically shards the identifier sequences and distributes work across workers. @@ -236,11 +237,14 @@ def sample_constructor(file_path): # written ``infos.yaml`` always reflects how the dataset was saved, # overriding any inherited values from the input ``infos``. if infos is None: - infos = Infos(legal=Legal(owner="unknown", license="unknown")) - infos_data = infos.to_dict() + infos = Infos( + owner="unknown", + license="unknown", + ) + infos_data = infos.model_dump(exclude_none=True) infos_data["num_samples"] = num_samples infos_data["storage_backend"] = backend - infos = Infos.from_mapping(infos_data) + infos = Infos.validate_persisted(infos_data) save_infos_to_disk(output_folder, infos) @@ -291,7 +295,9 @@ def push_to_hub( """ infos = load_infos_from_disk(local_dir) - backend = infos["storage_backend"] + backend = infos.storage_backend + if backend is None: + raise ValueError("Missing 'storage_backend' in persisted infos") backend_spec = get_backend(backend) backend_spec.push_local_to_hub(repo_id, local_dir, num_workers=num_workers) diff --git a/src/plaid/storage/zarr/reader.py b/src/plaid/storage/zarr/reader.py index 95bf6853..557840a1 100644 --- a/src/plaid/storage/zarr/reader.py +++ b/src/plaid/storage/zarr/reader.py @@ -321,7 +321,7 @@ def init_datasetdict_streaming_from_hub( else: infos = load_infos_from_hub(repo_id=repo_id) selected_ids = { - split: range(n_samples) for split, n_samples in infos["num_samples"].items() + split: range(n_samples) for split, n_samples in infos.num_samples.items() } return { diff --git a/src/plaid/storage/zarr/writer.py b/src/plaid/storage/zarr/writer.py index e50c10e7..71408523 100644 --- a/src/plaid/storage/zarr/writer.py +++ b/src/plaid/storage/zarr/writer.py @@ -321,7 +321,7 @@ def configure_dataset_card( None: This function does not return a value; it pushes the dataset card directly to Hugging Face Hub. """ - infos_dict = infos.to_dict() + infos_dict = infos.model_dump(exclude_none=True) dataset_card_str = """--- task_categories: - graph-ml @@ -360,7 +360,7 @@ def configure_dataset_card( lines = lines[: indices[1] + 1] count = 6 - lines.insert(count, f"license: {infos.legal.license}") + lines.insert(count, f"license: {infos.license}") count += 1 lines.insert(count, "viewer: false") count += 1 diff --git a/src/plaid/viewer/services/plaid_dataset_service.py b/src/plaid/viewer/services/plaid_dataset_service.py index 1c291def..6c5d973f 100644 --- a/src/plaid/viewer/services/plaid_dataset_service.py +++ b/src/plaid/viewer/services/plaid_dataset_service.py @@ -455,15 +455,16 @@ def _load_feature_metadata(self, dataset_id: str) -> tuple[list[str], list[str]] # writes no derived constant/variable schema metadata. Detect it # from ``infos.yaml`` and return empty feature catalogues so the # viewer falls back to inspecting samples directly. + infos = None try: if self._is_hub_dataset(dataset_id): infos = load_infos_from_hub(dataset_id) else: infos = load_infos_from_disk(str(self._dataset_dir(dataset_id))) except Exception: # pragma: no cover - defensive - infos = {} + pass if ( - infos.get("storage_backend") == "cgns" + infos is not None and infos.storage_backend == "cgns" ): # pragma: no cover - exercised via integration tests metadata = ([], []) self._feature_metadata[dataset_id] = metadata diff --git a/tests/cli/test_plaidcheck.py b/tests/cli/test_plaidcheck.py index 840a3730..6fb571e3 100644 --- a/tests/cli/test_plaidcheck.py +++ b/tests/cli/test_plaidcheck.py @@ -12,6 +12,7 @@ from plaid.cli.plaidcheck import ( CheckReport, _check_numeric_content, + _check_problem_definition_sample_features, _is_branch_without_data, _is_branch_without_data_in_mapping, check_dataset, @@ -22,6 +23,15 @@ _REFERENCE_DATASETS = ("dataset_cgns", "dataset_hf") +def _infos(num_samples: dict[str, int], storage_backend: str = "zarr") -> Infos: + return Infos( + owner="owner", + license="license", + num_samples=num_samples, + storage_backend=storage_backend, + ) + + def _copy_reference_dataset(tmp_path: Path, name: str = "dataset_cgns") -> Path: """Copy a reference dataset (CGNS or HF) used by container tests. @@ -60,6 +70,46 @@ def test_check_dataset_missing_infos(tmp_path: Path, dataset_name: str) -> None: assert any(msg.code == "MISSING_PATH" for msg in report.messages) +@pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) +def test_check_dataset_missing_required_layout_after_valid_infos( + tmp_path: Path, dataset_name: str +) -> None: + """Missing layout file (other than infos.yaml) should short-circuit checks.""" + dataset_path = _copy_reference_dataset(tmp_path, dataset_name) + if dataset_name == "dataset_cgns": + # CGNS backend only requires infos.yaml + data/. + shutil.rmtree(dataset_path / "data") + else: + (dataset_path / "variable_schema.yaml").unlink() + + report = check_dataset(dataset_path) + + assert report.has_errors() + assert any(msg.code == "MISSING_PATH" for msg in report.messages) + # The early return on missing layout means we never reach init-related codes. + assert not any(msg.code == "DATASET_INIT_ERROR" for msg in report.messages) + + +@pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) +def test_check_dataset_rejects_extra_infos_key( + tmp_path: Path, dataset_name: str +) -> None: + """Extra infos.yaml keys should be reported through infos validation.""" + dataset_path = _copy_reference_dataset(tmp_path, dataset_name) + infos_path = dataset_path / "infos.yaml" + original = infos_path.read_text(encoding="utf-8") + infos_path.write_text( + f"{original}\nplaid:\n version: 0.1.13.dev36+g21db6656e.d20260302\n", + encoding="utf-8", + ) + + report = check_dataset(dataset_path) + + assert report.has_errors() + assert any(msg.code == "INFOS_READ_ERROR" for msg in report.messages) + assert any("plaid" in msg.message for msg in report.messages) + + @pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) def test_check_dataset_num_samples_mismatch(tmp_path: Path, dataset_name: str) -> None: """Tampering with num_samples should raise split mismatch errors.""" @@ -95,6 +145,21 @@ def test_main_json_output_and_exit_code( assert "messages" in payload +@pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) +def test_main_text_success_does_not_count_ok_as_info( + tmp_path: Path, capsys, dataset_name: str +) -> None: + """Successful text output should show OK without adding an info count.""" + dataset_path = _copy_reference_dataset(tmp_path, dataset_name) + + code = main([str(dataset_path)]) + out = capsys.readouterr().out + + assert code == 0 + assert f"[OK] {dataset_path}: No issue detected" in out + assert "Summary: 0 error(s), 0 warning(s), 0 info message(s)" in out + + @pytest.mark.parametrize("dataset_name", _REFERENCE_DATASETS) def test_main_strict_fails_on_warning(tmp_path: Path, dataset_name: str) -> None: """In strict mode, warnings should make the command fail.""" @@ -199,12 +264,14 @@ def __init__( global_names: list[str] | None = None, tree: Any = None, checksum: str = "same", + features: dict[str, Any] | None = None, ) -> None: self._global_value = global_value self._field_value = field_value self._global_names = ["G"] if global_names is None else global_names self._tree = tree self._checksum = checksum + self._features = {} if features is None else features def get_zone_names(self, base: str, time: float) -> list[str]: # noqa: ARG002 """Return deterministic zone names for checker traversal. @@ -231,6 +298,8 @@ def get_feature_by_path(self, path: str) -> Any: # noqa: ARG002 Returns: Global value payload. """ + if path in self._features: + return self._features[path] return self._global_value def get_tree(self): @@ -300,8 +369,14 @@ def __init__( ) -> None: self._samples = samples self._fail_indices = set() if fail_indices is None else fail_indices + self.feature_requests: list[list[str] | None] = [] - def to_plaid(self, dataset: _FakeDataset, idx: int) -> Any: # noqa: ARG002 + def to_plaid( + self, + dataset: _FakeDataset, # noqa: ARG002 + idx: int, + features: list[str] | None = None, + ) -> Any: """Return fake sample or raise conversion error. Args: @@ -311,11 +386,36 @@ def to_plaid(self, dataset: _FakeDataset, idx: int) -> Any: # noqa: ARG002 Returns: Fake sample instance. """ + self.feature_requests.append(features) if idx in self._fail_indices: raise RuntimeError("boom") return self._samples[idx] +class _FakeSampleWithFeatureFailure: + """Sample-like object that fails for selected feature paths.""" + + def __init__(self, values: dict[str, Any], failing_features: set[str]) -> None: + self._values = values + self._failing_features = failing_features + + def get_feature_by_path(self, path: str) -> Any: + """Return configured values or raise for configured paths. + + Args: + path: Feature path to read. + + Returns: + Configured feature value. + + Raises: + RuntimeError: If the path is configured as failing. + """ + if path in self._failing_features: + raise RuntimeError(f"cannot read {path}") + return self._values[path] + + def test_check_numeric_content_all_remaining_branches() -> None: """Numeric checker should report all remaining invalid content cases.""" assert _check_numeric_content([]) == "value is empty" @@ -327,6 +427,66 @@ def test_check_numeric_content_all_remaining_branches() -> None: ) +def test_check_problem_definition_sample_reports_conversion_error() -> None: + """Problem-definition sample conversion failures should be reported.""" + report = CheckReport(messages=[]) + converter = _FakeConverter([_FakeSampleForCheck()], fail_indices={0}) + + _check_problem_definition_sample_features( + pb_name="pb", + split_dict_name="train_split", + split_name="train", + idx=0, + dataset=_FakeDataset(1), + converter=converter, + features=["Input"], + report=report, + ) + + assert len(report.messages) == 1 + msg = report.messages[0] + assert msg.severity == "error" + assert msg.code == "PB_DEF_SAMPLE_CONVERSION_ERROR" + assert msg.location == "problem_definitions/pb/train_split/train[0]" + assert msg.message == "boom" + + +def test_check_problem_definition_sample_reports_feature_read_error_and_continues() -> ( + None +): + """Feature read failures should be reported without stopping later checks.""" + report = CheckReport(messages=[]) + sample = _FakeSampleWithFeatureFailure( + values={"Good": np.array([1.0]), "BadValue": np.array([np.nan])}, + failing_features={"Unreadable"}, + ) + converter = _FakeConverter([sample]) + + _check_problem_definition_sample_features( + pb_name="pb", + split_dict_name="test_split", + split_name="test", + idx=0, + dataset=_FakeDataset(1), + converter=converter, + features=["Unreadable", "Good", "BadValue"], + report=report, + ) + + assert any( + msg.severity == "error" + and msg.code == "PB_DEF_FEATURE_READ_ERROR" + and msg.location == "problem_definitions/pb/test_split/test[0] Unreadable" + and msg.message == "cannot read Unreadable" + for msg in report.messages + ) + # Numeric content is intentionally not re-checked here: it is already + # validated by the per-split loop in `check_dataset`. + assert not any( + msg.code == "PB_DEF_INVALID_FEATURE_VALUE" for msg in report.messages + ) + + def test_is_branch_without_data_false_variants(monkeypatch) -> None: """Branch helper should return False for missing tree/node/children layout.""" @@ -372,7 +532,7 @@ def test_check_dataset_loader_failures_and_header_validations( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 1}}, # noqa: ARG005 + lambda path: _infos({"train": 1}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -403,7 +563,12 @@ def test_check_dataset_loader_failures_and_header_validations( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": 12, "num_samples": "bad"}, # noqa: ARG005 + lambda path: Infos.model_construct( # noqa: ARG005 + owner="owner", + license="license", + storage_backend=12, + num_samples="bad", + ), ) report_header = check_dataset(dataset) assert any(msg.code == "BACKEND_MISSING" for msg in report_header.messages) @@ -419,7 +584,7 @@ def test_check_dataset_split_and_data_warnings_and_duplicates( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 3}}, # noqa: ARG005 + lambda path: _infos({"train": 3}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, "load_metadata_from_disk", lambda _path: ({}, {"Var": {}}, {}, None) @@ -475,7 +640,7 @@ def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> N monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 1}}, # noqa: ARG005 + lambda path: _infos({"train": 1}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -496,6 +661,60 @@ def test_check_dataset_sample_conversion_error(tmp_path: Path, monkeypatch) -> N assert any(msg.code == "SAMPLE_CONVERSION_ERROR" for msg in report.messages) +def test_check_dataset_init_keyerror_reported_as_missing_split( + tmp_path: Path, monkeypatch +) -> None: + """KeyError raised by `init_from_disk` should map to NUM_SAMPLES_MISSING_SPLIT.""" + dataset = _make_minimal_layout(tmp_path) + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 1}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Var": {}}, {"train": {}}, None), # noqa: ARG005 + ) + + def _raise_key_error(path): # noqa: ARG001 + raise KeyError("ghost_split") + + monkeypatch.setattr(plaidcheck, "init_from_disk", _raise_key_error) + + report = check_dataset(dataset) + + assert any(msg.code == "NUM_SAMPLES_MISSING_SPLIT" for msg in report.messages) + assert any("ghost_split" in msg.message for msg in report.messages) + assert not any(msg.code == "DATASET_INIT_ERROR" for msg in report.messages) + + +def test_check_dataset_missing_num_samples_split_is_clear( + tmp_path: Path, monkeypatch +) -> None: + """Missing split declarations should not be reported as opaque KeyErrors.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "data" / "OOD").mkdir() + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 1}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Var": {}}, {"train": {}}, None), # noqa: ARG005 + ) + + report = check_dataset(dataset) + + assert any(msg.code == "NUM_SAMPLES_MISSING_SPLIT" for msg in report.messages) + assert not any(msg.code == "DATASET_INIT_ERROR" for msg in report.messages) + assert any("OOD" in msg.message for msg in report.messages) + + def test_check_dataset_problem_definition_validation_paths( tmp_path: Path, monkeypatch ) -> None: @@ -506,7 +725,7 @@ def test_check_dataset_problem_definition_validation_paths( monkeypatch.setattr( plaidcheck, "load_infos_from_disk", - lambda path: {"storage_backend": "zarr", "num_samples": {"train": 2}}, # noqa: ARG005 + lambda path: _infos({"train": 2}), # noqa: ARG005 ) monkeypatch.setattr( plaidcheck, @@ -564,6 +783,208 @@ def __init__(self, train_split, test_split): assert any(msg.code == "PB_DEF_OUT_OF_RANGE_INDICES" for msg in report_pb.messages) +def test_check_dataset_problem_definition_instantiates_filtered_features( + tmp_path: Path, monkeypatch +) -> None: + """Problem definitions should instantiate exact train/test feature subsets.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "problem_definitions").mkdir() + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 1, "test": 1}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ( # noqa: ARG005 + {"train": {}, "test": {}}, + {"Input": {}, "Output": {}}, + {"train": {}, "test": {}}, + None, + ), + ) + + train_converter = _FakeConverter( + [ + _FakeSampleForCheck( + features={"Input": np.array([1.0]), "Output": np.array([np.nan])} + ) + ] + ) + test_converter = _FakeConverter( + [ + _FakeSampleForCheck( + features={"Input": np.array([2.0]), "Output": np.array([np.nan])} + ) + ] + ) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ( # noqa: ARG005 + {"train": _FakeDataset(1), "test": _FakeDataset(1)}, + {"train": train_converter, "test": test_converter}, + ), + ) + + class _PBDef: + input_features = ["Input"] + output_features = ["Output"] + train_split = {"train": [0]} + test_split = {"test": [0]} + + monkeypatch.setattr( + plaidcheck, + "load_problem_definitions_from_disk", + lambda path: {"pb": _PBDef()}, # noqa: ARG005 + ) + + report = check_dataset(dataset, show_progress=False) + + assert train_converter.feature_requests[-1] == ["Input", "Output"] + assert test_converter.feature_requests[-1] == ["Input"] + # The pb-def loop no longer re-checks numeric content; it only verifies that + # the requested feature subset can be converted and read back. + assert not any( + msg.code == "PB_DEF_INVALID_FEATURE_VALUE" for msg in report.messages + ) + + +def test_check_dataset_filters_problem_definitions(tmp_path: Path, monkeypatch) -> None: + """Selected problem definitions should be checked without checking others.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "problem_definitions").mkdir() + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 1}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Input": {}, "Output": {}}, {"train": {}}, None), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ( # noqa: ARG005 + {"train": _FakeDataset(1)}, + {"train": _FakeConverter([_FakeSampleForCheck()])}, + ), + ) + + class _PBDef: + def __init__(self, input_features, output_features, train_split): + self.input_features = input_features + self.output_features = output_features + self.train_split = train_split + self.test_split = None + + monkeypatch.setattr( + plaidcheck, + "load_problem_definitions_from_disk", + lambda path: { # noqa: ARG005 + "selected": _PBDef(["Input"], ["Output"], {"train": [0]}), + "skipped": _PBDef(["UnknownInput"], ["UnknownOutput"], {"ghost": [0]}), + }, + ) + + report = check_dataset( + dataset, + show_progress=False, + problem_definitions=["selected"], + ) + + assert not any( + "problem_definitions/skipped" in msg.location for msg in report.messages + ) + assert not any(msg.code == "PB_DEF_UNKNOWN_INPUT" for msg in report.messages) + assert not any(msg.code == "PB_DEF_UNKNOWN_SPLIT" for msg in report.messages) + + +def test_check_dataset_reports_unknown_requested_problem_definition( + tmp_path: Path, monkeypatch +) -> None: + """Unknown requested problem definitions should be reported explicitly.""" + dataset = _make_minimal_layout(tmp_path) + (dataset / "problem_definitions").mkdir() + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 0}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Input": {}}, {"train": {}}, None), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ({"train": _FakeDataset(0)}, {"train": _FakeConverter([])}), # noqa: ARG005 + ) + + class _PBDef: + input_features = ["Input"] + output_features = [] + train_split = {"train": []} + test_split = None + + monkeypatch.setattr( + plaidcheck, + "load_problem_definitions_from_disk", + lambda path: {"known": _PBDef()}, # noqa: ARG005 + ) + + report = check_dataset( + dataset, + show_progress=False, + problem_definitions=["ghost"], + ) + + assert any(msg.code == "PB_DEF_UNKNOWN" for msg in report.messages) + assert any("known" in msg.message for msg in report.messages) + + +def test_check_dataset_problem_definition_read_error_names_yaml_file( + tmp_path: Path, monkeypatch +) -> None: + """Problem-definition read errors should identify the offending YAML file.""" + dataset = _make_minimal_layout(tmp_path) + pb_def_dir = dataset / "problem_definitions" + pb_def_dir.mkdir() + (pb_def_dir / "bad_definition.yaml").write_text( + "input_features: [in]\noutput_features: [out]\nunexpected_key: value\n", + encoding="utf-8", + ) + + monkeypatch.setattr( + plaidcheck, + "load_infos_from_disk", + lambda path: _infos({"train": 0}), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "load_metadata_from_disk", + lambda path: ({"train": {}}, {"Known": {}}, {"train": {}}, None), # noqa: ARG005 + ) + monkeypatch.setattr( + plaidcheck, + "init_from_disk", + lambda path: ({"train": _FakeDataset(0)}, {"train": _FakeConverter([])}), # noqa: ARG005 + ) + + report = check_dataset(dataset) + + assert any(msg.code == "PB_DEF_READ_ERROR" for msg in report.messages) + assert any("bad_definition.yaml" in msg.message for msg in report.messages) + assert any("extra_forbidden" in msg.message for msg in report.messages) + + def test_main_strict_returns_warning_exit_code( monkeypatch, tmp_path: Path, capsys ) -> None: @@ -571,8 +992,38 @@ def test_main_strict_returns_warning_exit_code( report = CheckReport(messages=[]) report.add("warning", "W", "loc", "msg") - monkeypatch.setattr(plaidcheck, "check_dataset", lambda path, splits=None: report) # noqa: ARG005 + monkeypatch.setattr( + plaidcheck, + "check_dataset", + lambda path, splits=None, show_progress=True, problem_definitions=None: report, # noqa: ARG005 + ) code = main([str(tmp_path), "--strict"]) _ = capsys.readouterr().out assert code == 2 + + +def test_main_json_disables_progress(monkeypatch, tmp_path: Path, capsys) -> None: + """JSON mode should disable progress bars and forward CLI filters.""" + seen: dict[str, Any] = {} + report = CheckReport(messages=[]) + + def _fake_check_dataset( + path, # noqa: ARG001 + splits=None, # noqa: ARG001 + show_progress=True, + problem_definitions=None, + ): + seen["show_progress"] = show_progress + seen["problem_definitions"] = problem_definitions + return report + + monkeypatch.setattr(plaidcheck, "check_dataset", _fake_check_dataset) + + code = main([str(tmp_path), "--json", "--problem-definition", "pb"]) + payload = json.loads(capsys.readouterr().out) + + assert code == 0 + assert payload["counts"] == {"error": 0, "warning": 0, "info": 0} + assert seen["show_progress"] is False + assert seen["problem_definitions"] == ["pb"] diff --git a/tests/conftest.py b/tests/conftest.py index fa3ee858..fbab2be9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,10 +91,13 @@ def other_samples(nb_samples: int, zone_name: str, base_name: str) -> list[Sampl @pytest.fixture() def infos(): - return Infos.from_mapping( + return Infos.model_validate( { - "legal": {"owner": "PLAID2", "license": "BSD-3"}, + "owner": "PLAID2", + "license": "BSD-3", "data_production": {"type": "simulation", "simulator": "Z-set"}, + "num_samples": {}, + "storage_backend": "zarr", } ) diff --git a/tests/containers/dataset_cgns/infos.yaml b/tests/containers/dataset_cgns/infos.yaml index 749ac15d..ffe70f8a 100644 --- a/tests/containers/dataset_cgns/infos.yaml +++ b/tests/containers/dataset_cgns/infos.yaml @@ -1,6 +1,5 @@ -legal: - owner: NeuralOperator (https://zenodo.org/records/13993629) - license: cc-by-4.0 +owner: NeuralOperator (https://zenodo.org/records/13993629) +license: cc-by-4.0 data_production: type: simulation physics: CFD diff --git a/tests/containers/dataset_cgns/problem_definitions/regression_1.yaml b/tests/containers/dataset_cgns/problem_definitions/regression_1.yaml index 7a7e1680..137a7af0 100644 --- a/tests/containers/dataset_cgns/problem_definitions/regression_1.yaml +++ b/tests/containers/dataset_cgns/problem_definitions/regression_1.yaml @@ -1,4 +1,3 @@ -name: regression_1 input_features: - Base_2_3/Zone/Elements_TRI_3/ElementConnectivity - Base_2_3/Zone/GridCoordinates/CoordinateX diff --git a/tests/containers/dataset_hf/infos.yaml b/tests/containers/dataset_hf/infos.yaml index 45560724..2478f564 100644 --- a/tests/containers/dataset_hf/infos.yaml +++ b/tests/containers/dataset_hf/infos.yaml @@ -1,6 +1,5 @@ -legal: - owner: NeuralOperator (https://zenodo.org/records/13993629) - license: cc-by-4.0 +owner: NeuralOperator (https://zenodo.org/records/13993629) +license: cc-by-4.0 data_production: type: simulation physics: CFD diff --git a/tests/containers/dataset_hf/problem_definitions/regression_1.yaml b/tests/containers/dataset_hf/problem_definitions/regression_1.yaml index 7a7e1680..137a7af0 100644 --- a/tests/containers/dataset_hf/problem_definitions/regression_1.yaml +++ b/tests/containers/dataset_hf/problem_definitions/regression_1.yaml @@ -1,4 +1,3 @@ -name: regression_1 input_features: - Base_2_3/Zone/Elements_TRI_3/ElementConnectivity - Base_2_3/Zone/GridCoordinates/CoordinateX diff --git a/tests/containers/test_utils.py b/tests/containers/test_utils.py index a7cbd352..f8b35c4b 100644 --- a/tests/containers/test_utils.py +++ b/tests/containers/test_utils.py @@ -144,14 +144,17 @@ def test_get_feature_details_from_path(self, url, expected): def test_validate_required_only(self): infos = { - "legal": {"owner": "Joh Doe", "license": "cc-by-sa-4.0"}, + "owner": "Joh Doe", + "license": "cc-by-sa-4.0", + "num_samples": {"train": 1}, + "storage_backend": "zarr", } Infos.validate_required_only(infos) infos_missing_license = { - "legal": { - "owner": "Joh Doe", - }, + "owner": "Joh Doe", + "num_samples": {"train": 1}, + "storage_backend": "zarr", } with pytest.raises(ValueError): Infos.validate_required_only(infos_missing_license) diff --git a/tests/storage/test_cgns_init.py b/tests/storage/test_cgns_init.py index 292e9f92..41089793 100644 --- a/tests/storage/test_cgns_init.py +++ b/tests/storage/test_cgns_init.py @@ -169,8 +169,13 @@ def test_cgns_backend_configure_dataset_card_requires_local_dir(): with pytest.raises(ValueError, match="local_dir must be provided for cgns backend"): CgnsBackend.configure_dataset_card( repo_id="dummy/repo", - infos=Infos.from_mapping( - {"legal": {"owner": "owner", "license": "cc-by-4.0"}} + infos=Infos.model_validate( + { + "owner": "owner", + "license": "cc-by-4.0", + "num_samples": {}, + "storage_backend": "cgns", + } ), ) @@ -184,7 +189,14 @@ def fake_configure_dataset_card(**kwargs): monkeypatch.setattr(cgns, "configure_dataset_card", fake_configure_dataset_card) - infos = Infos.from_mapping({"legal": {"owner": "owner", "license": "cc-by-4.0"}}) + infos = Infos.model_validate( + { + "owner": "owner", + "license": "cc-by-4.0", + "num_samples": {}, + "storage_backend": "cgns", + } + ) CgnsBackend.configure_dataset_card( repo_id="dummy/repo", infos=infos, diff --git a/tests/storage/test_common_writer.py b/tests/storage/test_common_writer.py new file mode 100644 index 00000000..76eb9ac8 --- /dev/null +++ b/tests/storage/test_common_writer.py @@ -0,0 +1,65 @@ +"""Tests for `plaid.storage.common.writer` validation paths.""" + +from pathlib import Path + +import pytest + +from plaid.problem_definition import ProblemDefinition +from plaid.storage.common.writer import save_problem_definitions_to_disk + + +def _make_pb_def() -> ProblemDefinition: + return ProblemDefinition( + input_features=["Global/in"], + output_features=["Global/out"], + train_split={"train": [0]}, + test_split={"test": [0]}, + ) + + +def test_save_problem_definitions_to_disk_rejects_non_dict_non_pbdef( + tmp_path: Path, +) -> None: + """Passing a non-dict, non-ProblemDefinition value should raise TypeError.""" + with pytest.raises(TypeError, match=r"dict\[str, ProblemDefinition\]"): + save_problem_definitions_to_disk(tmp_path, [("name", _make_pb_def())]) # type: ignore[arg-type] + + +def test_save_problem_definitions_to_disk_rejects_non_string_identifier( + tmp_path: Path, +) -> None: + """Non-string / empty identifiers should raise TypeError.""" + pb_def = _make_pb_def() + with pytest.raises(TypeError, match="non-empty strings"): + save_problem_definitions_to_disk(tmp_path, {123: pb_def}) # type: ignore[dict-item] + with pytest.raises(TypeError, match="non-empty strings"): + save_problem_definitions_to_disk(tmp_path, {"": pb_def}) + + +def test_save_problem_definitions_to_disk_rejects_non_pbdef_value( + tmp_path: Path, +) -> None: + """Non-ProblemDefinition values should raise TypeError.""" + with pytest.raises(TypeError, match="ProblemDefinition instances"): + save_problem_definitions_to_disk(tmp_path, {"pb": "not a pb_def"}) # type: ignore[dict-item] + + +def test_save_problem_definitions_to_disk_rejects_bare_pbdef(tmp_path: Path) -> None: + """Passing a bare ProblemDefinition (not wrapped in a dict) should raise.""" + with pytest.raises(TypeError, match="use the dictionary key as the problem"): + save_problem_definitions_to_disk(tmp_path, _make_pb_def()) # type: ignore[arg-type] + + +def test_save_problem_definitions_to_disk_writes_each_definition( + tmp_path: Path, +) -> None: + """Happy path: each ProblemDefinition is delegated to its `save_to_file`.""" + pb_defs = {"pb_a": _make_pb_def(), "pb_b": _make_pb_def()} + + save_problem_definitions_to_disk(tmp_path, pb_defs) + + target_dir = tmp_path / "problem_definitions" + assert target_dir.is_dir() + # ProblemDefinition.save_to_file serialises each definition as a YAML file. + assert (target_dir / "pb_a.yaml").is_file() + assert (target_dir / "pb_b.yaml").is_file() diff --git a/tests/storage/test_hf_datasets_init.py b/tests/storage/test_hf_datasets_init.py index 5f92b315..50673630 100644 --- a/tests/storage/test_hf_datasets_init.py +++ b/tests/storage/test_hf_datasets_init.py @@ -210,7 +210,14 @@ def fake_configure_dataset_card( hf_datasets, "configure_dataset_card", fake_configure_dataset_card ) - infos = Infos.from_mapping({"legal": {"owner": "owner", "license": "cc-by-4.0"}}) + infos = Infos.model_validate( + { + "owner": "owner", + "license": "cc-by-4.0", + "num_samples": {}, + "storage_backend": "hf_datasets", + } + ) HFBackend.configure_dataset_card("dummy/repo", infos) assert call == {"repo_id": "dummy/repo", "infos": infos, "local_dir": None} diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index b9f2f449..e9727767 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -163,11 +163,12 @@ def main_splits() -> dict: @pytest.fixture() def problem_definition(main_splits) -> ProblemDefinition: - problem_definition = ProblemDefinition() - problem_definition.add_input_features(["feature_name_1", "feature_name_2"]) - problem_definition.train_split = {"train": main_splits["train"]} - problem_definition.test_split = {"test": main_splits["test"]} - return problem_definition + return ProblemDefinition( + input_features=["feature_name_1", "feature_name_2"], + output_features=["feature_name_1"], + train_split={"train": main_splits["train"]}, + test_split={"test": main_splits["test"]}, + ) @pytest.fixture() @@ -183,8 +184,8 @@ def _sample_constructor(id): @pytest.fixture() def split_ids(problem_definition) -> dict: return { - "train": problem_definition.get_train_split_indices(), - "test": problem_definition.get_test_split_indices(), + "train": problem_definition.train_split["train"], + "test": problem_definition.test_split["test"], } @@ -264,8 +265,7 @@ def test_hf_datasets( overwrite=False, ) - with pytest.raises(ValueError): - problem_definition.name = None + with pytest.raises(TypeError, match=r"dict\[str, ProblemDefinition\]"): save_to_disk( output_folder=test_dir, sample_constructor=sample_constructor, @@ -287,7 +287,8 @@ def test_hf_datasets( verbose=True, ) - load_problem_definitions_from_disk(test_dir) + loaded_pb_defs = load_problem_definitions_from_disk(test_dir) + assert set(loaded_pb_defs) == {"pb_def"} with pytest.raises(ValueError): load_problem_definitions_from_disk("dummy") diff --git a/tests/storage/test_zarr_init.py b/tests/storage/test_zarr_init.py index ea8c320b..2715e305 100644 --- a/tests/storage/test_zarr_init.py +++ b/tests/storage/test_zarr_init.py @@ -176,7 +176,14 @@ def fake_configure_dataset_card( monkeypatch.setattr(zarr, "configure_dataset_card", fake_configure_dataset_card) - infos = Infos.from_mapping({"legal": {"owner": "owner", "license": "cc-by-4.0"}}) + infos = Infos.model_validate( + { + "owner": "owner", + "license": "cc-by-4.0", + "num_samples": {}, + "storage_backend": "zarr", + } + ) result = ZarrBackend.configure_dataset_card("dummy/repo", infos, "/tmp/local") assert result == "configured" diff --git a/tests/test_problem_definition.py b/tests/test_problem_definition.py index 38b23842..64dd0809 100644 --- a/tests/test_problem_definition.py +++ b/tests/test_problem_definition.py @@ -14,13 +14,16 @@ @pytest.fixture() def problem_definition() -> ProblemDefinition: - return ProblemDefinition() + return ProblemDefinition.model_construct( + input_features=[], + output_features=[], + train_split=None, + test_split=None, + ) @pytest.fixture() def problem_definition_full(problem_definition: ProblemDefinition) -> ProblemDefinition: - problem_definition.name = "regression_1" - # ---- feature_identifier = "Global/feature" predict_feature_identifier = "Global/predict_feature" @@ -77,158 +80,136 @@ class Test_ProblemDefinition: def test__init__(self, problem_definition): print(problem_definition) - def test__init__both_path_and_directory_path(self, current_directory): - d_path = current_directory / "problem_definition" - with pytest.raises(ValueError): - ProblemDefinition(path=d_path, directory_path=d_path) + def test_required_fields(self): + with pytest.raises(ValidationError, match="Field required"): + ProblemDefinition() + + def test_feature_lists_must_not_be_empty(self): + base = { + "train_split": {"train": "all"}, + "test_split": {"test": "all"}, + } + with pytest.raises(ValidationError, match="input_features must not be empty"): + ProblemDefinition( + **base, + input_features=[], + output_features=["out"], + ) + with pytest.raises(ValidationError, match="output_features must not be empty"): + ProblemDefinition( + **base, + input_features=["in"], + output_features=[], + ) # -------------------------------------------------------------------------# - def test_from_path_single_definition(self, monkeypatch, tmp_path): - expected = ProblemDefinition( - name="pb_single", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, - ) - - def fake_loader(path): - assert path == tmp_path - return {"pb_single": expected} - - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", fake_loader - ) - - loaded = ProblemDefinition.from_path(tmp_path) - assert loaded.name == "pb_single" - assert loaded.input_features == ["in_a"] - assert loaded.output_features == ["out_a"] - assert loaded.get_train_split_name() == "train_0" - assert loaded.get_test_split_name() == "test_0" - assert loaded.get_train_split_indices() == [0, 1] - assert loaded.get_test_split_indices() == [2] - - def test_from_path_single_definition_with_override(self, monkeypatch, tmp_path): - expected = ProblemDefinition( - name="pb_single", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, + def test_from_mapping_validates_and_normalizes(self): + loaded = ProblemDefinition.model_validate( + { + "input_features": ["in_b", "in_a"], + "output_features": ["out_b", "out_a"], + "train_split": {"train_0": [0, 1]}, + "test_split": {"test_0": [2]}, + } ) - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", - lambda path: {"pb_single": expected}, # noqa: ARG005 - ) + assert loaded.input_features == ["in_a", "in_b"] + assert loaded.output_features == ["out_a", "out_b"] + assert loaded.train_split == {"train_0": [0, 1]} + assert loaded.test_split == {"test_0": [2]} - loaded = ProblemDefinition.from_path(tmp_path) - assert loaded.name == "pb_single" - - def test_from_path_named_definition_and_override(self, monkeypatch, tmp_path): - pb_1 = ProblemDefinition( - name="pb_1", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, - ) - pb_2 = ProblemDefinition( - name="pb_2", - input_features=["in_b"], - output_features=["out_b"], - train_split={"train_1": [3, 4]}, - test_split={"test_1": [5]}, + def test_from_path_loads_single_yaml_file(self, tmp_path: Path): + file_path = tmp_path / "problem.yaml" + file_path.write_text( + "input_features:\n" + " - in_1\n" + "output_features:\n" + " - out_1\n" + "train_split:\n" + " train: [0]\n" + "test_split:\n" + " test: [1]\n", + encoding="utf-8", ) - def fake_loader(path): - assert path == tmp_path - return {"pb_1": pb_1, "pb_2": pb_2} + loaded = ProblemDefinition.from_path(file_path) - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", fake_loader - ) + assert loaded.train_split == {"train": [0]} + assert loaded.test_split == {"test": [1]} - loaded = ProblemDefinition.from_path( - tmp_path, - name="pb_2", - ) - assert loaded.name == "pb_2" - - def test_from_path_unknown_name_raises(self, monkeypatch, tmp_path): - pb = ProblemDefinition( - name="existing", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, - ) - - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", - lambda path: {"existing": pb}, # noqa: ARG005 + def test_from_path_adds_yaml_suffix(self, tmp_path: Path): + file_path = tmp_path / "problem.yaml" + file_path.write_text( + "input_features: [in_1]\n" + "output_features: [out_1]\n" + "train_split:\n" + " train: all\n" + "test_split:\n" + " test: all\n", + encoding="utf-8", ) - with pytest.raises(ValueError, match="Problem definition 'missing' not found"): - ProblemDefinition.from_path(tmp_path, name="missing") + loaded = ProblemDefinition.from_path(tmp_path / "problem") - def test_from_path_requires_name_when_multiple(self, monkeypatch, tmp_path): - pb_1 = ProblemDefinition( - name="pb_1", - input_features=["in_a"], - output_features=["out_a"], - train_split={"train_0": [0, 1]}, - test_split={"test_0": [2]}, - ) - pb_2 = ProblemDefinition( - name="pb_2", - input_features=["in_b"], - output_features=["out_b"], - train_split={"train_1": [3, 4]}, - test_split={"test_1": [5]}, - ) + assert loaded.input_features == ["in_1"] - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", - lambda path: {"pb_1": pb_1, "pb_2": pb_2}, # noqa: ARG005 + def test_from_path_rejects_old_name_key(self, tmp_path: Path): + file_path = tmp_path / "problem_with_name.yaml" + file_path.write_text( + "name: pb\n" + "input_features: [in_1]\n" + "output_features: [out_1]\n" + "train_split:\n" + " train: all\n" + "test_split:\n" + " test: all\n", + encoding="utf-8", ) - with pytest.raises(RuntimeError, match="more than one Problem definition"): - ProblemDefinition.from_path(tmp_path) + with pytest.raises(ValidationError, match="extra_forbidden"): + ProblemDefinition.from_path(file_path) - def test_from_path_error_lists_sorted_available_names(self, monkeypatch, tmp_path): - pb_a = ProblemDefinition( - name="a", input_features=["in"], output_features=["out"] - ) - pb_b = ProblemDefinition( - name="b", input_features=["in"], output_features=["out"] + def test_from_path_unknown_key_raises(self, tmp_path: Path): + file_path = tmp_path / "problem_with_unknown.yaml" + file_path.write_text( + "input_features: [in_1]\n" + "output_features: [out_1]\n" + "train_split:\n" + " train: all\n" + "test_split:\n" + " test: all\n" + "unknown_key: value\n", + encoding="utf-8", ) - monkeypatch.setattr( - "plaid.storage.load_problem_definitions_from_disk", - lambda path: {"b": pb_b, "a": pb_a}, # noqa: ARG005 - ) + with pytest.raises(ValidationError, match="extra_forbidden"): + ProblemDefinition.from_path(file_path) - with pytest.raises(ValueError, match="Available definitions: a, b"): - ProblemDefinition.from_path(tmp_path, name="missing") + def test_from_path_non_existing_file(self): + with pytest.raises(FileNotFoundError): + ProblemDefinition.from_path(Path("non_existing_path")) def test_feature_validators_reject_duplicates(self): with pytest.raises( ValidationError, match="duplicated values in input_features" ): - ProblemDefinition(input_features=["a", "a"]) + ProblemDefinition( + input_features=["a", "a"], + output_features=["out"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) with pytest.raises( ValidationError, match="duplicated values in output_features" ): - ProblemDefinition(output_features=["a", "a"]) - - def test_non_overwritable_attributes_raise(self, problem_definition): - problem_definition.name = "problem_a" - with pytest.raises(AttributeError, match="'name' is already set"): - problem_definition.name = "problem_b" + ProblemDefinition( + input_features=["in"], + output_features=["a", "a"], + train_split={"train": "all"}, + test_split={"test": "all"}, + ) def test_split_replacement_logs_warning(self, problem_definition, caplog): problem_definition.train_split = {"train_0": [0, 1]} @@ -237,24 +218,12 @@ def test_split_replacement_logs_warning(self, problem_definition, caplog): assert "already exists -> data will be replaced" in caplog.text - def test_get_split_paths(self, problem_definition): + def test_split_fields_are_plain_dictionaries(self, problem_definition): problem_definition.train_split = {"train_0": [0, 1, 2]} problem_definition.test_split = {"test_0": [3, 4]} - assert problem_definition.get_train_split_name() == "train_0" - assert problem_definition.get_test_split_name() == "test_0" - assert problem_definition.get_train_split_indices() == [0, 1, 2] - assert problem_definition.get_test_split_indices() == [3, 4] - - def test_get_split_paths_raise_when_not_defined(self, problem_definition): - with pytest.raises(ValueError, match="train_split is not defined"): - problem_definition.get_train_split_name() - with pytest.raises(ValueError, match="train_split is not defined"): - problem_definition.get_train_split_indices() - with pytest.raises(ValueError, match="test_split is not defined"): - problem_definition.get_test_split_name() - with pytest.raises(ValueError, match="test_split is not defined"): - problem_definition.get_test_split_indices() + assert problem_definition.train_split == {"train_0": [0, 1, 2]} + assert problem_definition.test_split == {"test_0": [3, 4]} def test_add_feature_identifiers_duplicate_checks(self, problem_definition): problem_definition.add_input_features(["in_1", "in_2"]) @@ -280,37 +249,8 @@ def test_split(self, problem_definition): assert problem_definition.train_split == {"train_0": [0, 1, 2]} assert problem_definition.test_split == {"test-1": [3, 4]} - def test__load_from_file_( - self, problem_definition_full: ProblemDefinition, tmp_path: Path - ): - - path = tmp_path / "pb_def" - problem_definition_full.save_to_file(path) - problem = ProblemDefinition() - problem._load_from_file_(path) - assert set(problem.input_features) == set( - [ - "Base_2_2/Zone/PointData/sig12", - "Base_2_2/Zone/PointData/U1", - "Base_2_2/Zone/PointData/U2", - "Global/predict_feature", - "Global/test_feature", - "Global/feature", - ] - ) - assert set(problem.output_features) == set( - [ - "Global/predict_feature", - "Base_2_2/Zone/PointData/sig12", - "Global/feature", - "Base_2_2/Zone/PointData/U2", - "Global/test_feature", - ] - ) - def test_save_and_load_keep_yaml_suffix(self, tmp_path: Path): problem = ProblemDefinition( - name="pb", input_features=["in_1"], output_features=["out_1"], train_split={"train": [0]}, @@ -319,36 +259,9 @@ def test_save_and_load_keep_yaml_suffix(self, tmp_path: Path): file_path = tmp_path / "problem.yaml" problem.save_to_file(file_path) - loaded = ProblemDefinition() - loaded._load_from_file_(file_path) - - assert loaded.name == "pb" - assert loaded.get_train_split_name() == "train" - assert loaded.get_test_split_name() == "test" - - def test__load_from_file__unknown_field_warns_and_raises( - self, tmp_path: Path, caplog - ): - file_path = tmp_path / "problem_with_unknown.yaml" - file_path.write_text( - "name: pb\n" - "task: regression\n" - "input_features:\n" - " - in_1\n" - "output_features:\n" - " - out_1\n" - "unknown_key: value\n", - encoding="utf-8", - ) + loaded = ProblemDefinition.from_path(file_path) - problem = ProblemDefinition() - with caplog.at_level("WARNING"): - problem._load_from_file_(file_path) - - assert "Data ignored! : unknown_key: value" in caplog.text - - def test__load_from_file__non_existing_file(self): - problem = ProblemDefinition() - non_existing_path = Path("non_existing_path") - with pytest.raises(FileNotFoundError): - problem._load_from_file_(non_existing_path) + saved_text = file_path.read_text(encoding="utf-8") + assert "name:" not in saved_text + assert loaded.train_split == {"train": [0]} + assert loaded.test_split == {"test": [1]} diff --git a/tests/utils/test_info.py b/tests/utils/test_info.py index 3dea8f13..9b2273aa 100644 --- a/tests/utils/test_info.py +++ b/tests/utils/test_info.py @@ -1,16 +1,31 @@ import pytest from pydantic import ValidationError -from plaid.infos import Infos, Legal +from plaid.infos import Infos -def test_verify_info_accepts_special_internal_keys(): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, +def _valid_infos(**overrides): + data = { + "owner": "owner", + "license": "cc-by-4.0", "num_samples": {"train": 10}, "storage_backend": "zarr", } - Infos(**infos) + data.update(overrides) + return data + + +def test_verify_info_accepts_flat_owner_license_keys(): + Infos(**_valid_infos()) + + +def test_infos_allows_draft_without_storage_derived_fields(): + model = Infos(owner="owner", license="cc-by-4.0") + + assert model.owner == "owner" + assert model.license == "cc-by-4.0" + assert model.num_samples == {} + assert model.storage_backend is None def test_verify_info_rejects_unknown_category(): @@ -18,26 +33,37 @@ def test_verify_info_rejects_unknown_category(): Infos(**{"unknown": {"x": "y"}}) -def test_verify_info_rejects_unknown_key(): - with pytest.raises(ValidationError): - Infos(**{"legal": {"unknown_key": "v"}}) +def test_verify_info_rejects_legacy_legal_key(): + with pytest.raises(ValidationError, match="extra_forbidden"): + Infos(**{"legal": {"owner": "owner", "license": "cc-by-4.0"}}) def test_validate_required_only_missing_required_key(): with pytest.raises(ValueError): - Infos(**{"legal": {"owner": "someone"}}) + Infos(**{"owner": "someone"}) -def test_normalize_infos_strips_legacy_plaid_section_and_copies(): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "plaid": {"version": "x"}, - } - normalized = Infos.normalize_mapping(infos) +def test_validate_required_only_requires_persisted_fields(): + with pytest.raises(ValueError, match="num_samples"): + Infos.validate_required_only({"owner": "owner", "license": "cc-by-4.0"}) - # The legacy ``plaid`` section is dropped from the validated payload. - assert "plaid" not in normalized - # And the input mapping is not mutated. + with pytest.raises(ValueError, match="storage_backend"): + Infos.validate_required_only( + { + "owner": "owner", + "license": "cc-by-4.0", + "num_samples": {}, + } + ) + + +def test_normalize_infos_rejects_legacy_plaid_section_and_copies(): + infos = {"owner": "owner", "license": "cc-by-4.0", "plaid": {"version": "x"}} + + with pytest.raises(ValidationError, match="extra_forbidden"): + Infos.normalize_mapping(infos) + + # The input mapping is not mutated before validation raises. assert "plaid" in infos @@ -45,82 +71,83 @@ def test_model_validate_rejects_plaid_section(): with pytest.raises(ValidationError): Infos.model_validate( { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, + "owner": "owner", + "license": "cc-by-4.0", "plaid": {"version": "x"}, } ) def test_dataset_info_model_validate_success(): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "num_samples": {"train": 10}, - "storage_backend": "zarr", - } + model = Infos.model_validate(_valid_infos()) - model = Infos.model_validate(infos) - - assert model.legal.owner == "owner" + assert model.owner == "owner" + assert model.license == "cc-by-4.0" assert model.storage_backend == "zarr" def test_dataset_info_model_validate_rejects_extra_top_level_key(): with pytest.raises(ValueError): - Infos.model_validate( - { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "unknown": {}, - } - ) + Infos.model_validate({"owner": "owner", "license": "cc-by-4.0", "unknown": {}}) def test_infos_save_and_load_roundtrip(tmp_path): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - "num_samples": {"train": 10}, - "storage_backend": "zarr", - } - model = Infos.from_mapping(infos) + model = Infos.model_validate(_valid_infos()) target = tmp_path / "infos.yaml" model.save_to_file(target) assert target.is_file() reloaded = Infos.from_path(target) - assert reloaded.legal.owner == "owner" + assert reloaded.owner == "owner" + assert reloaded.license == "cc-by-4.0" assert reloaded.storage_backend == "zarr" assert reloaded.num_samples == {"train": 10} def test_infos_from_path_directory(tmp_path): - infos = { - "legal": {"owner": "owner", "license": "cc-by-4.0"}, - } - Infos.from_mapping(infos).save_to_file(tmp_path / "infos.yaml") + Infos.model_validate(_valid_infos()).save_to_file(tmp_path / "infos.yaml") reloaded = Infos.from_path(tmp_path) - assert reloaded.legal.license == "cc-by-4.0" + assert reloaded.license == "cc-by-4.0" + + +def test_infos_from_path_requires_persisted_fields_by_default(tmp_path): + Infos(owner="o", license="l").save_to_file(tmp_path) + + with pytest.raises(ValueError, match="num_samples"): + Infos.from_path(tmp_path) -def test_validate_authorized_only_allows_missing_legal(): +def test_infos_from_path_can_load_draft_infos(tmp_path): + Infos(owner="o", license="l").save_to_file(tmp_path) + + reloaded = Infos.from_path(tmp_path, require_persisted=False) + + assert reloaded.owner == "o" + assert reloaded.license == "l" + assert reloaded.num_samples == {} + assert reloaded.storage_backend is None + + +def test_validate_authorized_only_allows_missing_owner_license(): model = Infos.validate_authorized_only( {"num_samples": {"train": 1}, "storage_backend": "zarr"} ) - # Missing legal must be filled with empty placeholder values. - assert model.legal.owner == "" - assert model.legal.license == "" + # Missing user-authored required fields are filled with empty placeholder values. + assert model.owner == "" + assert model.license == "" assert model.storage_backend == "zarr" -def test_validate_authorized_only_with_legal_present(): - model = Infos.validate_authorized_only({"legal": {"owner": "o", "license": "l"}}) - assert model.legal.owner == "o" +def test_validate_authorized_only_with_owner_license_present(): + model = Infos.validate_authorized_only({"owner": "o", "license": "l"}) + assert model.owner == "o" + assert model.license == "l" def test_validate_authorized_only_rejects_unauthorized_key(): with pytest.raises(KeyError): - Infos.validate_authorized_only( - {"legal": {"owner": "o", "license": "l"}, "unknown": {}} - ) + Infos.validate_authorized_only({"owner": "o", "license": "l", "unknown": {}}) def test_validate_authorized_only_reraises_other_validation_errors(): @@ -129,84 +156,77 @@ def test_validate_authorized_only_reraises_other_validation_errors(): # must be re-raised as ValidationError. with pytest.raises(ValidationError): Infos.validate_authorized_only( - {"legal": {"owner": "o", "license": "l"}, "num_samples": "nope"} + {"owner": "o", "license": "l", "num_samples": "nope"} ) def test_validate_required_only_accepts_valid_mapping(): - Infos.validate_required_only({"legal": {"owner": "o", "license": "l"}}) + Infos.validate_required_only( + { + "owner": "o", + "license": "l", + "num_samples": {"train": 1}, + "storage_backend": "zarr", + } + ) -def test_validate_required_only_missing_legal(): +def test_validate_required_only_missing_owner_license(): with pytest.raises(ValidationError): Infos.validate_required_only({}) -def test_to_dict_returns_plain_mapping(): - model = Infos.from_mapping( - {"legal": {"owner": "o", "license": "l"}, "storage_backend": "zarr"} +def test_model_dump_returns_plain_mapping(): + model = Infos.model_validate( + { + "owner": "o", + "license": "l", + "num_samples": {}, + "storage_backend": "zarr", + } ) - d = model.to_dict() - assert d["legal"] == {"owner": "o", "license": "l"} + d = model.model_dump(exclude_none=True) + assert d["owner"] == "o" + assert d["license"] == "l" assert d["storage_backend"] == "zarr" -def test_getitem_returns_plain_value_and_unwraps_nested_dataclasses(): - model = Infos.from_mapping( +def test_attribute_access_returns_typed_values(): + model = Infos.model_validate( { - "legal": {"owner": "o", "license": "l"}, + "owner": "o", + "license": "l", + "num_samples": {}, "storage_backend": "zarr", } ) - # Plain field. - assert model["storage_backend"] == "zarr" - # Nested dataclass is returned as a dict, with None entries dropped. - legal_dict = model["legal"] - assert legal_dict == {"owner": "o", "license": "l"} - - -def test_getitem_raises_key_error_for_unknown_field(): - model = Infos.from_mapping({"legal": {"owner": "o", "license": "l"}}) - with pytest.raises(KeyError): - model["does_not_exist"] - - -def test_contains_handles_strings_and_non_strings(): - model = Infos.from_mapping( - {"legal": {"owner": "o", "license": "l"}, "storage_backend": "zarr"} - ) - assert "legal" in model - assert "storage_backend" in model - # Field exists but is None -> not "in" the model. - assert "data_production" not in model - # Unknown attribute name. - assert "nope" not in model - # Non-string keys are never in the model. - assert (123 in model) is False - - -def test_get_returns_default_when_missing(): - model = Infos.from_mapping({"legal": {"owner": "o", "license": "l"}}) - assert model.get("storage_backend", "fallback") == "fallback" - assert model.get("legal")["owner"] == "o" + assert model.storage_backend == "zarr" + assert model.owner == "o" + assert model.license == "l" def test_save_to_file_treats_suffixless_path_as_directory(tmp_path): target = tmp_path / "myinfos" - Infos(legal=Legal(owner="o", license="l")).save_to_file(target) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file( + target + ) # Suffix-less, non-existing paths are treated as directories that # will hold an ``infos.yaml``. assert (target / "infos.yaml").is_file() def test_save_to_file_into_existing_directory(tmp_path): - Infos(legal=Legal(owner="o", license="l")).save_to_file(tmp_path) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file( + tmp_path + ) assert (tmp_path / "infos.yaml").is_file() def test_save_to_file_replaces_non_yaml_suffix(tmp_path): target = tmp_path / "weird.txt" - Infos(legal=Legal(owner="o", license="l")).save_to_file(target) + Infos(owner="o", license="l", num_samples={}, storage_backend="zarr").save_to_file( + target + ) # ``.txt`` suffix is replaced by ``.yaml``. assert (tmp_path / "weird.yaml").is_file() assert not target.exists() @@ -222,14 +242,20 @@ def test_save_to_file_preserves_unknown_future_keys(tmp_path, monkeypatch): from plaid import infos as infos_mod monkeypatch.setattr(infos_mod, "_KEY_ORDER", ()) - model = Infos.from_mapping( - {"legal": {"owner": "o", "license": "l"}, "storage_backend": "zarr"} + model = Infos.model_validate( + { + "owner": "o", + "license": "l", + "num_samples": {}, + "storage_backend": "zarr", + } ) target = tmp_path / "out.yaml" model.save_to_file(target) text = target.read_text(encoding="utf-8") assert "storage_backend" in text - assert "legal" in text + assert "owner" in text + assert "license" in text def test_from_path_raises_file_not_found(tmp_path):