Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions flixopt/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

import threading
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Literal, overload

import xarray as xr
Expand All @@ -19,6 +21,7 @@
)

if TYPE_CHECKING:
import pathlib
from collections.abc import ItemsView, Iterator, KeysView, ValuesView

from .flow_system import FlowSystem
Expand All @@ -28,6 +31,12 @@
# Extract all unique slot names from xarray_plotly
_CASE_SLOTS = frozenset(slot for slots in SLOT_ORDERS.values() for slot in slots)

# The netCDF4 C library is not thread-safe: concurrent `xr.load_dataset` calls
# with engine='netcdf4' can segfault because the HDF5 error stack and library
# state are global. We serialize only the file-read step (the CPU-heavy
# deserialization that follows runs in parallel).
_NETCDF_READ_LOCK = threading.Lock()


def _extract_nonindex_coords(datasets: list[xr.Dataset]) -> tuple[list[xr.Dataset], dict[str, tuple[str, dict]]]:
"""Extract and merge non-index coords, returning cleaned datasets and merged mappings.
Expand Down Expand Up @@ -186,6 +195,66 @@ def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = Non
self._statistics: ComparisonStatistics | None = None
self._inputs: xr.Dataset | None = None

@classmethod
def from_netcdf(
cls,
paths: list[str | pathlib.Path] | dict[str | pathlib.Path, str],
max_workers: int | None = None,
) -> Comparison:
"""Load multiple FlowSystems from NetCDF files and combine them into a Comparison.

The file read itself is serialized (the netCDF4 C library is not
thread-safe — concurrent reads can segfault), but the CPU-heavy
deserialization — JSON attrs and rebuilding the FlowSystem from the
dataset — runs in parallel across a thread pool. This typically still
gives a solid speedup because deserialization dominates the total load
time for non-trivial systems.

Args:
paths: Either a list of file paths (names are derived from the
filename stems), or a dict mapping file paths to explicit case
names.
max_workers: Maximum number of threads used to deserialize loaded
datasets. ``None`` uses the default of
:class:`concurrent.futures.ThreadPoolExecutor`. Set to ``1`` to
run sequentially.

Returns:
A new :class:`Comparison` containing the loaded FlowSystems.

Examples:
```python
# From a list (names come from filenames)
comp = fx.Comparison.from_netcdf(['results/base.nc', 'results/modified.nc'])

# With explicit names
comp = fx.Comparison.from_netcdf({'results/base.nc': 'baseline', 'results/modified.nc': 'variant'})
```
"""
import pathlib as _pl

from .flow_system import FlowSystem
from .io import load_dataset_from_netcdf

if isinstance(paths, dict):
path_list = list(paths.keys())
names: list[str] | None = list(paths.values())
else:
path_list = list(paths)
names = None

def _load_one(path: str | _pl.Path) -> FlowSystem:
with _NETCDF_READ_LOCK:
ds = load_dataset_from_netcdf(path)
fs = FlowSystem.from_dataset(ds)
fs.name = _pl.Path(path).stem
return fs
Comment on lines +246 to +251
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

fs.name is overwritten with the filename stem even when the user supplies explicit names via a dict.

For the dict input form ({path: 'baseline', ...}), the explicit name is only applied at the Comparison level (via names=), while each FlowSystem.name is still forced to Path(path).stem. That leaves the underlying system with a name that differs from the case label the user asked for, which can be confusing when accessing comp[i].name or later rebuilding a Comparison from those systems.

Consider using the provided name when available:

Proposed tweak
-        if isinstance(paths, dict):
-            path_list = list(paths.keys())
-            names: list[str] | None = list(paths.values())
-        else:
-            path_list = list(paths)
-            names = None
-
-        def _load_one(path: str | _pl.Path) -> FlowSystem:
-            with _NETCDF_READ_LOCK:
-                ds = load_dataset_from_netcdf(path)
-            fs = FlowSystem.from_dataset(ds)
-            fs.name = _pl.Path(path).stem
-            return fs
-
-        with ThreadPoolExecutor(max_workers=max_workers) as executor:
-            flow_systems = list(executor.map(_load_one, path_list))
+        if isinstance(paths, dict):
+            path_list = list(paths.keys())
+            names: list[str] | None = list(paths.values())
+        else:
+            path_list = list(paths)
+            names = None
+
+        stem_names = [_pl.Path(p).stem for p in path_list]
+        fs_names = names if names is not None else stem_names
+
+        def _load_one(item: tuple[str | _pl.Path, str]) -> FlowSystem:
+            path, fs_name = item
+            with _NETCDF_READ_LOCK:
+                ds = load_dataset_from_netcdf(path)
+            fs = FlowSystem.from_dataset(ds)
+            fs.name = fs_name
+            return fs
+
+        with ThreadPoolExecutor(max_workers=max_workers) as executor:
+            flow_systems = list(executor.map(_load_one, zip(path_list, fs_names, strict=True)))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flixopt/comparison.py` around lines 246 - 251, The helper _load_one currently
unconditionally sets FlowSystem.name to Path(path).stem, overwriting
user-supplied names; modify the loading logic so that when callers supply an
explicit name (e.g., via the dict form feeding Comparison names=), you preserve
that name instead of replacing it — detect the provided label before calling
_load_one (or change _load_one to accept an optional name parameter), call
FlowSystem.from_dataset(ds) and only set fs.name = Path(path).stem if no
explicit name was given; update any call sites that construct Comparison from a
mapping to pass the provided label into _load_one.


with ThreadPoolExecutor(max_workers=max_workers) as executor:
flow_systems = list(executor.map(_load_one, path_list))

return cls(flow_systems, names=names)

def __repr__(self) -> str:
"""Return a detailed string representation."""
lines = ['Comparison', '=' * 10]
Expand Down Expand Up @@ -417,6 +486,43 @@ def inputs(self) -> xr.Dataset:
self._inputs = _apply_merged_coords(result, merged_coords)
return self._inputs

def expand(self, max_workers: int | None = None) -> Comparison:
"""Expand clustered FlowSystems back to full timesteps in parallel.

