@@ -485,21 +485,27 @@ def calc_aai_metric(self):
485485
486486 def calc_aai_per_group_metric (self ):
487487 aai_per_group_df = []
488- for group in np .unique (np .concatenate (self ._group_id_E0 , self ._group_id_E1 )):
488+ for group in np .unique (np .concatenate ([ self ._group_id_E0 , self ._group_id_E1 ] )):
489489 group_idx_E0 = np .where (self ._group_id_E0 != group )
490490 group_idx_E1 = np .where (self ._group_id_E1 != group )
491491 per_date_aai_H0 , per_date_aai_H1 = (
492- self .per_date_eai_H0 [group_idx_E0 ].sum (),
493- self .per_date_eai_H1 [group_idx_E1 ].sum (),
492+ self .per_date_eai_H0 [:, group_idx_E0 ].sum (),
493+ self .per_date_eai_H1 [:, group_idx_E1 ].sum (),
494494 )
495495 per_date_aai = (
496496 self ._prop_H0 * per_date_aai_H0 + self ._prop_H1 * per_date_aai_H1
497497 )
498498 df = pd .DataFrame (index = self .date_idx , columns = ["risk" ], data = per_date_aai )
499- df ["group" ] = pd . NA
500- aai_per_group_df += df
499+ df ["group" ] = group
500+ aai_per_group_df . append ( df )
501501
502- return pd .concat (aai_per_group_df )
502+ aai_per_group_df = pd .concat (aai_per_group_df )
503+ aai_per_group_df ["metric" ] = "aai"
504+ aai_per_group_df ["measure" ] = (
505+ self .measure .name if self .measure else "no_measure"
506+ )
507+ aai_per_group_df .reset_index (inplace = True )
508+ return aai_per_group_df
503509
504510 def calc_return_periods_metric (self , return_periods ):
505511 rp_0 , rp_1 = self .per_date_return_periods_H0 (
0 commit comments