Skip to content

Commit 8a254db

Browse files
♻️ simplify Infos class (#444)
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
1 parent 0bda35e commit 8a254db

8 files changed

Lines changed: 197 additions & 117 deletions

File tree

docs/source/concepts/infos.md

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,24 @@ In the current API, infos stores:
2121
## Basic usage
2222

2323
```python
24-
from plaid.infos import DataProduction, Infos
24+
from plaid import Infos
2525

2626
infos = Infos(
2727
owner="Safran",
2828
license="proprietary",
29-
data_production=DataProduction(
30-
type="simulation",
31-
physics="fluid dynamics",
32-
simulator="ExampleSolver",
33-
),
29+
data_production={
30+
"type": "simulation",
31+
"physics": "fluid dynamics",
32+
"simulator": "ExampleSolver",
33+
},
3434
data_description="ExampleDescription",
3535
)
3636
```
3737

38-
Infos can also be built from a plain mapping, for instance after reading YAML:
38+
To inspect the public constructor fields accepted by `Infos`, use:
3939

4040
```python
41-
infos = Infos.model_validate(
42-
{
43-
"owner": "Safran",
44-
"license": "proprietary",
45-
}
46-
)
41+
Infos.print_available_fields()
4742
```
4843

4944
`num_samples` and `storage_backend` are derived from the chosen storage backend
@@ -60,13 +55,7 @@ infos = Infos.from_path("/path/to/plaid_dataset")
6055
```
6156

6257
When a directory is provided, `Infos.from_path(...)` looks for `infos.yaml`
63-
inside that directory. By default, loading from disk requires the persisted
64-
storage metadata (`num_samples` and `storage_backend`) to be present. To load a
65-
draft infos file that has not been produced by `save_to_disk(...)`, use:
66-
67-
```python
68-
infos = Infos.from_path("/path/to/draft/infos.yaml", require_persisted=False)
69-
```
58+
inside that directory.
7059

7160
## Saving
7261

@@ -77,7 +66,10 @@ infos.save_to_file("/path/to/plaid_dataset/infos.yaml")
7766
```
7867

7968
If a directory path is provided, the file is saved as `infos.yaml` inside that
80-
directory.
69+
directory. Direct YAML writing requires complete persisted metadata: `owner`,
70+
`license`, `num_samples`, and `storage_backend`. When using
71+
`save_to_disk(..., infos=...)`, PLAID fills `num_samples` and `storage_backend`
72+
automatically before writing `infos.yaml`.
8173

8274
## Typed access and serialization
8375

docs/source/tutorials/storage.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ infos = Infos(
9696
owner="NeuralOperator (https://zenodo.org/records/13993629)",
9797
license="cc-by-4.0",
9898
data_description="No changes to data content from original dataset",
99+
type="simulation",
100+
physics="phase-field fracture models for brittle fracture",
101+
script="Subset 'res-SENS' of the initial dataset, 1/5th time steps, converted to PLAID format for standardized access; no changes to data content."
99102
)
100103

101104

examples/infos_example.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
# %%
3838
# Import necessary libraries and classes
39-
from plaid.infos import DataProduction, Infos
39+
from plaid.infos import Infos
4040

4141
# %% [markdown]
4242
# ## Section 1: Initializing Infos
@@ -51,51 +51,25 @@
5151
infos = Infos(
5252
owner="PLAID",
5353
license="MIT",
54+
data_production={
55+
"type": "simulation",
56+
"physics": "fluid dynamics",
57+
"simulator": "ExampleSolver",
58+
},
59+
data_description="ExampleDescription",
5460
)
5561
print(f"{infos = }")
5662

5763
# %% [markdown]
58-
# ### Initialize Infos from a plain mapping
64+
# ### Print available Infos fields
5965

6066
# %%
61-
infos_from_mapping = Infos.model_validate(
62-
{
63-
"owner": "PLAID",
64-
"license": "MIT",
65-
"data_description": "Example metadata for a PLAID dataset.",
66-
}
67-
)
68-
print(f"{infos_from_mapping = }")
67+
Infos.print_available_fields()
6968

7069
# %% [markdown]
71-
# ## Section 2: Configuring Infos and retrieve data
70+
# ## Section 2: Modifying Infos and retrieve data
7271
#
73-
# This section demonstrates how to handle and configure Infos objects and access
74-
# metadata.
75-
76-
# %% [markdown]
77-
# ### Set owner and license metadata
78-
79-
# %%
80-
infos.owner = "Safran"
81-
infos.license = "proprietary"
82-
print(f"{infos.owner = }")
83-
print(f"{infos.license = }")
84-
85-
# %% [markdown]
86-
# ### Set data production metadata
87-
88-
# %%
89-
infos.data_production = DataProduction(
90-
type="simulation",
91-
physics="fluid dynamics",
92-
simulator="ExampleSolver",
93-
hardware="ExampleCluster",
94-
computation_duration="1 hour",
95-
script="run_simulation.py",
96-
contact="contact@example.com",
97-
)
98-
print(f"{infos.data_production = }")
72+
# This section demonstrates how to handle Infos objects and access metadata.
9973

10074
# %% [markdown]
10175
# ### Set data description
@@ -119,8 +93,7 @@
11993
# %% [markdown]
12094
# ## Section 3: Saving and Loading Infos
12195
#
122-
# This section demonstrates how to save and load Infos from a directory or YAML
123-
# file.
96+
# This section demonstrates how to save and load Infos from a YAML file.
12497

12598
# %% [markdown]
12699
# ### Save Infos to a YAML file
@@ -131,20 +104,22 @@
131104
)
132105
infos_save_fname = test_pth / "infos.yaml"
133106
test_pth.mkdir(parents=True, exist_ok=True)
134-
print(f"saving path: {infos_save_fname}")
135107

