|
1 | 1 | import os |
2 | 2 |
|
| 3 | +import pytest |
| 4 | + |
3 | 5 | import pals |
4 | 6 |
|
5 | 7 |
|
@@ -332,3 +334,75 @@ def test_comprehensive_lattice(): |
332 | 334 | # Clean up temporary files |
333 | 335 | os.remove(yaml_file) |
334 | 336 | os.remove(json_file) |
| 337 | + |
| 338 | + |
| 339 | +def _build_numpy_lattice(np): |
| 340 | + """Build a small lattice using numpy-typed scalar values throughout.""" |
| 341 | + quad = pals.Quadrupole( |
| 342 | + name="q_np", |
| 343 | + length=np.float64(0.061), |
| 344 | + MagneticMultipoleP=pals.MagneticMultipoleParameters( |
| 345 | + Bn1=np.float64(-26.0), Bs1=np.float32(0.5), Kn0=np.int64(-1) |
| 346 | + ), |
| 347 | + ) |
| 348 | + oct_ = pals.Octupole( |
| 349 | + name="o_np", |
| 350 | + length=np.float64(0.25), |
| 351 | + ElectricMultipoleP=pals.ElectricMultipoleParameters( |
| 352 | + En3=np.float64(0.75), Es3=np.float32(0.125) |
| 353 | + ), |
| 354 | + ) |
| 355 | + return pals.BeamLine(name="line_np", line=[quad, oct_]) |
| 356 | + |
| 357 | + |
| 358 | +def test_yaml_roundtrip_with_numpy(): |
| 359 | + """Regression test for issue #67: writing YAML with numpy-typed values |
| 360 | + must not produce !!python/object tags, and round-tripping must yield |
| 361 | + Python-native floats with the correct numeric values.""" |
| 362 | + np = pytest.importorskip("numpy") |
| 363 | + |
| 364 | + line = _build_numpy_lattice(np) |
| 365 | + yaml_file = "numpy_roundtrip.pals.yaml" |
| 366 | + line.to_file(yaml_file) |
| 367 | + try: |
| 368 | + with open(yaml_file, "r") as f: |
| 369 | + text = f.read() |
| 370 | + |
| 371 | + # The bug symptom: YAML contains opaque numpy object tags. |
| 372 | + assert "!!python/object" not in text, ( |
| 373 | + f"YAML output still contains unsafe numpy object tags:\n{text}" |
| 374 | + ) |
| 375 | + assert "numpy" not in text, f"YAML output still references numpy:\n{text}" |
| 376 | + |
| 377 | + loaded = pals.BeamLine.from_file(yaml_file) |
| 378 | + loaded_quad = loaded.line[0] |
| 379 | + assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0 |
| 380 | + assert type(loaded_quad.MagneticMultipoleP.Bn1) is float |
| 381 | + assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5 |
| 382 | + assert loaded_quad.MagneticMultipoleP.Kn0 == -1 |
| 383 | + |
| 384 | + loaded_oct = loaded.line[1] |
| 385 | + assert loaded_oct.ElectricMultipoleP.En3 == 0.75 |
| 386 | + assert type(loaded_oct.ElectricMultipoleP.En3) is float |
| 387 | + finally: |
| 388 | + if os.path.exists(yaml_file): |
| 389 | + os.remove(yaml_file) |
| 390 | + |
| 391 | + |
| 392 | +def test_json_roundtrip_with_numpy(): |
| 393 | + """JSON path also needs to handle numpy values cleanly (defense-in-depth).""" |
| 394 | + np = pytest.importorskip("numpy") |
| 395 | + |
| 396 | + line = _build_numpy_lattice(np) |
| 397 | + json_file = "numpy_roundtrip.pals.json" |
| 398 | + line.to_file(json_file) |
| 399 | + try: |
| 400 | + loaded = pals.BeamLine.from_file(json_file) |
| 401 | + loaded_quad = loaded.line[0] |
| 402 | + assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0 |
| 403 | + assert type(loaded_quad.MagneticMultipoleP.Bn1) is float |
| 404 | + loaded_oct = loaded.line[1] |
| 405 | + assert loaded_oct.ElectricMultipoleP.En3 == 0.75 |
| 406 | + finally: |
| 407 | + if os.path.exists(json_file): |
| 408 | + os.remove(json_file) |
0 commit comments