@@ -173,6 +173,7 @@ def pairwise(container: list):
173173 next (b , None )
174174 return zip (a , b )
175175
176+ LOGGER .debug (f"{ self .__class__ .__name__ } : Calc risk periods" )
176177 # impfset = self._merge_impfset(snapshots)
177178 return [
178179 CalcRiskPeriod (
@@ -227,26 +228,25 @@ def _generic_metrics(
227228 # Construct the attribute name for storing the metric results
228229 attr_name = f"_{ metric_name } _metrics"
229230
230- if getattr (self , attr_name , None ) is None :
231- tmp = []
232- for calc_period in self .risk_periods :
233- # Call the specified method on the calc_period object
234- tmp .append (getattr (calc_period , metric_meth )(** kwargs ))
235-
236- tmp = pd .concat (tmp )
237- tmp .drop_duplicates (inplace = True )
238- tmp ["group" ] = tmp ["group" ].fillna (self ._all_groups_name )
239- columns_to_front = ["group" , "date" , "measure" , "metric" ]
240- tmp = tmp [
241- columns_to_front
242- + [
243- col
244- for col in tmp .columns
245- if col not in columns_to_front + ["group" , "risk" , "rp" ]
246- ]
247- + ["risk" ]
231+ tmp = []
232+ for calc_period in self .risk_periods :
233+ # Call the specified method on the calc_period object
234+ tmp .append (getattr (calc_period , metric_meth )(** kwargs ))
235+
236+ tmp = pd .concat (tmp )
237+ tmp .drop_duplicates (inplace = True )
238+ tmp ["group" ] = tmp ["group" ].fillna (self ._all_groups_name )
239+ columns_to_front = ["group" , "date" , "measure" , "metric" ]
240+ tmp = tmp [
241+ columns_to_front
242+ + [
243+ col
244+ for col in tmp .columns
245+ if col not in columns_to_front + ["group" , "risk" , "rp" ]
248246 ]
249- setattr (self , attr_name , tmp )
247+ + ["risk" ]
248+ ]
249+ setattr (self , attr_name , tmp )
250250
251251 if npv :
252252 return self .npv_transform (getattr (self , attr_name ), self .risk_disc )
@@ -271,40 +271,39 @@ def _compute_metrics(
271271 )
272272 return df
273273
274- def eai_metrics (self , npv : bool = True ):
274+ def eai_metrics (self , npv : bool = True , ** kwargs ):
275275 return self ._compute_metrics (
276- npv = npv ,
277- metric_name = "eai" ,
278- metric_meth = "calc_eai_gdf" ,
276+ npv = npv , metric_name = "eai" , metric_meth = "calc_eai_gdf" , ** kwargs
279277 )
280278
281- def aai_metrics (self , npv : bool = True ):
279+ def aai_metrics (self , npv : bool = True , ** kwargs ):
282280 return self ._compute_metrics (
283- npv = npv ,
284- metric_name = "aai" ,
285- metric_meth = "calc_aai_metric" ,
281+ npv = npv , metric_name = "aai" , metric_meth = "calc_aai_metric" , ** kwargs
286282 )
287283
288- def return_periods_metrics (self , return_periods , npv : bool = True ):
284+ def return_periods_metrics (self , return_periods , npv : bool = True , ** kwargs ):
289285 return self ._compute_metrics (
290286 npv = npv ,
291287 metric_name = "return_periods" ,
292288 metric_meth = "calc_return_periods_metric" ,
293289 return_periods = return_periods ,
290+ ** kwargs ,
294291 )
295292
296- def aai_per_group_metrics (self , npv : bool = True ):
293+ def aai_per_group_metrics (self , npv : bool = True , ** kwargs ):
297294 return self ._compute_metrics (
298295 npv = npv ,
299296 metric_name = "aai_per_group" ,
300297 metric_meth = "calc_aai_per_group_metric" ,
298+ ** kwargs ,
301299 )
302300
303- def risk_components_metrics (self , npv : bool = True ):
301+ def risk_components_metrics (self , npv : bool = True , ** kwargs ):
304302 return self ._compute_metrics (
305303 npv = npv ,
306304 metric_name = "risk_components" ,
307305 metric_meth = "calc_risk_components_metric" ,
306+ ** kwargs ,
308307 )
309308
310309 def per_date_risk_metrics (
0 commit comments