Skip to content

Commit 4e1de05

Browse files
EliEli
authored andcommitted
Fixed 1) flow not working with selected stations and 2) failure if the wrong number of time basis entries are given.
1 parent 751c182 commit 4e1de05

2 files changed

Lines changed: 36 additions & 27 deletions

File tree

schimpy/batch_metrics.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,8 @@ def plot(self):
400400
if isinstance(outputs_dir, str):
401401
outputs_dir = outputs_dir.split()
402402
if isinstance(params["time_basis"], list):
403-
assert len(outputs_dir) == len(params["time_basis"])
404-
403+
if len(outputs_dir) > len(params["time_basis"]):
404+
raise ValueError("time basis provided as list but doesn't have as many entries as output_dir")
405405
time_basis = [
406406
process_time_str(date_str) for date_str in params["time_basis"]
407407
]
@@ -463,28 +463,39 @@ def plot(self):
463463
"station, RMSE, lag, bias, NSE, Willmott_skill, Correlation\n"
464464
)
465465

466-
if selected_stations is not None:
467-
idx = pd.IndexSlice
468-
sim_outputs[0] = sim_outputs[0].loc[:, idx[selected_stations, :]]
469-
self.logger.info("==================================================")
470-
self.logger.info(
471-
"'selected_stations' enabled. Only following will be processed."
472-
)
473-
for station_id in selected_stations:
474-
self.logger.info(" {}".format(station_id))
475-
476-
if excluded_stations is not None:
477-
sim_outputs[0] = sim_outputs[0].loc[
478-
:,
479-
idx[
480-
~sim_outputs[0].columns.get_level_values(0).isin(excluded_stations),
481-
:,
482-
],
483-
]
484-
self.logger.info("==================================================")
485-
self.logger.info("'excluded_stations' enabled. Following will be skipped.")
486-
for station_id in excluded_stations:
487-
self.logger.info(" {}".format(station_id))
466+
if selected_stations:
467+
# Subset quietly to stations actually present for this variable.
468+
# Preserve the user’s order where possible.
469+
order = [s for s in selected_stations if isinstance(s, str)]
470+
cols = sim_outputs[0].columns
471+
if isinstance(cols, pd.MultiIndex):
472+
level0 = cols.get_level_values(0)
473+
keep_mask = level0.isin(order)
474+
df = sim_outputs[0].loc[:, keep_mask]
475+
# Reorder blocks by the user-provided order (within each station keep native order)
476+
if not df.empty:
477+
new_cols = []
478+
for s in order:
479+
block = [c for c in df.columns if c[0] == s]
480+
if block:
481+
new_cols.extend(block)
482+
if new_cols:
483+
df = df.loc[:, new_cols]
484+
sim_outputs[0] = df
485+
else:
486+
present = [s for s in order if s in cols]
487+
# If none of the requested stations are present, return an empty selection
488+
sim_outputs[0] = sim_outputs[0].loc[:, present] if present else sim_outputs[0].iloc[:, 0:0]
489+
490+
491+
if excluded_stations:
492+
cols = sim_outputs[0].columns
493+
if isinstance(cols, pd.MultiIndex):
494+
mask = ~cols.get_level_values(0).isin(excluded_stations)
495+
sim_outputs[0] = sim_outputs[0].loc[:, mask]
496+
else:
497+
mask = ~cols.isin(excluded_stations)
498+
sim_outputs[0] = sim_outputs[0].loc[:, mask]
488499

489500
# Iterate through the stations in the first simulation outputs
490501
for stn in sim_outputs[0].columns:

schimpy/metricsplot.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def safe_window(ts, window):
5151
unit = ts.unit
5252

5353
if ts.last_valid_index() is None or ts.first_valid_index() is None:
54-
print(" Valid index None")
54+
print(" No valid time indexes found")
5555
return None
5656
if window[0] > window[1]:
5757
raise ValueError("The left value of the window is larger than the right.")
@@ -299,13 +299,11 @@ def gen_metrics_grid():
299299

300300
def plot_tss(ax, tss, labels, window=None, cell_method="inst"):
301301
"""Simply plot lines from a list of time series"""
302-
print(f"Checking time window validity for {cell_method} plot")
303302
if window is not None:
304303
tss_plot = []
305304
if len(tss) > 0:
306305
for it in range(len(tss)):
307306
ts = tss[it]
308-
print(f" {labels[it]}...")
309307
tss_plot.append(safe_window(ts, window))
310308
else:
311309
tss_plot = tss

0 commit comments

Comments
 (0)