Skip to content

Commit bb0c406

Browse files
committed
Fix tests and period aggregation
1 parent e609c32 commit bb0c406

2 files changed

Lines changed: 129 additions & 238 deletions

File tree

climada/trajectories/interpolated_trajectory.py

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -619,57 +619,38 @@ def _get_risk_periods(
619619
)
620620
]
621621

622-
def _make_period_bins(
623-
self, freq: str | None = None
624-
) -> tuple[pd.DatetimeIndex, list[str]]:
625-
"""Build bin edges and labels from snapshot dates or a given frequency.
622+
def _assign_snapshot_period_ids(self, dates: pd.Series) -> pd.Series:
623+
"""Assign each date to the index of the snapshot pair that contains it.
626624
627625
Parameters
628626
----------
629-
freq : str, optional
630-
Pandas frequency string (e.g. ``"2Y"``, ``"3M"``). If None, bins
631-
correspond to the intervals between consecutive snapshots.
627+
dates : pd.Series
628+
Series of Period dtype with frequency ``self.time_resolution``.
632629
633630
Returns
634631
-------
635-
bin_edges : pd.DatetimeIndex
636-
labels : list of str
632+
pd.Series
633+
Integer series of same index as ``dates``, with values in
634+
``range(len(self._snapshots) - 1)``. Dates outside all snapshot
635+
intervals are assigned ``NaN``.
637636
"""
638-
snapshot_dates = sorted(snap.date for snap in self._snapshots)
639-
start, end = snapshot_dates[0], snapshot_dates[-1]
640-
641-
if freq is None:
642-
edges = pd.DatetimeIndex(snapshot_dates)
643-
else:
644-
edges = pd.date_range(start=start, end=end, freq=freq, inclusive="left")
645-
if edges[-1] < end:
646-
edges = pd.date_range(
647-
start=start, periods=len(edges) + 1, freq=freq, inclusive="left"
648-
)
649-
650-
if edges[0] != start:
651-
LOGGER.warning(
652-
"The first bin edge %s does not match the start date %s. "
653-
"This is likely because '%s' is interpreted as an end-anchored frequency. "
654-
"Consider using an explicit start-anchored frequency instead "
655-
"(e.g. 'YS' instead of 'Y', 'MS' instead of 'M').",
656-
edges[0].date(),
657-
start.date(),
658-
freq,
659-
)
660-
661-
labels = [
662-
f"{edges[i].date()} to {edges[i + 1].date()}" for i in range(len(edges) - 1)
637+
snapshot_dates = sorted(snap.date for snap in self._snapshots) + [
638+
self._snapshots[-1].date + pd.DateOffset()
663639
]
664-
return edges, labels
640+
bins = pd.DatetimeIndex(snapshot_dates)
641+
ts = dates.dt.to_timestamp(how="start")
642+
return pd.cut(
643+
ts,
644+
bins=bins,
645+
labels=False,
646+
include_lowest=True,
647+
right=True,
648+
)
665649

666-
@classmethod
667650
def _date_to_period_agg(
668-
cls,
651+
self,
669652
metric_df: pd.DataFrame,
670653
grouper: list[str],
671-
bin_edges: pd.DatetimeIndex,
672-
labels: list[str],
673654
colname: str | list[str] = RISK_COL_NAME,
674655
aggfunc: str | Callable = "mean",
675656
) -> pd.DataFrame:
@@ -679,29 +660,32 @@ def _date_to_period_agg(
679660
----------
680661
metric_df : pd.DataFrame
681662
grouper : list of str
682-
bin_edges : pd.DatetimeIndex
683-
Edges of the period bins, as returned by ``_make_period_bins``.
684-
labels : list of str
685-
Labels for each bin interval.
686663
colname : str or list of str, optional
687664
aggfunc : str or callable, optional
688665
Aggregation function passed to ``groupby.agg``. Default is ``"mean"``.
666+
freq : str, optional
667+
If provided, resample the date column at this frequency.
668+
time_resolution : str, optional
669+
The time resolution of the date column, used to format labels when
670+
``freq`` is provided.
671+
snapshot_mapper : dict, optional
672+
Maps each ``pd.Period`` to a snapshot interval label. Used when
673+
``freq`` is None.
689674
"""
690675
if isinstance(colname, str):
691676
colname = [colname]
692677

693678
df = metric_df.copy()
694-
df[PERIOD_COL_NAME] = pd.cut(
695-
df[DATE_COL_NAME].dt.to_timestamp(how="start"),
696-
bins=bin_edges,
697-
labels=labels,
698-
include_lowest=True,
699-
right=False,
700-
)
701679

702680
if GROUP_COL_NAME in df.columns and GROUP_COL_NAME not in grouper:
703681
grouper = [GROUP_COL_NAME] + grouper
704682

683+
df[PERIOD_COL_NAME] = self._assign_snapshot_period_ids(df[DATE_COL_NAME])
684+
df[PERIOD_COL_NAME] = (
685+
df.groupby(PERIOD_COL_NAME)[DATE_COL_NAME].transform("first").astype(str)
686+
+ " to "
687+
+ df.groupby(PERIOD_COL_NAME)[DATE_COL_NAME].transform("last").astype(str)
688+
)
705689
return (
706690
df.groupby([PERIOD_COL_NAME] + grouper, dropna=False, observed=True)[
707691
colname
@@ -717,7 +701,6 @@ def per_period_risk_metrics(
717701
RETURN_PERIOD_METRIC_NAME,
718702
AAI_PER_GROUP_METRIC_NAME,
719703
),
720-
freq: str | None = None,
721704
colname: str | list[str] = RISK_COL_NAME,
722705
aggfunc: str | Callable = "mean",
723706
) -> pd.DataFrame:
@@ -739,9 +722,9 @@ def per_period_risk_metrics(
739722
Notes
740723
-----
741724
742-
Periods are left inclusing, right excluding, meaning for instance,
743-
"2018-01-01 to 2024-01-01" is the average risk from 2018-01-01 included
744-
to 2023-12-31 included.
725+
If freq is given, periods are left inclusing, right excluding,
726+
meaning for instance, "2018-01-01 to 2024-01-01" is the
727+
average risk from 2018-01-01 included to 2023-12-31 included.
745728
746729
If the last date is at odd with the frequency given, the aggfunc is
747730
still applied over the "whole" bin inclunding the date, for instance if
@@ -751,12 +734,10 @@ def per_period_risk_metrics(
751734
752735
"""
753736
metric_df = self.per_date_risk_metrics(metrics=metrics)
754-
bin_edges, labels = self._make_period_bins(freq=freq)
737+
755738
return self._date_to_period_agg(
756739
metric_df,
757740
grouper=self._grouper + [UNIT_COL_NAME],
758-
bin_edges=bin_edges,
759-
labels=labels,
760741
colname=colname,
761742
aggfunc=aggfunc,
762743
)

0 commit comments

Comments
 (0)