Calls :meth:`FlowSystem.transform.expand` on every contained FlowSystem
that has a ``clustering`` attribute. FlowSystems without clustering are
passed through unchanged, so mixed comparisons are safe.

Expansion is CPU-bound but vectorized through xarray/numpy, which
release the GIL for most operations — a thread pool is typically
enough to get a speedup.

Args:
max_workers: Maximum number of threads used to expand systems.
``None`` uses the default of
:class:`concurrent.futures.ThreadPoolExecutor`. Set to ``1`` to
expand sequentially.

Returns:
A new :class:`Comparison` with expanded FlowSystems, preserving
the original case names.

Examples:
```python
comp_reduced = fx.Comparison([fs_clustered_a, fs_clustered_b])
comp_full = comp_reduced.expand()
comp_full.stats.plot.balance('Heat') # Full-resolution plots
```
"""

def _expand_one(fs: FlowSystem) -> FlowSystem:
return fs.transform.expand() if fs.clustering is not None else fs

with ThreadPoolExecutor(max_workers=max_workers) as executor:
expanded = list(executor.map(_expand_one, self._systems))

return type(self)(expanded, names=list(self._names))


class ComparisonStatistics:
"""Combined statistics accessor for comparing FlowSystems.
Expand Down
128 changes: 128 additions & 0 deletions tests/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,131 @@ def test_diff_invalid_reference_raises(self, optimized_base, optimized_with_chp)

with pytest.raises(ValueError, match='not found'):
comp.diff(reference='NonexistentCase')


# ============================================================================
# PARALLEL LOAD / EXPAND TESTS
# ============================================================================


class TestComparisonFromNetcdf:
"""Tests for Comparison.from_netcdf classmethod."""

def test_from_netcdf_list_of_paths(self, tmp_path, optimized_base, optimized_with_chp):
"""List of paths loads systems with names derived from filenames."""
p1 = tmp_path / 'base.nc'
p2 = tmp_path / 'with_chp.nc'
optimized_base.to_netcdf(p1)
optimized_with_chp.to_netcdf(p2)

comp = fx.Comparison.from_netcdf([p1, p2])

assert comp.names == ['base', 'with_chp']
assert len(comp) == 2
assert comp.is_optimized

