Skip to content

Commit 3792a34

Browse files
committed
Generalize NumPy Serialization
1 parent 9706144 commit 3792a34

1 file changed

Lines changed: 10 additions & 42 deletions

File tree

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,20 @@
11
"""Private shared base class for multipole parameter groups.
22
33
Both :class:`MagneticMultipoleParameters` and :class:`ElectricMultipoleParameters`
4-
allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``).
5-
Because these fields are not declared with a type, Pydantic would otherwise
6-
store them as-is, preserving non-native numeric inputs like ``numpy.float64``.
7-
That breaks downstream YAML serialization (PyYAML emits unsafe Python-object
8-
tags for numpy scalars). See pals-project/pals-python#67.
4+
allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``) and
5+
share the same name-validation logic. This module centralizes that logic.
96
10-
This module centralizes the name-validation logic and adds numpy-to-native
11-
coercion at construction time.
7+
numpy interoperability (see pals-project/pals-python#67) is handled at the
8+
serialization boundary in :mod:`pals.functions`, which keeps the fix general:
9+
any numpy scalar reaching ``yaml.dump`` or ``json.dumps`` is converted to a
10+
Python-native equivalent regardless of which model produced it.
1211
"""
1312

1413
from typing import Any, ClassVar
1514

1615
from pydantic import BaseModel, ConfigDict, model_validator
1716

1817

19-
def _coerce_numpy_value(value: Any) -> Any:
20-
"""Convert numpy scalars/arrays to Python-native equivalents.
21-
22-
Recurses through ``list``/``tuple``/``dict`` containers so nested
23-
structures are also cleaned. Returns ``value`` unchanged when numpy is
24-
not installed or the value is not a numpy type. numpy remains an optional
25-
dependency of this project.
26-
"""
27-
try:
28-
import numpy as np
29-
except ImportError:
30-
return value
31-
32-
if isinstance(value, np.ndarray):
33-
if value.ndim == 0:
34-
return value.item()
35-
return _coerce_numpy_value(value.tolist())
36-
if isinstance(value, np.generic):
37-
return value.item()
38-
if isinstance(value, list):
39-
return [_coerce_numpy_value(v) for v in value]
40-
if isinstance(value, tuple):
41-
return tuple(_coerce_numpy_value(v) for v in value)
42-
if isinstance(value, dict):
43-
return {k: _coerce_numpy_value(v) for k, v in value.items()}
44-
return value
45-
46-
4718
def _validate_order(
4819
key_num: str, parameter_name: str, prefix: str, expected_format: str
4920
) -> None:
@@ -72,10 +43,9 @@ class _MultipoleBase(BaseModel):
7243

7344
@model_validator(mode="before")
7445
@classmethod
75-
def _validate_and_coerce(cls, values: dict[str, Any]) -> dict[str, Any]:
76-
"""Validate parameter names and coerce numpy values to Python natives."""
77-
coerced: dict[str, Any] = {}
78-
for key, value in values.items():
46+
def _validate(cls, values: dict[str, Any]) -> dict[str, Any]:
47+
"""Validate that all parameter names match the expected multipole format."""
48+
for key in values:
7949
is_length_integrated = key.endswith("L")
8050
base_key = key[:-1] if is_length_integrated else key
8151

@@ -103,6 +73,4 @@ def _validate_and_coerce(cls, values: dict[str, Any]) -> dict[str, Any]:
10373
f"where 'N' is a non-negative integer."
10474
)
10575

106-
coerced[key] = _coerce_numpy_value(value)
107-
108-
return coerced
76+
return values

0 commit comments

Comments
 (0)