|
1 | 1 | """Private shared base class for multipole parameter groups. |
2 | 2 |
|
3 | 3 | Both :class:`MagneticMultipoleParameters` and :class:`ElectricMultipoleParameters` |
4 | | -allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``). |
5 | | -Because these fields are not declared with a type, Pydantic would otherwise |
6 | | -store them as-is, preserving non-native numeric inputs like ``numpy.float64``. |
7 | | -That breaks downstream YAML serialization (PyYAML emits unsafe Python-object |
8 | | -tags for numpy scalars). See pals-project/pals-python#67. |
| 4 | +allow arbitrary order-indexed extra fields (e.g. ``Bn1``, ``Es3``, ``Kn0L``) and |
| 5 | +share the same name-validation logic. This module centralizes that logic. |
9 | 6 |
|
10 | | -This module centralizes the name-validation logic and adds numpy-to-native |
11 | | -coercion at construction time. |
| 7 | +numpy interoperability (see pals-project/pals-python#67) is handled at the |
| 8 | +serialization boundary in :mod:`pals.functions`, which keeps the fix general: |
| 9 | +any numpy scalar reaching ``yaml.dump`` or ``json.dumps`` is converted to a |
| 10 | +Python-native equivalent regardless of which model produced it. |
12 | 11 | """ |
13 | 12 |
|
14 | 13 | from typing import Any, ClassVar |
15 | 14 |
|
16 | 15 | from pydantic import BaseModel, ConfigDict, model_validator |
17 | 16 |
|
18 | 17 |
|
19 | | -def _coerce_numpy_value(value: Any) -> Any: |
20 | | - """Convert numpy scalars/arrays to Python-native equivalents. |
21 | | -
|
22 | | - Recurses through ``list``/``tuple``/``dict`` containers so nested |
23 | | - structures are also cleaned. Returns ``value`` unchanged when numpy is |
24 | | - not installed or the value is not a numpy type. numpy remains an optional |
25 | | - dependency of this project. |
26 | | - """ |
27 | | - try: |
28 | | - import numpy as np |
29 | | - except ImportError: |
30 | | - return value |
31 | | - |
32 | | - if isinstance(value, np.ndarray): |
33 | | - if value.ndim == 0: |
34 | | - return value.item() |
35 | | - return _coerce_numpy_value(value.tolist()) |
36 | | - if isinstance(value, np.generic): |
37 | | - return value.item() |
38 | | - if isinstance(value, list): |
39 | | - return [_coerce_numpy_value(v) for v in value] |
40 | | - if isinstance(value, tuple): |
41 | | - return tuple(_coerce_numpy_value(v) for v in value) |
42 | | - if isinstance(value, dict): |
43 | | - return {k: _coerce_numpy_value(v) for k, v in value.items()} |
44 | | - return value |
45 | | - |
46 | | - |
47 | 18 | def _validate_order( |
48 | 19 | key_num: str, parameter_name: str, prefix: str, expected_format: str |
49 | 20 | ) -> None: |
@@ -72,10 +43,9 @@ class _MultipoleBase(BaseModel): |
72 | 43 |
|
73 | 44 | @model_validator(mode="before") |
74 | 45 | @classmethod |
75 | | - def _validate_and_coerce(cls, values: dict[str, Any]) -> dict[str, Any]: |
76 | | - """Validate parameter names and coerce numpy values to Python natives.""" |
77 | | - coerced: dict[str, Any] = {} |
78 | | - for key, value in values.items(): |
| 46 | + def _validate(cls, values: dict[str, Any]) -> dict[str, Any]: |
| 47 | + """Validate that all parameter names match the expected multipole format.""" |
| 48 | + for key in values: |
79 | 49 | is_length_integrated = key.endswith("L") |
80 | 50 | base_key = key[:-1] if is_length_integrated else key |
81 | 51 |
|
@@ -103,6 +73,4 @@ def _validate_and_coerce(cls, values: dict[str, Any]) -> dict[str, Any]: |
103 | 73 | f"where 'N' is a non-negative integer." |
104 | 74 | ) |
105 | 75 |
|
106 | | - coerced[key] = _coerce_numpy_value(value) |
107 | | - |
108 | | - return coerced |
| 76 | + return values |
0 commit comments