Skip to content

Commit 086a450

Browse files
FBumannclaude
andcommitted
refactor!: remove legacy variable category fallback from _Expander
Remove backwards-compat pattern matching (_is_state_variable, _is_first_timestep_variable, _build_segment_total_varnames) that was only needed for FlowSystems saved before _variable_categories was added in v6.0. Now requires _variable_categories to always be present. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 51274ee commit 086a450

1 file changed

Lines changed: 11 additions & 64 deletions

File tree

flixopt/transform_accessor.py

Lines changed: 11 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,6 @@ def __init__(self, fs: FlowSystem, clustering: Clustering):
299299

300300
# Pre-compute clustering dimensions
301301
self._timesteps_per_cluster = clustering.timesteps_per_cluster
302-
self._n_segments = clustering.n_segments
303-
self._time_dim_size = self._n_segments if self._n_segments else self._timesteps_per_cluster
304302
self._n_clusters = clustering.n_clusters
305303
self._n_original_clusters = clustering.n_original_clusters
306304

@@ -319,70 +317,17 @@ def __init__(self, fs: FlowSystem, clustering: Clustering):
319317
self._n_original_clusters - 1,
320318
)
321319

322-
# Build variable category sets
323-
self._variable_categories = getattr(fs, '_variable_categories', {})
324-
if self._variable_categories:
325-
self._state_vars = {name for name, cat in self._variable_categories.items() if cat in EXPAND_INTERPOLATE}
326-
self._first_timestep_vars = {
327-
name for name, cat in self._variable_categories.items() if cat in EXPAND_FIRST_TIMESTEP
328-
}
329-
self._segment_total_vars = {name for name, cat in self._variable_categories.items() if cat in EXPAND_DIVIDE}
330-
else:
331-
# Fallback to pattern matching for old FlowSystems without categories
332-
self._state_vars = set()
333-
self._first_timestep_vars = set()
334-
self._segment_total_vars = self._build_segment_total_varnames() if clustering.is_segmented else set()
320+
# Build variable category sets from registered categories
321+
variable_categories = fs._variable_categories
322+
self._state_vars = {name for name, cat in variable_categories.items() if cat in EXPAND_INTERPOLATE}
323+
self._first_timestep_vars = {name for name, cat in variable_categories.items() if cat in EXPAND_FIRST_TIMESTEP}
324+
self._segment_total_vars = {name for name, cat in variable_categories.items() if cat in EXPAND_DIVIDE}
335325

336326
# Pre-compute expansion divisor for segmented systems (segment durations on original time)
337327
self._expansion_divisor = None
338328
if clustering.is_segmented:
339329
self._expansion_divisor = clustering.disaggregate(clustering.segment_durations).ffill(dim='time')
340330

341-
def _is_state_variable(self, var_name: str) -> bool:
342-
"""Check if variable is a state variable requiring interpolation."""
343-
return var_name in self._state_vars or (not self._variable_categories and var_name.endswith('|charge_state'))
344-
345-
def _is_first_timestep_variable(self, var_name: str) -> bool:
346-
"""Check if variable is a first-timestep-only variable (startup/shutdown)."""
347-
return var_name in self._first_timestep_vars or (
348-
not self._variable_categories and (var_name.endswith('|startup') or var_name.endswith('|shutdown'))
349-
)
350-
351-
def _build_segment_total_varnames(self) -> set[str]:
352-
"""Build segment total variable names - BACKWARDS COMPATIBILITY FALLBACK.
353-
354-
This method is only used when variable_categories is empty (old FlowSystems
355-
saved before category registration was implemented). New FlowSystems use
356-
the VariableCategory registry with EXPAND_DIVIDE categories (PER_TIMESTEP, SHARE).
357-
358-
Returns:
359-
Set of variable names that should be divided by expansion divisor.
360-
"""
361-
segment_total_vars: set[str] = set()
362-
effect_names = list(self._fs.effects.keys())
363-
364-
# 1. Per-timestep totals for each effect
365-
for effect in effect_names:
366-
segment_total_vars.add(f'{effect}(temporal)|per_timestep')
367-
368-
# 2. Flow contributions to effects
369-
for flow_label in self._fs.flows:
370-
for effect in effect_names:
371-
segment_total_vars.add(f'{flow_label}->{effect}(temporal)')
372-
373-
# 3. Component contributions to effects
374-
for component_label in self._fs.components:
375-
for effect in effect_names:
376-
segment_total_vars.add(f'{component_label}->{effect}(temporal)')
377-
378-
# 4. Effect-to-effect contributions
379-
for target_effect_name, target_effect in self._fs.effects.items():
380-
if target_effect.share_from_temporal:
381-
for source_effect_name in target_effect.share_from_temporal:
382-
segment_total_vars.add(f'{source_effect_name}(temporal)->{target_effect_name}(temporal)')
383-
384-
return segment_total_vars
385-
386331
def _append_final_state(self, expanded: xr.DataArray, da: xr.DataArray) -> xr.DataArray:
387332
"""Append final state value from original data to expanded data."""
388333
cluster_assignments = self._clustering.cluster_assignments
@@ -418,8 +363,8 @@ def expand_dataarray(self, da: xr.DataArray, var_name: str = '', is_solution: bo
418363

419364
clustering = self._clustering
420365
has_cluster_dim = 'cluster' in da.dims
421-
is_state = self._is_state_variable(var_name) and has_cluster_dim
422-
is_first_timestep = self._is_first_timestep_variable(var_name) and has_cluster_dim
366+
is_state = var_name in self._state_vars and has_cluster_dim
367+
is_first_timestep = var_name in self._first_timestep_vars and has_cluster_dim
423368
is_segment_total = is_solution and var_name in self._segment_total_vars
424369

425370
# Solution variables have n+1 timesteps (extra boundary value).
@@ -615,8 +560,10 @@ def expand_flow_system(self) -> FlowSystem:
615560
n_combinations = (len(self._fs.periods) if has_periods else 1) * (
616561
len(self._fs.scenarios) if has_scenarios else 1
617562
)
618-
n_reduced_timesteps = self._n_clusters * self._time_dim_size
619-
segmented_info = f' ({self._n_segments} segments)' if self._n_segments else ''
563+
n_segments = self._clustering.n_segments
564+
time_dim_size = n_segments if n_segments else self._timesteps_per_cluster
565+
n_reduced_timesteps = self._n_clusters * time_dim_size
566+
segmented_info = f' ({n_segments} segments)' if n_segments else ''
620567
logger.info(
621568
f'Expanded FlowSystem from {n_reduced_timesteps} to {self._n_original_timesteps} timesteps '
622569
f'({self._n_clusters} clusters{segmented_info}'

0 commit comments

Comments
 (0)