Skip to content

Commit 8a3ddf5

Browse files
committed
🚀 simplify Sample and move to pydantic v2 conventions
1 parent 0c17f81 commit 8a3ddf5

11 files changed

Lines changed: 40 additions & 62 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818

1919
### Changed
2020

21+
- (sample) **Breaking**: remove the `path` constructor argument from `Sample`; use `Sample.load_from_dir(path)` or `sample.load(path)` for explicit disk loading.
2122
- (infos) add an infos.py file with a pydantic Info class at the same level as ProblemDefinition.
2223
- (cgns_backend) remove constant/variable mechanism in reading and writing paths for the cgns_backend.
2324
- (docs) update logos and migrate to Zensical.

docs/source/concepts/sample.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ sample.save_to_dir("sample_000000000", overwrite=True)
9393
sample = Sample.load_from_dir("sample_000000000")
9494
```
9595

96+
`Sample()` creates an empty in-memory sample. Loading from disk is explicit:
97+
use `Sample.load_from_dir(path)` to create a loaded sample, or instantiate an
98+
empty sample and call `sample.load(path)`.
99+
96100
The CGNS files can be open in Paraview for advanced visualization and post-treatment.
97101

98102
For large samples, `save_to_dir(..., memory_safe=True)` writes CGNS files through

docs/source/upgrade_guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ the [reader API](api/storage/reader.md), and the [backend API](api/storage/backe
165165
| `Dataset()` + `add_sample` / `add_samples` / `from_list_of_samples` | `save_to_disk(sample_constructor=..., ids=...)` |
166166
| `Dataset.save_to_dir(path)` / `add_to_dir` | `save_to_disk(path, sample_constructor=..., ids=...)` |
167167
| `Dataset.load(path)` | `init_from_disk(path)``(datasetdict, converterdict)` |
168+
| `Sample(path=sample_dir)` | `Sample.load_from_dir(sample_dir)` or `sample = Sample(); sample.load(sample_dir)` |
168169
| `dataset[i]` / `get_samples()` | `converter.to_plaid(dataset, i)` |
169170
| `len(dataset)` / `get_number_of_samples()` | `len(dataset)` (per-split backend object) |
170171
| `dataset.set_infos(...)` / `get_infos()` | pass [`Infos`](concepts/infos.md) to `save_to_disk(infos=...)`; read back with `Infos.from_path(path)` |

examples/containers/sample_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,10 +559,10 @@
559559
sample.save_to_dir(sample_save_fname)
560560

561561
# %% [markdown]
562-
# #### Load a Sample from a directory via initialization
562+
# #### Load a Sample from a directory via the Sample class
563563

564564
# %%
565-
new_sample = Sample(path=sample_save_fname)
565+
new_sample = Sample.load_from_dir(sample_save_fname)
566566

567567
print(sample.summarize())
568568

src/plaid/containers/managers/default_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def set_default_time(self, time: float) -> None:
8383
.. code-block:: python
8484
8585
from plaid import Sample
86-
sample = Sample("path_to_plaid_sample")
86+
sample = Sample.load_from_dir("path_to_plaid_sample")
8787
print(sample)
8888
>>> Sample(2 scalars, 1 timestamp, 5 fields)
8989
print(sample.show_tree(0.5))
@@ -122,7 +122,7 @@ def set_default_base(self, base: str, time: Optional[float] = None) -> None:
122122
.. code-block:: python
123123
124124
from plaid import Sample
125-
sample = Sample("path_to_plaid_sample")
125+
sample = Sample.load_from_dir("path_to_plaid_sample")
126126
print(sample)
127127
>>> Sample(2 scalars, 1 timestamp, 5 fields)
128128
print(sample.get_physical_dim("BaseA", 0.5))
@@ -173,7 +173,7 @@ def set_default_zone_base(
173173
.. code-block:: python
174174
175175
from plaid import Sample
176-
sample = Sample("path_to_plaid_sample")
176+
sample = Sample.load_from_dir("path_to_plaid_sample")
177177
print(sample)
178178
>>> Sample(2 scalars, 1 timestamp, 5 fields)
179179
print(sample.get_zone_type("ZoneX", "BaseA", 0.5))

src/plaid/containers/sample.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
import subprocess
2121
from copy import deepcopy
2222
from pathlib import Path
23-
from typing import Any, Optional, Union
23+
from typing import Optional, Union
2424

2525
import CGNS.MAP as CGM
2626
import CGNS.PAT.cgnskeywords as CGK
2727
import CGNS.PAT.cgnslib as CGL
2828
import CGNS.PAT.cgnsutils as CGU
2929
import numpy as np
3030
from CGNS.PAT.cgnsutils import __CHILDREN__, __NAME__
31-
from pydantic import BaseModel, ConfigDict, model_validator
31+
from pydantic import BaseModel, model_validator
3232
from pydantic import Field as PydanticField
3333

3434
from ..constants import (
@@ -48,35 +48,24 @@
4848
CGNS_WORKER = Path(__file__).parent.parent / "utils" / "cgns_worker.py"
4949

5050

51-
class Sample(BaseModel):
51+
class Sample(BaseModel, arbitrary_types_allowed=True, extra="forbid"):
5252
"""Represents a single sample. It contains data and information related to a single observation or measurement within a dataset.
5353
54-
By default, the sample is empty but:
55-
- You can provide a path to a folder containing the sample data, and it will be loaded during initialization.
54+
By default, the sample is empty. Use :meth:`Sample.load_from_dir` or
55+
:meth:`Sample.load` to load sample data from disk.
5656
5757
Note:
5858
Mesh/field/global operations are directly implemented on ``Sample`` via
5959
inheritance from internal feature operations.
6060
"""
6161

62-
# Pydantic configuration
63-
# TODO(FB) check why arbitrary_types_allowed is needed, and if it can be removed
64-
model_config = ConfigDict(
65-
arbitrary_types_allowed=True, revalidate_instances="always", extra="forbid"
66-
)
67-
6862
# Attributes
69-
path: Optional[Union[str, Path]] = PydanticField(
70-
None,
71-
description="Path to the folder containing the sample data. If provided, the sample will be loaded from this path during initialization. Defaults to None.",
72-
)
73-
7463
data: dict[float, CGNSTree] = PydanticField(
7564
default_factory=dict,
7665
description="A dictionary mapping time steps to CGNS trees.",
7766
)
7867

79-
defaults: DefaultManager = PydanticField(
68+
defaults: Optional[DefaultManager] = PydanticField(
8069
default=None,
8170
exclude=True,
8271
repr=False,
@@ -90,13 +79,6 @@ def initialize_defaults(self) -> "Sample":
9079
self.defaults = DefaultManager(self)
9180
return self
9281

93-
def model_post_init(self, _context: Any) -> None:
94-
"""Run post-initialization hooks (e.g. load sample from path)."""
95-
# Load if path is provided
96-
if self.path is not None:
97-
path = Path(self.path)
98-
self.load(path)
99-
10082
def copy(self) -> Self: # pyright: ignore[reportIncompatibleMethodOverride]
10183
"""Create a deep copy of the current `Sample` instance.
10284

src/plaid/infos.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,11 @@
77
from typing import Any, Union
88

99
import yaml
10-
from pydantic import BaseModel, ConfigDict, Field, ValidationError
11-
from pydantic.dataclasses import dataclass
10+
from pydantic import BaseModel, Field, ValidationError
1211

1312
logger = logging.getLogger(__name__)
1413

15-
_PD_CONFIG = ConfigDict(
16-
extra="forbid", str_strip_whitespace=True, validate_assignment=True
17-
)
18-
19-
20-
@dataclass(config=_PD_CONFIG)
21-
class DataProduction:
14+
class DataProduction(BaseModel, revalidate_instances="always", str_strip_whitespace=True, validate_assignment=True, extra="forbid"):
2215
"""Dataset production context metadata."""
2316

2417
owner: str | None = None
@@ -43,10 +36,10 @@ class DataProduction:
4336
)
4437

4538

46-
class Infos(BaseModel):
39+
class Infos(BaseModel, revalidate_instances="always", str_strip_whitespace=True, validate_assignment=True, extra="forbid"):
4740
"""Structured representation of a PLAID dataset ``infos`` payload."""
4841

49-
model_config = _PD_CONFIG
42+
# model_config = _PD_CONFIG
5043

5144
owner: str
5245
license: str

src/plaid/problem_definition.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Any, Literal, Sequence, Union
1616

1717
import yaml
18-
from pydantic import BaseModel, ConfigDict, field_validator
18+
from pydantic import BaseModel, field_validator
1919

2020
# %% Globals
2121

@@ -26,12 +26,12 @@ def _normalize_list(v):
2626
return sorted(map(str, v))
2727

2828

29-
class ProblemDefinition(BaseModel):
29+
class ProblemDefinition(BaseModel, revalidate_instances="always", str_strip_whitespace=True, validate_assignment=True, extra="forbid"):
3030
"""Defines the input and output features for a machine learning problem."""
3131

32-
model_config = ConfigDict(
33-
revalidate_instances="always", validate_assignment=True, extra="forbid"
34-
)
32+
# model_config = ConfigDict(
33+
# revalidate_instances="always", validate_assignment=True, extra="forbid"
34+
# )
3535

3636
input_features: list[str]
3737
output_features: list[str]

src/plaid/storage/cgns/reader.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __getitem__(self, idx: int) -> Sample:
7575
Sample: A PLAID Sample object.
7676
"""
7777
assert idx in self.ids
78-
return Sample(path=self.path / f"sample_{idx:09d}")
78+
return Sample.load_from_dir(self.path / f"sample_{idx:09d}")
7979

8080
def __len__(self) -> int:
8181
"""Get the number of samples in the dataset.
@@ -122,12 +122,9 @@ def sample_generator(
122122
allow_patterns=[f"data/{split}/sample_{idx:09d}/"],
123123
local_dir=temp_folder,
124124
)
125-
sample = Sample(
126-
path=Path(temp_folder) / "data" / f"{split}" / f"sample_{idx:09d}"
125+
sample = Sample.load_from_dir(
126+
Path(temp_folder) / "data" / f"{split}" / f"sample_{idx:09d}"
127127
)
128-
# Sample data are eagerly loaded in memory during initialization;
129-
# clear the transient on-disk path before leaving the temp dir.
130-
sample.path = None
131128
yield sample
132129

133130

src/plaid/storage/common/bridge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def to_plaid_sample(
208208
for time, flat_tree in sample_dict.items():
209209
sample_data[time] = unflatten_cgns_tree(flat_tree, cgns_types)
210210

211-
return Sample(path=None, data=sample_data)
211+
return Sample(data=sample_data)
212212

213213

214214
def plaid_to_sample_dict(

0 commit comments

Comments
 (0)