Skip to content

Commit 6015ade

Browse files
committed
Add ElementReference
Used inside `BeamLine::line` and later on similarly used for `Lattice::branches` and other beamline extensions.
1 parent f3ceb0a commit 6015ade

6 files changed

Lines changed: 163 additions & 13 deletions

File tree

src/pals/kinds/ElementReference.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""Element reference class for referring to elements by name."""
2+
3+
from pydantic import BaseModel, Field, model_serializer
4+
from typing import Annotated
5+
6+
from .mixin import BaseElement
7+
8+
9+
class ElementReference(BaseModel):
10+
"""A pydantic model that represents a reference to a named element.
11+
12+
This class behaves like a string (via __str__ and __eq__) but stores
13+
a true reference to the actual element object once it's resolved.
14+
15+
The element field holds a reference (not a copy) to the actual element.
16+
17+
Attributes:
18+
name: The name of the referenced element
19+
element: A reference to the resolved element object (None until resolved)
20+
21+
Example:
22+
>>> ref = ElementReference(name="drift1")
23+
>>> ref.name
24+
'drift1'
25+
>>> str(ref)
26+
'drift1'
27+
>>> ref == "drift1"
28+
True
29+
>>> ref.element # None until resolved
30+
>>> drift = pals.Drift(name="drift1", length=1.0)
31+
>>> ref.element = drift
32+
>>> ref.is_resolved()
33+
True
34+
>>> ref.element is drift # True - it's a reference, not a copy
35+
True
36+
"""
37+
38+
name: str = Field(..., description="The name of the referenced element")
39+
element: Annotated[
40+
"BaseElement | None",
41+
Field(default=None, description="Reference to the resolved element object"),
42+
] = None
43+
44+
@model_serializer(mode="plain")
45+
def _serialize_as_name(self) -> str:
46+
"""Serialize this reference as just its name.
47+
48+
This makes `model_dump()` return a string (the element name), so nested
49+
serialization (e.g. inside BeamLine.line) produces plain strings too.
50+
"""
51+
return self.name
52+
53+
def __init__(self, name: str | None = None, /, **data):
54+
"""Initialize with either positional name or keyword arguments."""
55+
if name is not None:
56+
super().__init__(name=name, **data)
57+
else:
58+
super().__init__(**data)
59+
60+
def __str__(self) -> str:
61+
"""Return the element name as string."""
62+
return self.name
63+
64+
def __eq__(self, other: object) -> bool:
65+
"""Enable string comparison."""
66+
if isinstance(other, str):
67+
return self.name == other
68+
if isinstance(other, ElementReference):
69+
return self.name == other.name and self.element is other.element
70+
return False
71+
72+
def __hash__(self) -> int:
73+
"""Make hashable like a string."""
74+
return hash(self.name)
75+
76+
def is_resolved(self) -> bool:
77+
"""Check if this reference has been resolved to an actual element."""
78+
return self.element is not None
79+
80+
def __repr__(self) -> str:
81+
"""Return a representation of the ElementReference."""
82+
resolved = "resolved" if self.is_resolved() else "unresolved"
83+
return f"ElementReference('{self.name}', {resolved})"

src/pals/kinds/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .CrabCavity import CrabCavity # noqa: F401
1111
from .Drift import Drift # noqa: F401
1212
from .EGun import EGun # noqa: F401
13+
from .ElementReference import ElementReference # noqa: F401
1314
from .Feedback import Feedback # noqa: F401
1415
from .Fiducial import Fiducial # noqa: F401
1516
from .FloorShift import FloorShift # noqa: F401

src/pals/kinds/all_elements.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
avoiding duplication between BeamLine.line and UnionEle.elements.
55
"""
66

7-
from typing import Annotated, Union
7+
from typing import Union
88

9-
from pydantic import Field
109

1110
from .ACKicker import ACKicker
1211
from .BeamBeam import BeamBeam
@@ -15,6 +14,7 @@
1514
from .CrabCavity import CrabCavity
1615
from .Drift import Drift
1716
from .EGun import EGun
17+
from .ElementReference import ElementReference
1818
from .Feedback import Feedback
1919
from .Fiducial import Fiducial
2020
from .FloorShift import FloorShift
@@ -84,11 +84,11 @@ def get_all_element_types(extra_types: tuple = None):
8484
def get_all_elements_as_annotation(extra_types: tuple = None):
8585
"""Return the Union type of all allowed elements with their name as the discriminator field.
8686
87-
Note: When str is included in the union (for string references), we cannot use
88-
discriminator since str doesn't have a 'kind' field. Pydantic will still properly
89-
validate the union by trying each type in order.
87+
Note: ElementReference is included to support string references to named elements.
88+
Since ElementReference doesn't have a 'kind' field, we cannot use discriminator.
89+
Pydantic will still properly validate the union by trying each type in order in
90+
our unpack_element_list_structure method.
9091
"""
91-
types = get_all_element_types(extra_types)
92-
# Add str to support string references to named elements
93-
# We can't use discriminator with str in the union since str has no 'kind' field
94-
return Union[types + (str,)]
92+
types = get_all_element_types(extra_types) + (ElementReference,)
93+
# We can't use discriminator with ElementReference in the union since it has no 'kind' field
94+
return Union[types]

