Skip to content

Commit 9706144

Browse files
committed
Fix: NumPy Serialization
1 parent f514cc2 commit 9706144

4 files changed

Lines changed: 193 additions & 102 deletions

File tree

src/pals/functions.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ def load_file_to_dict(filename: str) -> dict:
5454
return pals_data
5555

5656

57+
def _numpy_to_native(obj):
58+
"""Convert a numpy scalar/array to its Python-native equivalent.
59+
60+
Returns ``None`` when the object is not a numpy type or when numpy is not
61+
installed; callers use that to decide whether to fall back to the default
62+
serializer behavior. numpy is an optional dependency.
63+
"""
64+
try:
65+
import numpy as np
66+
except ImportError:
67+
return None
68+
69+
if isinstance(obj, np.ndarray):
70+
return obj.tolist()
71+
if isinstance(obj, np.generic):
72+
return obj.item()
73+
return None
74+
75+
5776
def store_dict_to_file(filename: str, pals_dict: dict):
5877
file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions(
5978
filename
@@ -63,14 +82,58 @@ def store_dict_to_file(filename: str, pals_dict: dict):
6382
if extension == ".json":
6483
import json
6584

66-
json_data = json.dumps(pals_dict, sort_keys=False, indent=2)
85+
def _json_default(obj):
86+
native = _numpy_to_native(obj)
87+
if native is not None:
88+
return native
89+
raise TypeError(
90+
f"Object of type {type(obj).__name__} is not JSON serializable"
91+
)
92+
93+
json_data = json.dumps(
94+
pals_dict, sort_keys=False, indent=2, default=_json_default
95+
)
6796
with open(filename, "w") as file:
6897
file.write(json_data)
6998

7099
elif extension == ".yaml":
71100
import yaml
72101

73-
yaml_data = yaml.dump(pals_dict, default_flow_style=False, sort_keys=False)
102+
# Subclass the safe dumper so numpy representers are scoped to PALS
103+
# serialization and do not leak into the global pyyaml state used by
104+
# other code in the same process.
105+
class _PALSDumper(yaml.SafeDumper):
106+
pass
107+
108+
try:
109+
import numpy as np
110+
except ImportError:
111+
np = None
112+
113+
if np is not None:
114+
115+
def _represent_numpy_scalar(dumper, value):
116+
native = value.item()
117+
if isinstance(native, bool):
118+
return dumper.represent_bool(native)
119+
if isinstance(native, int):
120+
return dumper.represent_int(native)
121+
if isinstance(native, float):
122+
return dumper.represent_float(native)
123+
return dumper.represent_data(native)
124+
125+
def _represent_numpy_array(dumper, value):
126+
return dumper.represent_list(value.tolist())
127+
128+
_PALSDumper.add_multi_representer(np.generic, _represent_numpy_scalar)
129+
_PALSDumper.add_representer(np.ndarray, _represent_numpy_array)
130+
131+
yaml_data = yaml.dump(
132+
pals_dict,
133+
Dumper=_PALSDumper,
134+
default_flow_style=False,
135+
sort_keys=False,
136+
)
74137
with open(filename, "w") as file:
75138
file.write(yaml_data)
76139

Lines changed: 9 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,9 @@
1-
from pydantic import BaseModel, ConfigDict, model_validator
2-
from typing import Any
1+
from typing import ClassVar
32

4-
# Valid parameter prefixes, their expected format and description
5-
_PARAMETER_PREFIXES = {
6-
"tilt": ("tiltN", "Tilt"),
7-
"En": ("EnN", "Normal component"),
8-
"Es": ("EsN", "Skew component"),
9-
}
3+
from pals.parameters._multipole_base import _MultipoleBase
104

115

12-
def _validate_order(
13-
key_num: str, parameter_name: str, prefix: str, expected_format: str
14-
) -> None:
15-
"""Validate that the order number is a non-negative integer without leading zeros."""
16-
error_msg = (
17-
f"Invalid {parameter_name}: '{prefix}{key_num}'. "
18-
f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros."
19-
)
20-
if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"):
21-
raise ValueError(error_msg)
22-
23-
24-
class ElectricMultipoleParameters(BaseModel):
6+
class ElectricMultipoleParameters(_MultipoleBase):
257
"""Electric multipole parameters
268
279
Valid parameter formats:
@@ -33,31 +15,9 @@ class ElectricMultipoleParameters(BaseModel):
3315
Where N is a positive integer without leading zeros (except "0" itself).
3416
"""
3517

36-
model_config = ConfigDict(extra="allow")
37-
38-
@model_validator(mode="before")
39-
@classmethod
40-
def validate(cls, values: dict[str, Any]) -> dict[str, Any]:
41-
"""Validate all parameter names match the expected multipole format."""
42-
for key in values:
43-
# Check if key ends with 'L' for length-integrated values
44-
is_length_integrated = key.endswith("L")
45-
base_key = key[:-1] if is_length_integrated else key
46-
47-
# No length-integrated values allowed for tilt parameter
48-
if is_length_integrated and base_key.startswith("tilt"):
49-
raise ValueError(f"Invalid electric multipole parameter: '{key}'. ")
50-
51-
# Find matching prefix
52-
for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items():
53-
if base_key.startswith(prefix):
54-
key_num = base_key[len(prefix) :]
55-
_validate_order(key_num, description, prefix, expected_format)
56-
break
57-
else:
58-
raise ValueError(
59-
f"Invalid electric multipole parameter: '{key}'. "
60-
f"Parameters must be of the form 'tiltN', 'EnN', or 'EsN' "
61-
f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer."
62-
)
63-
return values
18+
_KIND_NAME: ClassVar[str] = "electric"
19+
_PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = {
20+
"tilt": ("tiltN", "Tilt"),
21+
"En": ("EnN", "Normal component"),
22+
"Es": ("EsN", "Skew component"),
23+
}
Lines changed: 11 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,9 @@
1-
from pydantic import BaseModel, ConfigDict, model_validator
2-
from typing import Any
1+
from typing import ClassVar
32

4-
# Valid parameter prefixes, their expected format and description
5-
_PARAMETER_PREFIXES = {
6-
"tilt": ("tiltN", "Tilt"),
7-
"Bn": ("BnN", "Normal component"),
8-
"Bs": ("BsN", "Skew component"),
9-
"Kn": ("KnN", "Normalized normal component"),
10-
"Ks": ("KsN", "Normalized skew component"),
11-
}
3+
from pals.parameters._multipole_base import _MultipoleBase
124

135

14-
def _validate_order(
15-
key_num: str, parameter_name: str, prefix: str, expected_format: str
16-
) -> None:
17-
"""Validate that the order number is a non-negative integer without leading zeros."""
18-
error_msg = (
19-
f"Invalid {parameter_name}: '{prefix}{key_num}'. "
20-
f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros."
21-
)
22-
if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"):
23-
raise ValueError(error_msg)
24-
25-
26-
class MagneticMultipoleParameters(BaseModel):
6+
class MagneticMultipoleParameters(_MultipoleBase):
277
"""Magnetic multipole parameters
288
299
Valid parameter formats:
@@ -37,31 +17,11 @@ class MagneticMultipoleParameters(BaseModel):
3717
Where N is a positive integer without leading zeros (except "0" itself).
3818
"""
3919

40-
model_config = ConfigDict(extra="allow")
41-
42-
@model_validator(mode="before")
43-
@classmethod
44-
def validate(cls, values: dict[str, Any]) -> dict[str, Any]:
45-
"""Validate all parameter names match the expected multipole format."""
46-
for key in values:
47-
# Check if key ends with 'L' for length-integrated values
48-
is_length_integrated = key.endswith("L")
49-
base_key = key[:-1] if is_length_integrated else key
50-
51-
# No length-integrated values allowed for tilt parameter
52-
if is_length_integrated and base_key.startswith("tilt"):
53-
raise ValueError(f"Invalid magnetic multipole parameter: '{key}'. ")
54-
55-
# Find matching prefix
56-
for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items():
57-
if base_key.startswith(prefix):
58-
key_num = base_key[len(prefix) :]
59-
_validate_order(key_num, description, prefix, expected_format)
60-
break
61-
else:
62-
raise ValueError(
63-
f"Invalid magnetic multipole parameter: '{key}'. "
64-
f"Parameters must be of the form 'tiltN', 'BnN', 'BsN', 'KnN', or 'KsN' "
65-
f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer."
66-
)
67-
return values
20+
_KIND_NAME: ClassVar[str] = "magnetic"
21+
_PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = {
22+
"tilt": ("tiltN", "Tilt"),
23+
"Bn": ("BnN", "Normal component"),
24+
"Bs": ("BsN", "Skew component"),
25+
"Kn": ("KnN", "Normalized normal component"),
26+
"Ks": ("KsN", "Normalized skew component"),
27+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Private shared base class for multipole parameter groups.
2+
3+
Both :class:`MagneticMultipoleParameters` and :class:`ElectricMultipoleParameters`
4+
allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``).
5+
Because these fields are not declared with a type, Pydantic would otherwise
6+
store them as-is, preserving non-native numeric inputs like ``numpy.float64``.
7+
That breaks downstream YAML serialization (PyYAML emits unsafe Python-object
8+
tags for numpy scalars). See pals-project/pals-python#67.
9+
10+
This module centralizes the name-validation logic and adds numpy-to-native
11+
coercion at construction time.
12+
"""
13+
14+
from typing import Any, ClassVar
15+
16+
from pydantic import BaseModel, ConfigDict, model_validator
17+
18+
19+
def _coerce_numpy_value(value: Any) -> Any:
20+
"""Convert numpy scalars/arrays to Python-native equivalents.
21+
22+
Recurses through ``list``/``tuple``/``dict`` containers so nested
23+
structures are also cleaned. Returns ``value`` unchanged when numpy is
24+
not installed or the value is not a numpy type. numpy remains an optional
25+
dependency of this project.
26+
"""
27+
try:
28+
import numpy as np
29+
except ImportError:
30+
return value
31+
32+
if isinstance(value, np.ndarray):
33+
if value.ndim == 0:
34+
return value.item()
35+
return _coerce_numpy_value(value.tolist())
36+
if isinstance(value, np.generic):
37+
return value.item()
38+
if isinstance(value, list):
39+
return [_coerce_numpy_value(v) for v in value]
40+
if isinstance(value, tuple):
41+
return tuple(_coerce_numpy_value(v) for v in value)
42+
if isinstance(value, dict):
43+
return {k: _coerce_numpy_value(v) for k, v in value.items()}
44+
return value
45+
46+
47+
def _validate_order(
48+
key_num: str, parameter_name: str, prefix: str, expected_format: str
49+
) -> None:
50+
"""Validate that the order number is a non-negative integer without leading zeros."""
51+
error_msg = (
52+
f"Invalid {parameter_name}: '{prefix}{key_num}'. "
53+
f"Parameter must be of the form '{expected_format}', "
54+
f"where 'N' is a non-negative integer without leading zeros."
55+
)
56+
if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"):
57+
raise ValueError(error_msg)
58+
59+
60+
class _MultipoleBase(BaseModel):
61+
"""Private shared base for multipole parameter groups.
62+
63+
Subclasses must set :attr:`_PARAMETER_PREFIXES` and :attr:`_KIND_NAME`.
64+
Both are ``ClassVar`` and are not exposed as Pydantic fields.
65+
"""
66+
67+
# Subclasses override these:
68+
_PARAMETER_PREFIXES: ClassVar[dict[str, tuple[str, str]]] = {}
69+
_KIND_NAME: ClassVar[str] = "multipole"
70+
71+
model_config = ConfigDict(extra="allow")
72+
73+
@model_validator(mode="before")
74+
@classmethod
75+
def _validate_and_coerce(cls, values: dict[str, Any]) -> dict[str, Any]:
76+
"""Validate parameter names and coerce numpy values to Python natives."""
77+
coerced: dict[str, Any] = {}
78+
for key, value in values.items():
79+
is_length_integrated = key.endswith("L")
80+
base_key = key[:-1] if is_length_integrated else key
81+
82+
if is_length_integrated and base_key.startswith("tilt"):
83+
raise ValueError(
84+
f"Invalid {cls._KIND_NAME} multipole parameter: '{key}'. "
85+
)
86+
87+
for prefix, (
88+
expected_format,
89+
description,
90+
) in cls._PARAMETER_PREFIXES.items():
91+
if base_key.startswith(prefix):
92+
key_num = base_key[len(prefix) :]
93+
_validate_order(key_num, description, prefix, expected_format)
94+
break
95+
else:
96+
prefix_list = ", ".join(
97+
f"'{p}N'" for p in cls._PARAMETER_PREFIXES if p != "tilt"
98+
)
99+
raise ValueError(
100+
f"Invalid {cls._KIND_NAME} multipole parameter: '{key}'. "
101+
f"Parameters must be of the form 'tiltN', {prefix_list} "
102+
f"(with optional 'L' suffix for length-integrated), "
103+
f"where 'N' is a non-negative integer."
104+
)
105+
106+
coerced[key] = _coerce_numpy_value(value)
107+
108+
return coerced

0 commit comments

Comments
 (0)