def test_from_netcdf_dict_paths_to_names(self, tmp_path, optimized_base, optimized_with_chp):
"""Dict input uses explicit names instead of filenames."""
p1 = tmp_path / 'base.nc'
p2 = tmp_path / 'chp.nc'
optimized_base.to_netcdf(p1)
optimized_with_chp.to_netcdf(p2)

comp = fx.Comparison.from_netcdf({p1: 'baseline', p2: 'variant'})

assert comp.names == ['baseline', 'variant']

def test_from_netcdf_serial_matches_parallel(self, tmp_path, optimized_base, optimized_with_chp):
"""max_workers=1 produces the same result as the default parallel load."""
p1 = tmp_path / 'base.nc'
p2 = tmp_path / 'chp.nc'
optimized_base.to_netcdf(p1)
optimized_with_chp.to_netcdf(p2)

comp_parallel = fx.Comparison.from_netcdf([p1, p2])
comp_serial = fx.Comparison.from_netcdf([p1, p2], max_workers=1)

assert comp_parallel.names == comp_serial.names
xr.testing.assert_identical(comp_parallel.solution, comp_serial.solution)


class TestComparisonExpand:
"""Tests for Comparison.expand method."""

@pytest.fixture(scope='class')
def clustered_systems(self):
"""Build two clustered, optimized FlowSystems (module/class-scoped: solve once)."""
pytest.importorskip('tsam')
n_hours = 168 # 7 days
ts = pd.date_range('2024-01-01', periods=n_hours, freq='h', name='time')
demand = np.sin(np.linspace(0, 14 * np.pi, n_hours)) + 2

def _build(name: str, cost: float) -> fx.FlowSystem:
fs = fx.FlowSystem(ts, name=name)
fs.add_elements(
fx.Effect('costs', '€', 'Costs', is_standard=True, is_objective=True),
fx.Bus('Electricity'),
fx.Source(
'Grid',
outputs=[fx.Flow('P_el', bus='Electricity', size=100, effects_per_flow_hour={'costs': cost})],
),
fx.Sink(
'Demand',
inputs=[
fx.Flow(
'P_demand',
bus='Electricity',
size=100,
fixed_relative_profile=fx.TimeSeriesData(demand / 100),
)
],
),
)
return fs

solver = fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=60, log_to_console=False)
systems = []
for name, cost in [('A', 0.3), ('B', 0.25)]:
fs = _build(name, cost).transform.cluster(n_clusters=2, cluster_duration='1D')
fs.optimize(solver)
systems.append(fs)
return systems

def test_expand_returns_new_comparison(self, clustered_systems):
"""expand() returns a new Comparison instance, preserving names."""
comp = fx.Comparison(clustered_systems, names=['a', 'b'])
expanded = comp.expand()

assert isinstance(expanded, fx.Comparison)
assert expanded is not comp
assert expanded.names == ['a', 'b']

def test_expand_restores_full_timesteps(self, clustered_systems):
"""Each expanded FlowSystem has the full (original) timestep count."""
comp = fx.Comparison(clustered_systems, names=['a', 'b'])
expanded = comp.expand()

for fs in expanded.values():
# Original was 168 hours; clustering exposes 2D shape but expand
# restores a single time axis with 168 steps (+1 boundary).
assert 'time' in fs.solution.dims
assert fs.solution.sizes['time'] == 168 + 1

def test_expand_serial_matches_parallel(self, clustered_systems):
"""max_workers=1 gives identical results to the default parallel path."""
comp = fx.Comparison(clustered_systems, names=['a', 'b'])

expanded_parallel = comp.expand()
expanded_serial = comp.expand(max_workers=1)

xr.testing.assert_identical(expanded_parallel.solution, expanded_serial.solution)

def test_expand_passes_through_non_clustered(self, clustered_systems, optimized_base):
"""Systems without clustering are passed through unchanged (mixed comparison)."""
comp = fx.Comparison([clustered_systems[0], optimized_base], names=['clustered', 'plain'])
expanded = comp.expand()

# The non-clustered system is the same object, untouched.
assert expanded['plain'] is optimized_base
# The clustered system was actually expanded (new object).
assert expanded['clustered'] is not clustered_systems[0]
Loading