Skip to content

Commit 10103d2

Browse files
FBumannclaude
andauthored
refactor: make Comparison follow dict/Mapping protocol (#671)
* refactor: make Comparison follow dict/Mapping protocol Iteration now yields case names instead of (name, FlowSystem) pairs, matching Python's dict/Mapping convention. Added keys(), values(), and items() methods for explicit access. Users iterating pairs should switch to `for name, fs in comp.items():`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * ci: pin xarray <2026.3 in dev deps for linopy compatibility linopy 0.6.x breaks with xarray 2026.3+ (TypeError on Dataset constructor). Keep the release dependency range unchanged; only pin for dev/CI until linopy ships a fix. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat: validate flow_systems input types in Comparison Raise TypeError if flow_systems is not a list, or if any item is not a FlowSystem instance — replaces the cryptic AttributeError that would otherwise surface later when attributes like .name or .solution are accessed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix: validate element types before length in Comparison Check that each element is a FlowSystem before the minimum-length check, so `Comparison([not_a_flow_system])` surfaces the real problem (wrong type) instead of "requires at least 2 FlowSystems". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5c55d36 commit 10103d2

3 files changed

Lines changed: 59 additions & 9 deletions

File tree

flixopt/comparison.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020

2121
if TYPE_CHECKING:
22-
from collections.abc import Iterator
22+
from collections.abc import ItemsView, Iterator, KeysView, ValuesView
2323

2424
from .flow_system import FlowSystem
2525

@@ -158,10 +158,19 @@ class Comparison:
158158
"""
159159

160160
def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = None) -> None:
161+
from .flow_system import FlowSystem
162+
163+
if not isinstance(flow_systems, list):
164+
raise TypeError(f'flow_systems must be a list, got {type(flow_systems).__name__}')
165+
166+
non_fs = [(i, type(fs).__name__) for i, fs in enumerate(flow_systems) if not isinstance(fs, FlowSystem)]
167+
if non_fs:
168+
raise TypeError(f'flow_systems must contain only FlowSystem instances; got {non_fs} (index, type)')
169+
161170
if len(flow_systems) < 2:
162171
raise ValueError('Comparison requires at least 2 FlowSystems')
163172

164-
self._systems = flow_systems
173+
self._systems: list[FlowSystem] = flow_systems
165174
self._names = names or [fs.name or f'System {i}' for i, fs in enumerate(flow_systems)]
166175

167176
if len(self._names) != len(self._systems):
@@ -224,14 +233,30 @@ def __getitem__(self, key: int | str) -> FlowSystem:
224233
return self._systems[idx]
225234
raise KeyError(f"Case '{key}' not found. Available: {self._names}")
226235

227-
def __iter__(self) -> Iterator[tuple[str, FlowSystem]]:
228-
"""Iterate over (name, FlowSystem) pairs."""
229-
yield from zip(self._names, self._systems, strict=True)
236+
def __iter__(self) -> Iterator[str]:
237+
"""Iterate over case names, matching the ``dict`` / ``Mapping`` protocol.
238+
239+
Use :meth:`items` for ``(name, FlowSystem)`` pairs or :meth:`values`
240+
for FlowSystems.
241+
"""
242+
return iter(self._names)
230243

231244
def __contains__(self, key: str) -> bool:
232245
"""Check if a case name exists."""
233246
return key in self._names
234247

248+
def keys(self) -> KeysView[str]:
249+
"""Return a view of case names, like :meth:`dict.keys`."""
250+
return self.flow_systems.keys()
251+
252+
def values(self) -> ValuesView[FlowSystem]:
253+
"""Return a view of FlowSystems, like :meth:`dict.values`."""
254+
return self.flow_systems.values()
255+
256+
def items(self) -> ItemsView[str, FlowSystem]:
257+
"""Return a view of ``(name, FlowSystem)`` pairs, like :meth:`dict.items`."""
258+
return self.flow_systems.items()
259+
235260
@property
236261
def flow_systems(self) -> dict[str, FlowSystem]:
237262
"""Access underlying FlowSystems as a dict mapping name → FlowSystem."""

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ full = [
7777

7878
# Development tools and testing
7979
dev = [
80+
"xarray<2026.3", # TODO: drop once linopy ships xarray 2026.3+ compat fix
8081
"tsam==3.3.0", # Time series aggregation for clustering
8182
"pytest==9.0.3",
8283
"pytest-xdist==3.8.0",

tests/test_comparison.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ def test_comparison_rejects_unoptimized_system(self, base_flow_system, optimized
173173
with pytest.raises(RuntimeError, match='no solution'):
174174
_ = comp.solution
175175

176+
def test_comparison_rejects_non_list(self, optimized_base, optimized_with_chp):
177+
"""Comparison rejects non-list flow_systems input."""
178+
with pytest.raises(TypeError, match='must be a list'):
179+
fx.Comparison((optimized_base, optimized_with_chp))
180+
181+
def test_comparison_rejects_non_flowsystem_items(self, optimized_base):
182+
"""Comparison rejects list items that are not FlowSystem instances."""
183+
with pytest.raises(TypeError, match='FlowSystem instances'):
184+
fx.Comparison([optimized_base, 'not a flow system'])
185+
176186

177187
# ============================================================================
178188
# CONTAINER PROTOCOL TESTS
@@ -212,11 +222,25 @@ def test_getitem_invalid_index_raises(self, optimized_base, optimized_with_chp):
212222
with pytest.raises(IndexError):
213223
_ = comp[99]
214224

215-
def test_iter(self, optimized_base, optimized_with_chp):
216-
"""Iteration yields (name, FlowSystem) pairs."""
225+
def test_iter_yields_names(self, optimized_base, optimized_with_chp):
226+
"""Iteration yields case names, matching the dict/Mapping protocol."""
227+
comp = fx.Comparison([optimized_base, optimized_with_chp])
228+
assert list(comp) == ['Base', 'WithCHP']
229+
230+
def test_keys(self, optimized_base, optimized_with_chp):
231+
"""keys() returns case names."""
232+
comp = fx.Comparison([optimized_base, optimized_with_chp])
233+
assert list(comp.keys()) == ['Base', 'WithCHP']
234+
235+
def test_values(self, optimized_base, optimized_with_chp):
236+
"""values() returns FlowSystems."""
237+
comp = fx.Comparison([optimized_base, optimized_with_chp])
238+
assert list(comp.values()) == [optimized_base, optimized_with_chp]
239+
240+
def test_items(self, optimized_base, optimized_with_chp):
241+
"""items() returns (name, FlowSystem) pairs without warning."""
217242
comp = fx.Comparison([optimized_base, optimized_with_chp])
218-
items = list(comp)
219-
assert items == [('Base', optimized_base), ('WithCHP', optimized_with_chp)]
243+
assert list(comp.items()) == [('Base', optimized_base), ('WithCHP', optimized_with_chp)]
220244

221245
def test_contains(self, optimized_base, optimized_with_chp):
222246
"""'in' operator checks for case name."""

0 commit comments

Comments
 (0)