108+
print(f"saving path: {infos_save_fname}")
109+
infos.num_samples = {"train": 0}
110+
infos.storage_backend = "zarr"
136111
infos.save_to_file(infos_save_fname)
137112

138113
# %% [markdown]
139114
# ### Load Infos from a YAML file
140115

141116
# %%
142-
loaded_infos = Infos.from_path(infos_save_fname, require_persisted=False)
117+
loaded_infos = Infos.from_path(infos_save_fname)
143118
print(loaded_infos)
144119

145120
# %% [markdown]
146-
# ### Load Infos from a directory containing infos.yaml
121+
# ### Load Infos from an explicit infos.yaml path
147122

148123
# %%
149-
loaded_infos_from_dir = Infos.from_path(test_pth, require_persisted=False)
150-
print(loaded_infos_from_dir)
124+
loaded_infos_from_explicit_path = Infos.from_path(test_pth / "infos.yaml")
125+
print(loaded_infos_from_explicit_path)

src/plaid/cli/plaidcheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def check_dataset(
435435
)
436436
return report
437437
try:
438-
infos = load_infos_from_disk(path)
438+
infos = load_infos_from_disk(path / "infos.yaml")
439439
except Exception as exc:
440440
report.add("error", "INFOS_READ_ERROR", "infos.yaml", str(exc))
441441
return report

src/plaid/infos.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ class DataProduction(
2121
):
2222
"""Dataset production context metadata."""
2323

24-
owner: str | None = None
25-
license: str | None = None
2624
type: str | None = None
2725
physics: str | None = None
2826
simulator: str | None = None
@@ -52,15 +50,26 @@ class Infos(
5250
):
5351
"""Structured representation of a PLAID dataset ``infos`` payload."""
5452

55-
# model_config = _PD_CONFIG
56-
5753
owner: str
5854
license: str
5955
data_production: DataProduction | None = None
6056
data_description: str | None = None
6157
num_samples: dict[str, int] = Field(default_factory=dict)
6258
storage_backend: str | None = None
6359