src/pals/kinds/mixin/all_element_mixin.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from . import BaseElement
8+
from ..ElementReference import ElementReference
89

910

1011
def unpack_element_list_structure(
@@ -44,7 +45,12 @@ def unpack_element_list_structure(
4445
for item in data[field_name]:
4546
# An element can be a string that refers to another element
4647
if isinstance(item, str):
47-
# Keep the string reference as-is - it will be validated later
48+
# Wrap the string in an ElementReference object
49+
new_list.append(ElementReference(item))
50+
continue
51+
# An element can be an ElementReference instance directly
52+
elif isinstance(item, ElementReference):
53+
# Keep the ElementReference as-is
4854
new_list.append(item)
4955
continue
5056
# An element can be a dict
@@ -71,7 +77,7 @@ def unpack_element_list_structure(
7177
continue
7278

7379
raise TypeError(
74-
f"Value must be a reference string or a dict, but we got {item!r}"
80+
f"Value must be a reference string, ElementReference, or a dict, but we got {item!r}"
7581
)
7682

7783
data[field_name] = new_list

src/pals/schema_version.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from typing import Optional
2+
3+
# PALS schema version - null for now, will be set when version scheme is finalized
4+
PALS_SCHEMA_VERSION: Optional[str] = None

tests/test_elements.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,11 +534,67 @@ def test_BeamLine_with_string_references():
534534

535535
assert beamline.name == "fodo_cell"
536536
assert len(beamline.line) == 3
537-
# First element should be a string reference
537+
538+
# First element should be an ElementReference that behaves like the string "drift1"
539+
assert isinstance(beamline.line[0], pals.ElementReference)
538540
assert beamline.line[0] == "drift1"
539-
# Second element should be a string reference
541+
assert beamline.line[0].name == "drift1"
542+
assert beamline.line[0].element is None # Not yet resolved
543+
assert not beamline.line[0].is_resolved()
544+
545+
# Second element should be an ElementReference that behaves like the string "quad1"
546+
assert isinstance(beamline.line[1], pals.ElementReference)
540547
assert beamline.line[1] == "quad1"
548+
assert beamline.line[1].name == "quad1"
549+
assert beamline.line[1].element is None # Not yet resolved
550+
assert not beamline.line[1].is_resolved()
551+
541552
# Third element should be a Drift object
542553
assert isinstance(beamline.line[2], pals.Drift)
543554
assert beamline.line[2].name == "drift2"
544555
assert beamline.line[2].length == 0.5
556+
557+
# Test that we can resolve the reference later
558+
drift_element = pals.Drift(name="drift1", length=1.0)
559+
beamline.line[0].element = drift_element
560+
assert beamline.line[0].is_resolved()
561+
assert beamline.line[0].element.name == "drift1"
562+
assert beamline.line[0].element.length == 1.0
563+
564+
565+
def test_ElementReference_direct():
566+
"""Test ElementReference creation and behavior directly"""
567+
# Test creation with positional argument
568+
ref1 = pals.ElementReference("test_element")
569+
assert ref1.name == "test_element"
570+
assert str(ref1) == "test_element"
571+
assert ref1 == "test_element"
572+
assert not ref1.is_resolved()
573+
574+
# Test creation with keyword argument
575+
ref2 = pals.ElementReference(name="another_element")
576+
assert ref2.name == "another_element"
577+
assert str(ref2) == "another_element"
578+
assert ref2 == "another_element"
579+
580+
# Test hash (for use in sets/dicts)
581+
ref_set = {ref1, ref2}
582+
assert len(ref_set) == 2
583+
assert ref1 in ref_set
584+
585+
# Test resolution
586+
drift = pals.Drift(name="test_element", length=2.5)
587+
ref1.element = drift
588+
assert ref1.is_resolved()
589+
assert ref1.element.length == 2.5
590+
591+
# Test repr
592+
assert "test_element" in repr(ref1)
593+
assert "resolved" in repr(ref1)
594+
assert "unresolved" in repr(ref2)
595+
596+
# Test that element is a reference, not a copy
597+
assert ref1.element is drift # Same object identity
598+
# Modify the original and verify the reference sees the change
599+
drift.length = 3.0
600+
assert ref1.element.length == 3.0 # Change is visible through reference

0 commit comments

Comments
 (0)