|
16 | 16 | ComponentType = Chopper | Detector |
17 | 17 |
|
18 | 18 |
|
19 | | -def _input_to_dict( |
20 | | - obj: None | list[ComponentType] | tuple[ComponentType, ...] | ComponentType, |
21 | | - kind: type, |
22 | | -): |
23 | | - if isinstance(obj, list | tuple): |
24 | | - out = {} |
25 | | - for item in obj: |
26 | | - new = _input_to_dict(item, kind=kind) |
27 | | - for key in new.keys(): |
28 | | - if key in out: |
29 | | - raise ValueError(f"More than one component named '{key}' found.") |
30 | | - out.update(new) |
31 | | - return out |
32 | | - elif isinstance(obj, kind): |
33 | | - return {obj.name: obj} |
34 | | - elif obj is None: |
35 | | - return {} |
36 | | - else: |
37 | | - raise TypeError( |
38 | | - "Invalid input type. Must be a Chopper or a Detector, " |
39 | | - "or a list/tuple of Choppers or Detectors." |
40 | | - ) |
41 | | - |
42 | | - |
43 | 19 | def _array_or_none(container: dict, key: str) -> sc.Variable | None: |
44 | 20 | return ( |
45 | 21 | sc.array( |
@@ -131,12 +107,20 @@ class Model: |
131 | 107 | def __init__( |
132 | 108 | self, |
133 | 109 | source: Source | None = None, |
134 | | - choppers: Chopper | list[Chopper] | tuple[Chopper, ...] | None = None, |
135 | | - detectors: Detector | list[Detector] | tuple[Detector, ...] | None = None, |
| 110 | + choppers: list[Chopper] | tuple[Chopper, ...] | None = None, |
| 111 | + detectors: list[Detector] | tuple[Detector, ...] | None = None, |
136 | 112 | ): |
137 | | - self.choppers = _input_to_dict(choppers, kind=Chopper) |
138 | | - self.detectors = _input_to_dict(detectors, kind=Detector) |
| 113 | + self.choppers = {} |
| 114 | + self.detectors = {} |
139 | 115 | self.source = source |
| 116 | + for components, kind in ((choppers, Chopper), (detectors, Detector)): |
| 117 | + for c in components or (): |
| 118 | + if not isinstance(c, kind): |
| 119 | + raise TypeError( |
| 120 | + f"Beamline components: expected {kind.__name__} instance, " |
| 121 | + f"got {type(c)}." |
| 122 | + ) |
| 123 | + self.add(c) |
140 | 124 |
|
141 | 125 | @classmethod |
142 | 126 | def from_json(cls, filename: str) -> Model: |
@@ -212,31 +196,33 @@ def to_json(self, filename: str): |
212 | 196 | with open(filename, 'w') as f: |
213 | 197 | json.dump(self.as_json(), f, indent=2) |
214 | 198 |
|
215 | | - def add(self, component): |
| 199 | + def add(self, component: Chopper | Detector): |
216 | 200 | """ |
217 | 201 | Add a component to the instrument. |
218 | 202 | Component names must be unique across choppers and detectors. |
| 203 | + The name "source" is reserved for the source, and can thus not be used for other |
| 204 | + components. |
219 | 205 |
|
220 | 206 | Parameters |
221 | 207 | ---------- |
222 | 208 | component: |
223 | 209 | A chopper or detector. |
224 | 210 | """ |
225 | | - if component.name in chain(self.choppers, self.detectors): |
| 211 | + if not isinstance(component, (Chopper | Detector)): |
| 212 | + raise TypeError( |
| 213 | + f"Cannot add component of type {type(component)} to the model. " |
| 214 | + "Only Chopper and Detector instances are allowed." |
| 215 | + ) |
| 216 | + # Note that the name "source" is reserved for the source. |
| 217 | + if component.name in chain(self.choppers, self.detectors, ("source",)): |
226 | 218 | raise KeyError( |
227 | 219 | f"Component with name {component.name} already exists. " |
228 | 220 | "If you wish to replace/update an existing component, use " |
229 | 221 | "``model.choppers['name'] = new_chopper`` or " |
230 | 222 | "``model.detectors['name'] = new_detector``." |
231 | 223 | ) |
232 | | - if isinstance(component, Chopper): |
233 | | - self.choppers[component.name] = component |
234 | | - elif isinstance(component, Detector): |
235 | | - self.detectors[component.name] = component |
236 | | - else: |
237 | | - raise TypeError( |
238 | | - f"Cannot add component of type {type(component)} to the model." |
239 | | - ) |
| 224 | + container = self.choppers if isinstance(component, Chopper) else self.detectors |
| 225 | + container[component.name] = component |
240 | 226 |
|
241 | 227 | def remove(self, name: str): |
242 | 228 | """ |
|
0 commit comments