60+
@classmethod
61+
def print_available_fields(cls) -> None:
62+
"""Print the public constructor fields accepted by :class:`Infos`."""
63+
print("Infos fields:")
64+
for field_name in cls.model_fields:
65+
print(f" - {field_name}")
66+
if field_name in {"num_samples", "storage_backend"}:
67+
print(" note: automatically filled when calling save_to_disk")
68+
if field_name == "data_production":
69+
print(" subfields:")
70+
for subfield_name in DataProduction.model_fields:
71+
print(f" - {subfield_name}")
72+
6473
def require_persisted(self) -> "Infos":
6574
"""Validate fields that must exist in persisted dataset infos.
6675
@@ -142,8 +151,8 @@ def from_path(
142151
"""Load and validate an :class:`Infos` from a YAML file.
143152
144153
Args:
145-
path: Path to the YAML file (typically ``infos.yaml``) or to a
146-
directory containing it.
154+
path: Path to the YAML file (typically ``infos.yaml``). If no
155+
suffix is provided, ``.yaml`` is appended.
147156
require_persisted: When True, require storage-derived metadata
148157
fields expected in a complete on-disk dataset.
149158
@@ -152,10 +161,15 @@ def from_path(
152161
153162
Raises:
154163
FileNotFoundError: If the resolved YAML file does not exist.
164+
IsADirectoryError: If ``path`` points to a directory.
155165
"""
156166
path = Path(path)
157167
if path.is_dir():
158-
path = path / "infos.yaml"
168+
raise IsADirectoryError(
169+
f'Expected a YAML file path, got directory "{path}"'
170+
)
171+
if path.suffix != ".yaml":
172+
path = path.with_suffix(".yaml")
159173
if not path.exists():
160174
raise FileNotFoundError(f'File "{path}" does not exist. Abort')
161175

@@ -171,22 +185,28 @@ def save_to_file(self, path: Union[str, Path]) -> None:
171185
"""Save infos to ``path`` as a YAML file.
172186
173187
Args:
174-
path: File path (or directory) where the YAML will be written. If
175-
``path`` is a directory it will be extended with ``infos.yaml``.
188+
path: File path where the YAML will be written. If no suffix is
189+
provided, ``.yaml`` is appended.
190+
191+
Raises:
192+
IsADirectoryError: If ``path`` points to a directory.
176193
"""
194+
self.require_persisted()
195+
177196
path = Path(path)
178-
if path.suffix == "" and not path.exists():
179-
# Treat suffix-less paths as directories.
180-
path = path / "infos.yaml"
181-
elif path.is_dir():
182-
path = path / "infos.yaml"
197+
if path.is_dir():
198+
raise IsADirectoryError(
199+
f'Expected a YAML file path, got directory "{path}"'
200+
)
201+
183202
if path.suffix != ".yaml":
184203
path = path.with_suffix(".yaml")
185204

186205
path.parent.mkdir(parents=True, exist_ok=True)
187206

188207
data = self.model_dump(exclude_none=True, exclude_unset=True)
189208
ordered_data = {key: data[key] for key in _KEY_ORDER if key in data}
209+
190210
# Preserve any future fields.
191211
for key, value in data.items():
192212
if key not in ordered_data:

src/plaid/problem_definition.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121

2222
logger = logging.getLogger(__name__)
2323

24+
_KEY_ORDER = [
25+
"input_features",
26+
"output_features",
27+
"train_split",
28+
"test_split",
29+
]
30+
2431

2532
def _normalize_list(v):
2633
return sorted(map(str, v))
@@ -35,10 +42,6 @@ class ProblemDefinition(
3542
):
3643
"""Defines the input and output features for a machine learning problem."""
3744

38-
# model_config = ConfigDict(
39-
# revalidate_instances="always", validate_assignment=True, extra="forbid"
40-
# )
41-
4245
input_features: list[str]
4346
output_features: list[str]
4447
train_split: dict[str, Sequence[int] | Literal["all"]]
@@ -57,8 +60,13 @@ def from_path(cls, path: str | Path) -> "ProblemDefinition":
5760
5861
Raises:
5962
FileNotFoundError: If the resolved YAML file does not exist.
63+
IsADirectoryError: If ``path`` points to a directory.
6064
"""
6165
path = Path(path)
66+
if path.is_dir():
67+
raise IsADirectoryError(
68+
f'Expected a YAML file path, got directory "{path}"'
69+
)
6270
if path.suffix != ".yaml":
6371
path = path.with_suffix(".yaml")
6472
if not path.exists():
@@ -203,21 +211,18 @@ def save_to_file(self, path: Union[str, Path]) -> None:
203211
problem.save_to_file("/path/to/save_file")
204212
"""
205213
path = Path(path)
206-
path.parent.mkdir(parents=True, exist_ok=True)
214+
if path.is_dir():
215+
raise IsADirectoryError(
216+
f'Expected a YAML file path, got directory "{path}"'
217+
)
207218

208219
if path.suffix != ".yaml":
209220
path = path.with_suffix(".yaml")
210221

211-
data = self.model_dump()
212-
213-
key_order = [
214-
"input_features",
215-
"output_features",
216-
"train_split",
217-
"test_split",
218-
]
222+
path.parent.mkdir(parents=True, exist_ok=True)
219223

220-
ordered_data = {key: data[key] for key in key_order if key in data}
224+
data = self.model_dump()
225+
ordered_data = {key: data[key] for key in _KEY_ORDER if key in data}
221226

222227
# Save infos
223228
with path.open("w") as file:

0 commit comments

Comments
 (0)