Skip to content

Commit f514cc2

Browse files
committed
Test: NumPy Multipole Serialization
1 parent 1db2ef3 commit f514cc2

3 files changed

Lines changed: 125 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ classifiers = [
4141
]
4242

4343
[project.optional-dependencies]
44-
test = ["pytest"]
44+
test = ["pytest", "numpy"]
4545

4646
[project.urls]
4747
Documentation = "https://pals-project.readthedocs.io"

tests/test_parameters.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,53 @@ def test_ParameterClasses():
134134
# Test BeamBeamParameters
135135
beambeam = BeamBeamParameters()
136136
assert beambeam is not None
137+
138+
139+
def test_multipole_numpy_coercion():
140+
"""Regression test for issue #67: numpy scalars passed to multipole parameter
141+
classes must be coerced to Python-native numeric types at construction time,
142+
so YAML/JSON serialization produces clean output regardless of input type."""
143+
np = pytest.importorskip("numpy")
144+
145+
# MagneticMultipoleParameters: cover all prefixes and several numpy dtypes
146+
mmp = MagneticMultipoleParameters(
147+
tilt1=np.float64(0.1),
148+
Bn1=np.float64(1.5),
149+
Bn2=np.float32(2.5),
150+
Bs1=np.int64(3),
151+
Kn0=np.int32(-1),
152+
Ks1=np.float64(0.25),
153+
)
154+
assert type(mmp.tilt1) is float and mmp.tilt1 == 0.1
155+
assert type(mmp.Bn1) is float and mmp.Bn1 == 1.5
156+
assert type(mmp.Bn2) is float and mmp.Bn2 == 2.5
157+
assert type(mmp.Bs1) is int and mmp.Bs1 == 3
158+
assert type(mmp.Kn0) is int and mmp.Kn0 == -1
159+
assert type(mmp.Ks1) is float and mmp.Ks1 == 0.25
160+
161+
# 0-d numpy array also works
162+
mmp_arr = MagneticMultipoleParameters(Bn1=np.array(4.2))
163+
assert type(mmp_arr.Bn1) is float and mmp_arr.Bn1 == 4.2
164+
165+
# Length-integrated variants
166+
mmp_L = MagneticMultipoleParameters(Bn1L=np.float64(7.0), Ks1L=np.float64(8.0))
167+
assert type(mmp_L.Bn1L) is float and mmp_L.Bn1L == 7.0
168+
assert type(mmp_L.Ks1L) is float and mmp_L.Ks1L == 8.0
169+
170+
# ElectricMultipoleParameters: cover all prefixes
171+
emp = ElectricMultipoleParameters(
172+
tilt1=np.float64(0.2),
173+
En1=np.float64(0.5),
174+
Es1=np.int64(2),
175+
)
176+
assert type(emp.tilt1) is float and emp.tilt1 == 0.2
177+
assert type(emp.En1) is float and emp.En1 == 0.5
178+
assert type(emp.Es1) is int and emp.Es1 == 2
179+
180+
emp_L = ElectricMultipoleParameters(En1L=np.float64(1.0), Es1L=np.float64(0.5))
181+
assert type(emp_L.En1L) is float and emp_L.En1L == 1.0
182+
assert type(emp_L.Es1L) is float and emp_L.Es1L == 0.5
183+
184+
# Plain Python values must still pass through unchanged
185+
mmp_plain = MagneticMultipoleParameters(Bn1=1.5)
186+
assert type(mmp_plain.Bn1) is float and mmp_plain.Bn1 == 1.5

tests/test_serialization.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
import pytest
4+
35
import pals
46

57

@@ -332,3 +334,75 @@ def test_comprehensive_lattice():
332334
# Clean up temporary files
333335
os.remove(yaml_file)
334336
os.remove(json_file)
337+
338+
339+
def _build_numpy_lattice(np):
340+
"""Build a small lattice using numpy-typed scalar values throughout."""
341+
quad = pals.Quadrupole(
342+
name="q_np",
343+
length=np.float64(0.061),
344+
MagneticMultipoleP=pals.MagneticMultipoleParameters(
345+
Bn1=np.float64(-26.0), Bs1=np.float32(0.5), Kn0=np.int64(-1)
346+
),
347+
)
348+
oct_ = pals.Octupole(
349+
name="o_np",
350+
length=np.float64(0.25),
351+
ElectricMultipoleP=pals.ElectricMultipoleParameters(
352+
En3=np.float64(0.75), Es3=np.float32(0.125)
353+
),
354+
)
355+
return pals.BeamLine(name="line_np", line=[quad, oct_])
356+
357+
358+
def test_yaml_roundtrip_with_numpy():
359+
"""Regression test for issue #67: writing YAML with numpy-typed values
360+
must not produce !!python/object tags, and round-tripping must yield
361+
Python-native floats with the correct numeric values."""
362+
np = pytest.importorskip("numpy")
363+
364+
line = _build_numpy_lattice(np)
365+
yaml_file = "numpy_roundtrip.pals.yaml"
366+
line.to_file(yaml_file)
367+
try:
368+
with open(yaml_file, "r") as f:
369+
text = f.read()
370+
371+
# The bug symptom: YAML contains opaque numpy object tags.
372+
assert "!!python/object" not in text, (
373+
f"YAML output still contains unsafe numpy object tags:\n{text}"
374+
)
375+
assert "numpy" not in text, f"YAML output still references numpy:\n{text}"
376+
377+
loaded = pals.BeamLine.from_file(yaml_file)
378+
loaded_quad = loaded.line[0]
379+
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
380+
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
381+
assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5
382+
assert loaded_quad.MagneticMultipoleP.Kn0 == -1
383+
384+
loaded_oct = loaded.line[1]
385+
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
386+
assert type(loaded_oct.ElectricMultipoleP.En3) is float
387+
finally:
388+
if os.path.exists(yaml_file):
389+
os.remove(yaml_file)
390+
391+
392+
def test_json_roundtrip_with_numpy():
393+
"""JSON path also needs to handle numpy values cleanly (defense-in-depth)."""
394+
np = pytest.importorskip("numpy")
395+
396+
line = _build_numpy_lattice(np)
397+
json_file = "numpy_roundtrip.pals.json"
398+
line.to_file(json_file)
399+
try:
400+
loaded = pals.BeamLine.from_file(json_file)
401+
loaded_quad = loaded.line[0]
402+
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
403+
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
404+
loaded_oct = loaded.line[1]
405+
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
406+
finally:
407+
if os.path.exists(json_file):
408+
os.remove(json_file)

0 commit comments

Comments
 (0)