Skip to content

Commit 73ed25f

Browse files
committed
style(cli): align dataset builder with plaid conventions
1 parent 7efc4b6 commit 73ed25f

2 files changed

Lines changed: 89 additions & 14 deletions

File tree

src/plaid/cli/build_dataset.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import csv
77
import shutil
88
from pathlib import Path
9-
from typing import Optional
109

1110
import numpy as np
1211

@@ -112,7 +111,7 @@ def _validate_raw_layout(
112111
input_scalars_dir_name: str,
113112
output_scalars_dir_name: str,
114113
file_glob: str,
115-
field_dirs: Optional[list[str]],
114+
field_dirs: list[str] | None,
116115
) -> tuple[FileMap, FileMap, FieldMaps]:
117116
"""Validate raw layout and return discovered file maps."""
118117
in_scalars_dir = input_dir / input_scalars_dir_name
@@ -191,11 +190,23 @@ def _initialize_sample_geometry(
191190
)
192191
raise ValueError(msg)
193192

194-
first_field = next(iter(fields.values()))
193+
first_field_name, first_field = next(iter(fields.items()))
195194
n_nodes = first_field.shape[0] if first_field.ndim > 1 else first_field.size
196195
if n_nodes <= 0:
197196
raise ValueError("Field size must be positive.")
198197

198+
for field_name, field_array in fields.items():
199+
local_n_nodes = (
200+
field_array.shape[0] if field_array.ndim > 1 else field_array.size
201+
)
202+
if local_n_nodes != n_nodes:
203+
msg = (
204+
"All fields must share the same number of support points. "
205+
f"Reference field '{first_field_name}' has {n_nodes}, "
206+
f"field '{field_name}' has {local_n_nodes}."
207+
)
208+
raise ValueError(msg)
209+
199210
if sample.features is None:
200211
raise ValueError("Sample features are not initialized.")
201212
features = sample.features
@@ -220,7 +231,7 @@ def build_dataset_from_raw(
220231
output_dir: Path,
221232
input_scalars_dir_name: str = "input_scalars",
222233
output_scalars_dir_name: str = "output_scalars",
223-
field_dirs: Optional[list[str]] = None,
234+
field_dirs: list[str] | None = None,
224235
field_location: str = "Vertex",
225236
base_name: str = "Base_1_1",
226237
zone_name: str = "Zone",
@@ -312,11 +323,33 @@ def build_parser() -> argparse.ArgumentParser:
312323
"in a single command."
313324
)
314325
)
315-
parser.add_argument("--input-dir", type=Path, required=True)
316-
parser.add_argument("--output-dir", type=Path, required=True)
317-
parser.add_argument("--overwrite", action="store_true")
318-
parser.add_argument("--input-scalars-dir", default="input_scalars")
319-
parser.add_argument("--output-scalars-dir", default="output_scalars")
326+
parser.add_argument(
327+
"--input-dir",
328+
type=Path,
329+
required=True,
330+
help="Path to the raw input directory.",
331+
)
332+
parser.add_argument(
333+
"--output-dir",
334+
type=Path,
335+
required=True,
336+
help="Path where the PLAID dataset will be written.",
337+
)
338+
parser.add_argument(
339+
"--overwrite",
340+
action="store_true",
341+
help="Overwrite output directory if it already exists.",
342+
)
343+
parser.add_argument(
344+
"--input-scalars-dir",
345+
default="input_scalars",
346+
help="Subdirectory name that contains input scalar CSV files.",
347+
)
348+
parser.add_argument(
349+
"--output-scalars-dir",
350+
default="output_scalars",
351+
help="Subdirectory name that contains output scalar CSV files.",
352+
)
320353
parser.add_argument(
321354
"--field-dirs",
322355
nargs="*",
@@ -326,11 +359,32 @@ def build_parser() -> argparse.ArgumentParser:
326359
"all non-scalar subdirs are used."
327360
),
328361
)
329-
parser.add_argument("--field-location", default="Vertex")
330-
parser.add_argument("--base-name", default="Base_1_1")
331-
parser.add_argument("--zone-name", default="Zone")
332-
parser.add_argument("--file-glob", default="scalars_*.csv")
333-
parser.add_argument("--verbose", action="store_true")
362+
parser.add_argument(
363+
"--field-location",
364+
default="Vertex",
365+
choices=["Vertex"],
366+
help="Field location in PLAID samples (currently only 'Vertex').",
367+
)
368+
parser.add_argument(
369+
"--base-name",
370+
default="Base_1_1",
371+
help="Base name used when initializing sample features.",
372+
)
373+
parser.add_argument(
374+
"--zone-name",
375+
default="Zone",
376+
help="Zone name used when initializing sample features.",
377+
)
378+
parser.add_argument(
379+
"--file-glob",
380+
default="scalars_*.csv",
381+
help="Glob pattern used to discover CSV files in each subdirectory.",
382+
)
383+
parser.add_argument(
384+
"--verbose",
385+
action="store_true",
386+
help="Print a completion message with the generated output path.",
387+
)
334388
return parser
335389

336390

tests/cli/test_build_dataset.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,27 @@ def test_build_dataset_from_raw_invalid_location(tmp_path: Path) -> None:
153153
)
154154

155155

156+
def test_build_dataset_from_raw_field_size_mismatch(tmp_path: Path) -> None:
157+
raw = tmp_path / "raw"
158+
out = tmp_path / "out"
159+
160+
_write_csv(
161+
raw / "input_scalars" / "scalars_00000.csv",
162+
"in_a",
163+
"1.0",
164+
)
165+
_write_csv(
166+
raw / "output_scalars" / "scalars_00000.csv",
167+
"out_y",
168+
"2.0",
169+
)
170+
_write_field(raw / "field_1" / "scalars_00000.csv", [1.0, 2.0, 3.0])
171+
_write_field(raw / "field_2" / "scalars_00000.csv", [4.0, 5.0])
172+
173+
with pytest.raises(ValueError, match="share the same number"):
174+
build_dataset_from_raw(input_dir=raw, output_dir=out)
175+
176+
156177
def test_build_dataset_from_raw_overwrite(tmp_path: Path) -> None:
157178
raw = tmp_path / "raw"
158179
out = tmp_path / "out"

0 commit comments

Comments
 (0)