Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions flixopt/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import ItemsView, Iterator, KeysView, ValuesView

from .flow_system import FlowSystem

Expand Down Expand Up @@ -158,10 +158,19 @@ class Comparison:
"""

def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = None) -> None:
from .flow_system import FlowSystem

if not isinstance(flow_systems, list):
raise TypeError(f'flow_systems must be a list, got {type(flow_systems).__name__}')

non_fs = [(i, type(fs).__name__) for i, fs in enumerate(flow_systems) if not isinstance(fs, FlowSystem)]
if non_fs:
raise TypeError(f'flow_systems must contain only FlowSystem instances; got {non_fs} (index, type)')
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if len(flow_systems) < 2:
raise ValueError('Comparison requires at least 2 FlowSystems')

self._systems = flow_systems
self._systems: list[FlowSystem] = flow_systems
self._names = names or [fs.name or f'System {i}' for i, fs in enumerate(flow_systems)]

if len(self._names) != len(self._systems):
Expand Down Expand Up @@ -224,14 +233,30 @@ def __getitem__(self, key: int | str) -> FlowSystem:
return self._systems[idx]
raise KeyError(f"Case '{key}' not found. Available: {self._names}")

def __iter__(self) -> Iterator[tuple[str, FlowSystem]]:
"""Iterate over (name, FlowSystem) pairs."""
yield from zip(self._names, self._systems, strict=True)
def __iter__(self) -> Iterator[str]:
"""Iterate over case names, matching the ``dict`` / ``Mapping`` protocol.

Use :meth:`items` for ``(name, FlowSystem)`` pairs or :meth:`values`
for FlowSystems.
"""
return iter(self._names)

def __contains__(self, key: str) -> bool:
"""Check if a case name exists."""
return key in self._names

def keys(self) -> KeysView[str]:
"""Return a view of case names, like :meth:`dict.keys`."""
return self.flow_systems.keys()

def values(self) -> ValuesView[FlowSystem]:
"""Return a view of FlowSystems, like :meth:`dict.values`."""
return self.flow_systems.values()

def items(self) -> ItemsView[str, FlowSystem]:
"""Return a view of ``(name, FlowSystem)`` pairs, like :meth:`dict.items`."""
return self.flow_systems.items()

@property
def flow_systems(self) -> dict[str, FlowSystem]:
"""Access underlying FlowSystems as a dict mapping name → FlowSystem."""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ full = [

# Development tools and testing
dev = [
"xarray<2026.3", # TODO: drop once linopy ships xarray 2026.3+ compat fix
"tsam==3.3.0", # Time series aggregation for clustering
"pytest==9.0.3",
"pytest-xdist==3.8.0",
Expand Down
32 changes: 28 additions & 4 deletions tests/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ def test_comparison_rejects_unoptimized_system(self, base_flow_system, optimized
with pytest.raises(RuntimeError, match='no solution'):
_ = comp.solution

def test_comparison_rejects_non_list(self, optimized_base, optimized_with_chp):
"""Comparison rejects non-list flow_systems input."""
with pytest.raises(TypeError, match='must be a list'):
fx.Comparison((optimized_base, optimized_with_chp))

def test_comparison_rejects_non_flowsystem_items(self, optimized_base):
"""Comparison rejects list items that are not FlowSystem instances."""
with pytest.raises(TypeError, match='FlowSystem instances'):
fx.Comparison([optimized_base, 'not a flow system'])


# ============================================================================
# CONTAINER PROTOCOL TESTS
Expand Down Expand Up @@ -212,11 +222,25 @@ def test_getitem_invalid_index_raises(self, optimized_base, optimized_with_chp):
with pytest.raises(IndexError):
_ = comp[99]

def test_iter(self, optimized_base, optimized_with_chp):
"""Iteration yields (name, FlowSystem) pairs."""
def test_iter_yields_names(self, optimized_base, optimized_with_chp):
"""Iteration yields case names, matching the dict/Mapping protocol."""
comp = fx.Comparison([optimized_base, optimized_with_chp])
assert list(comp) == ['Base', 'WithCHP']

def test_keys(self, optimized_base, optimized_with_chp):
"""keys() returns case names."""
comp = fx.Comparison([optimized_base, optimized_with_chp])
assert list(comp.keys()) == ['Base', 'WithCHP']

def test_values(self, optimized_base, optimized_with_chp):
"""values() returns FlowSystems."""
comp = fx.Comparison([optimized_base, optimized_with_chp])
assert list(comp.values()) == [optimized_base, optimized_with_chp]

def test_items(self, optimized_base, optimized_with_chp):
"""items() returns (name, FlowSystem) pairs without warning."""
comp = fx.Comparison([optimized_base, optimized_with_chp])
items = list(comp)
assert items == [('Base', optimized_base), ('WithCHP', optimized_with_chp)]
assert list(comp.items()) == [('Base', optimized_base), ('WithCHP', optimized_with_chp)]

def test_contains(self, optimized_base, optimized_with_chp):
"""'in' operator checks for case name."""
Expand Down
Loading