@@ -619,57 +619,38 @@ def _get_risk_periods(
619619 )
620620 ]
621621
622- def _make_period_bins (
623- self , freq : str | None = None
624- ) -> tuple [pd .DatetimeIndex , list [str ]]:
625- """Build bin edges and labels from snapshot dates or a given frequency.
622+ def _assign_snapshot_period_ids (self , dates : pd .Series ) -> pd .Series :
623+ """Assign each date to the index of the snapshot pair that contains it.
626624
627625 Parameters
628626 ----------
629- freq : str, optional
630- Pandas frequency string (e.g. ``"2Y"``, ``"3M"``). If None, bins
631- correspond to the intervals between consecutive snapshots.
627+ dates : pd.Series
628+ Series of Period dtype with frequency ``self.time_resolution``.
632629
633630 Returns
634631 -------
635- bin_edges : pd.DatetimeIndex
636- labels : list of str
632+ pd.Series
633+ Integer series of same index as ``dates``, with values in
634+ ``range(len(self._snapshots) - 1)``. Dates outside all snapshot
635+ intervals are assigned ``NaN``.
637636 """
638- snapshot_dates = sorted (snap .date for snap in self ._snapshots )
639- start , end = snapshot_dates [0 ], snapshot_dates [- 1 ]
640-
641- if freq is None :
642- edges = pd .DatetimeIndex (snapshot_dates )
643- else :
644- edges = pd .date_range (start = start , end = end , freq = freq , inclusive = "left" )
645- if edges [- 1 ] < end :
646- edges = pd .date_range (
647- start = start , periods = len (edges ) + 1 , freq = freq , inclusive = "left"
648- )
649-
650- if edges [0 ] != start :
651- LOGGER .warning (
652- "The first bin edge %s does not match the start date %s. "
653- "This is likely because '%s' is interpreted as an end-anchored frequency. "
654- "Consider using an explicit start-anchored frequency instead "
655- "(e.g. 'YS' instead of 'Y', 'MS' instead of 'M')." ,
656- edges [0 ].date (),
657- start .date (),
658- freq ,
659- )
660-
661- labels = [
662- f"{ edges [i ].date ()} to { edges [i + 1 ].date ()} " for i in range (len (edges ) - 1 )
637+ snapshot_dates = sorted (snap .date for snap in self ._snapshots ) + [
638+ self ._snapshots [- 1 ].date + pd .DateOffset ()
663639 ]
664- return edges , labels
640+ bins = pd .DatetimeIndex (snapshot_dates )
641+ ts = dates .dt .to_timestamp (how = "start" )
642+ return pd .cut (
643+ ts ,
644+ bins = bins ,
645+ labels = False ,
646+ include_lowest = True ,
647+ right = True ,
648+ )
665649
666- @classmethod
667650 def _date_to_period_agg (
668- cls ,
651+ self ,
669652 metric_df : pd .DataFrame ,
670653 grouper : list [str ],
671- bin_edges : pd .DatetimeIndex ,
672- labels : list [str ],
673654 colname : str | list [str ] = RISK_COL_NAME ,
674655 aggfunc : str | Callable = "mean" ,
675656 ) -> pd .DataFrame :
@@ -679,29 +660,32 @@ def _date_to_period_agg(
679660 ----------
680661 metric_df : pd.DataFrame
681662 grouper : list of str
682- bin_edges : pd.DatetimeIndex
683- Edges of the period bins, as returned by ``_make_period_bins``.
684- labels : list of str
685- Labels for each bin interval.
686663 colname : str or list of str, optional
687664 aggfunc : str or callable, optional
688665 Aggregation function passed to ``groupby.agg``. Default is ``"mean"``.
666+ freq : str, optional
667+ If provided, resample the date column at this frequency.
668+ time_resolution : str, optional
669+ The time resolution of the date column, used to format labels when
670+ ``freq`` is provided.
671+ snapshot_mapper : dict, optional
672+ Maps each ``pd.Period`` to a snapshot interval label. Used when
673+ ``freq`` is None.
689674 """
690675 if isinstance (colname , str ):
691676 colname = [colname ]
692677
693678 df = metric_df .copy ()
694- df [PERIOD_COL_NAME ] = pd .cut (
695- df [DATE_COL_NAME ].dt .to_timestamp (how = "start" ),
696- bins = bin_edges ,
697- labels = labels ,
698- include_lowest = True ,
699- right = False ,
700- )
701679
702680 if GROUP_COL_NAME in df .columns and GROUP_COL_NAME not in grouper :
703681 grouper = [GROUP_COL_NAME ] + grouper
704682
683+ df [PERIOD_COL_NAME ] = self ._assign_snapshot_period_ids (df [DATE_COL_NAME ])
684+ df [PERIOD_COL_NAME ] = (
685+ df .groupby (PERIOD_COL_NAME )[DATE_COL_NAME ].transform ("first" ).astype (str )
686+ + " to "
687+ + df .groupby (PERIOD_COL_NAME )[DATE_COL_NAME ].transform ("last" ).astype (str )
688+ )
705689 return (
706690 df .groupby ([PERIOD_COL_NAME ] + grouper , dropna = False , observed = True )[
707691 colname
@@ -717,7 +701,6 @@ def per_period_risk_metrics(
717701 RETURN_PERIOD_METRIC_NAME ,
718702 AAI_PER_GROUP_METRIC_NAME ,
719703 ),
720- freq : str | None = None ,
721704 colname : str | list [str ] = RISK_COL_NAME ,
722705 aggfunc : str | Callable = "mean" ,
723706 ) -> pd .DataFrame :
@@ -739,9 +722,9 @@ def per_period_risk_metrics(
739722 Notes
740723 -----
741724
742- Periods are left inclusing, right excluding, meaning for instance ,
743- "2018-01-01 to 2024-01-01" is the average risk from 2018-01-01 included
744- to 2023-12-31 included.
725+ If freq is given, periods are left inclusing, right excluding,
726+ meaning for instance, "2018-01-01 to 2024-01-01" is the
727+ average risk from 2018-01-01 included to 2023-12-31 included.
745728
746729 If the last date is at odd with the frequency given, the aggfunc is
747730 still applied over the "whole" bin inclunding the date, for instance if
@@ -751,12 +734,10 @@ def per_period_risk_metrics(
751734
752735 """
753736 metric_df = self .per_date_risk_metrics (metrics = metrics )
754- bin_edges , labels = self . _make_period_bins ( freq = freq )
737+
755738 return self ._date_to_period_agg (
756739 metric_df ,
757740 grouper = self ._grouper + [UNIT_COL_NAME ],
758- bin_edges = bin_edges ,
759- labels = labels ,
760741 colname = colname ,
761742 aggfunc = aggfunc ,
762743 )
0 commit comments