@@ -105,7 +105,6 @@ def _reset_metrics(self):
105105 self ._risk_components_metrics = None
106106 self ._aai_per_group_metrics = None
107107 self ._all_risk_metrics = None
108- self ._metrics_up_to_date = False
109108
110109 @property
111110 def default_rp (self ):
@@ -249,48 +248,67 @@ def _generic_metrics(
249248
250249 return getattr (self , attr_name )
251250
252- def aai_metrics (self , npv = True ):
253- return self ._generic_metrics (
254- npv = npv , metric_name = "aai" , metric_meth = "calc_aai_metric"
251+ def _compute_metrics (
252+ self , metric_name , metric_meth , total = False , npv = True , * args , ** kwargs
253+ ):
254+ """Helper method to compute metrics and optionally return total risk."""
255+ df = self ._generic_metrics (
256+ npv = npv , metric_name = metric_name , metric_meth = metric_meth , * args , ** kwargs
255257 )
258+ if total :
259+ return self ._per_period_risk (df )
260+ return df
256261
257- def return_periods_metrics (self , return_periods = None , npv = True ):
262+ def aai_metrics (self , total = False , npv = True , * args , ** kwargs ):
263+ return self ._compute_metrics (
264+ total = total ,
265+ npv = npv ,
266+ metric_name = "aai" ,
267+ metric_meth = "calc_aai_metric" ,
268+ * args ,
269+ ** kwargs ,
270+ )
271+
272+ def return_periods_metrics (
273+ self , total = False , return_periods = None , npv = True , * args , ** kwargs
274+ ):
258275 return_periods = return_periods if return_periods else self .default_rp
259- return self ._generic_metrics (
276+ return self ._compute_metrics (
260277 npv = npv ,
261278 metric_name = "return_periods" ,
262279 metric_meth = "calc_return_periods_metric" ,
263280 return_periods = return_periods ,
281+ * args ,
282+ ** kwargs ,
264283 )
265284
266- def aai_per_group_metrics (self , npv = True ):
267- return self ._generic_metrics (
285+ def aai_per_group_metrics (self , npv = True , * args , ** kwargs ):
286+ return self ._compute_metrics (
268287 npv = npv ,
269288 metric_name = "aai_per_group" ,
270289 metric_meth = "calc_aai_per_group_metric" ,
290+ * args ,
291+ ** kwargs ,
271292 )
272293
273- def risk_components_metrics (self , npv = True ):
274- return self ._generic_metrics (
294+ def risk_components_metrics (self , npv = True , * args , ** kwargs ):
295+ return self ._compute_metrics (
275296 npv = npv ,
276297 metric_name = "risk_components" ,
277298 metric_meth = "calc_risk_components_metric" ,
299+ * args ,
300+ ** kwargs ,
278301 )
279302
280303 def all_risk_metrics (
281- self , return_periods = [50 , 100 , 500 ], npv = True
304+ self , return_periods = [50 , 100 , 500 ], npv = True , * args , ** kwargs
282305 ) -> pd .DataFrame | pd .Series :
283- if not self ._metrics_up_to_date or self ._all_risk_metrics is None :
284- aai = self .aai_metrics (npv )
285- rp = self .return_periods_metrics (return_periods , npv )
286- aai_per_group = self .aai_per_group_metrics (npv )
287- risk_components = self .risk_components_metrics (npv )
288- self ._all_risk_metrics = pd .concat (
289- [aai , rp , aai_per_group , risk_components ]
290- )
291- self ._metrics_up_to_date = True
292306
293- return self ._all_risk_metrics
307+ aai = self .aai_metrics (npv , * args , ** kwargs )
308+ rp = self .return_periods_metrics (return_periods , npv , * args , ** kwargs )
309+ aai_per_group = self .aai_per_group_metrics (npv , * args , ** kwargs )
310+ risk_components = self .risk_components_metrics (npv , * args , ** kwargs )
311+ return pd .concat ([aai , rp , aai_per_group , risk_components ])
294312
295313 @staticmethod
296314 def _get_risk_periods (
@@ -321,7 +339,7 @@ def identify_continuous_periods(group, time_unit):
321339 return group
322340
323341 grouper = cls ._grouper
324- if "group" in df .columns :
342+ if "group" in df .columns and "group" not in grouper :
325343 grouper = ["group" ] + grouper
326344
327345 df_sorted = df .sort_values (by = cls ._grouper + ["date" ])
@@ -330,14 +348,20 @@ def identify_continuous_periods(group, time_unit):
330348 identify_continuous_periods , time_unit
331349 )
332350
351+ if isinstance (colname , str ):
352+ colname = [colname ]
353+
354+ agg_dict = {
355+ "start_date" : pd .NamedAgg (column = "date" , aggfunc = "min" ),
356+ "end_date" : pd .NamedAgg (column = "date" , aggfunc = "max" ),
357+ }
358+ for col in colname :
359+ agg_dict [col ] = pd .NamedAgg (column = col , aggfunc = "sum" )
333360 # Group by the identified periods and calculate start and end dates
361+ print (df_periods )
334362 df_periods = (
335363 df_periods .groupby (grouper + ["period_id" ], dropna = False )
336- .agg (
337- start_date = pd .NamedAgg (column = "date" , aggfunc = "min" ),
338- end_date = pd .NamedAgg (column = "date" , aggfunc = "max" ),
339- total = pd .NamedAgg (column = colname , aggfunc = "sum" ),
340- )
364+ .agg (** agg_dict )
341365 .reset_index ()
342366 )
343367
@@ -346,28 +370,22 @@ def identify_continuous_periods(group, time_unit):
346370 + " to "
347371 + df_periods ["end_date" ].astype (str )
348372 )
349- df_periods = df_periods .rename (columns = {"total" : f"{ colname } " })
373+ # df_periods = df_periods.rename(columns={"total": f"{colname}"})
350374 df_periods = df_periods .drop (["period_id" , "start_date" , "end_date" ], axis = 1 )
351375 return df_periods [
352376 ["period" ] + [col for col in df_periods .columns if col != "period" ]
353377 ]
354378
355379 @property
356- def per_date_risk_metrics (self ) -> pd .DataFrame | pd .Series :
380+ def per_date_risk_metrics (self , * args , ** kwargs ) -> pd .DataFrame | pd .Series :
357381 """Returns a tidy dataframe of the risk metrics for all dates."""
358- return self ._prepare_risk_metrics ( total = False , npv = True )
382+ return self .all_risk_metrics ( * args , ** kwargs )
359383
360384 @property
361- def total_risk_metrics (self ) -> pd .DataFrame | pd .Series :
385+ def total_risk_metrics (self , * args , ** kwargs ) -> pd .DataFrame | pd .Series :
362386 """Returns a tidy dataframe of the risk metrics with the total for each different period."""
363- return self ._prepare_risk_metrics (total = True , npv = True )
364-
365- def _prepare_risk_metrics (self , total = False , npv = True ) -> pd .DataFrame | pd .Series :
366- df = self .all_risk_metrics (npv = npv )
367- if total :
368- return self ._per_period_risk (df )
369-
370- return df
387+ df = self .all_risk_metrics (* args , ** kwargs )
388+ return self ._per_period_risk (df )
371389
372390 def _calc_waterfall_plot_data (self , start_date = None , end_date = None , npv = True ):
373391 start_date = self .start_date if start_date is None else start_date
0 commit comments