@@ -195,14 +195,14 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
195195 < span class ="n "> n_bootstrap</ span > < span class ="p "> :</ span > < span class ="nb "> int</ span > < span class ="p "> ,</ span >
196196 < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Tuple</ span > < span class ="p "> [</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ,</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> ndarray</ span > < span class ="p "> ]:</ span >
197197< span class ="w "> </ span > < span class ="sd "> """Compute expected DTEs."""</ span >
198- < span class ="n "> treatment_cdf</ span > < span class ="p "> ,</ span > < span class ="n "> _ </ span > < span class ="p "> ,</ span > < span class ="n "> treatment_cdf_mat </ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _compute_cumulative_distribution</ span > < span class ="p "> (</ span >
198+ < span class ="n "> treatment_cdf</ span > < span class ="p "> ,</ span > < span class ="n "> treatment_cdf_mat </ span > < span class ="p "> ,</ span > < span class ="n "> _ </ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _compute_cumulative_distribution</ span > < span class ="p "> (</ span >
199199 < span class ="n "> target_treatment_arm</ span > < span class ="p "> ,</ span >
200200 < span class ="n "> locations</ span > < span class ="p "> ,</ span >
201201 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> covariates</ span > < span class ="p "> ,</ span >
202202 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> treatment_arms</ span > < span class ="p "> ,</ span >
203203 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> outcomes</ span > < span class ="p "> ,</ span >
204204 < span class ="p "> )</ span >
205- < span class ="n "> control_cdf</ span > < span class ="p "> ,</ span > < span class ="n "> _ </ span > < span class ="p "> ,</ span > < span class ="n "> control_cdf_mat </ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _compute_cumulative_distribution</ span > < span class ="p "> (</ span >
205+ < span class ="n "> control_cdf</ span > < span class ="p "> ,</ span > < span class ="n "> control_cdf_mat </ span > < span class ="p "> ,</ span > < span class ="n "> _ </ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _compute_cumulative_distribution</ span > < span class ="p "> (</ span >
206206 < span class ="n "> control_treatment_arm</ span > < span class ="p "> ,</ span >
207207 < span class ="n "> locations</ span > < span class ="p "> ,</ span >
208208 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> covariates</ span > < span class ="p "> ,</ span >
@@ -665,18 +665,16 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
665665 < span class ="n "> binominal</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> outcomes</ span > < span class ="o "> .</ span > < span class ="n "> reshape</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span > < span class ="o "> <=</ span > < span class ="n "> locations</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="mi "> 1</ span > < span class ="c1 "> # (n_records, n_loc)</ span >
666666 < span class ="k "> for</ span > < span class ="n "> fold</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> folds</ span > < span class ="p "> ):</ span >
667667 < span class ="n "> fold_mask</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> folds</ span > < span class ="o "> !=</ span > < span class ="n "> fold</ span > < span class ="p "> )</ span > < span class ="o "> &</ span > < span class ="n "> treatment_mask</ span >
668- < span class ="n "> covariates_train</ span > < span class ="o "> =</ span > < span class ="n "> covariates</ span > < span class ="p "> [</ span > < span class ="n "> fold_mask</ span > < span class ="p "> ]</ span >
669- < span class ="n "> binominal_train</ span > < span class ="o "> =</ span > < span class ="n "> binominal</ span > < span class ="p "> [</ span > < span class ="n "> fold_mask</ span > < span class ="p "> ]</ span >
670- < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> unique</ span > < span class ="p "> (</ span > < span class ="n "> binominal_train</ span > < span class ="p "> ))</ span > < span class ="o "> ></ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
671- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> =</ span > < span class ="n "> deepcopy</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> base_model</ span > < span class ="p "> )</ span >
672- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> fit</ span > < span class ="p "> (</ span > < span class ="n "> covariates_train</ span > < span class ="p "> ,</ span > < span class ="n "> binominal_train</ span > < span class ="p "> )</ span >
673668 < span class ="k "> for</ span > < span class ="n "> s</ span > < span class ="ow "> in</ span > < span class ="n "> s_list</ span > < span class ="p "> :</ span >
674669 < span class ="n "> s_mask</ span > < span class ="o "> =</ span > < span class ="n "> strata</ span > < span class ="o "> ==</ span > < span class ="n "> s</ span >
675670 < span class ="n "> weight</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> s_mask</ span > < span class ="o "> &</ span > < span class ="n "> treatment_mask</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> ()</ span > < span class ="o "> /</ span > < span class ="n "> s_mask</ span > < span class ="o "> .</ span > < span class ="n "> sum</ span > < span class ="p "> ()</ span >
676671 < span class ="n "> superset_mask</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> folds</ span > < span class ="o "> ==</ span > < span class ="n "> fold</ span > < span class ="p "> )</ span > < span class ="o "> &</ span > < span class ="n "> s_mask</ span >
677672 < span class ="n "> subset_train_mask</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> folds</ span > < span class ="o "> !=</ span > < span class ="n "> fold</ span > < span class ="p "> )</ span > < span class ="o "> &</ span > < span class ="n "> s_mask</ span > < span class ="o "> &</ span > < span class ="n "> treatment_mask</ span >
678673 < span class ="n "> covariates_train</ span > < span class ="o "> =</ span > < span class ="n "> covariates</ span > < span class ="p "> [</ span > < span class ="n "> subset_train_mask</ span > < span class ="p "> ]</ span >
679674 < span class ="n "> binominal_train</ span > < span class ="o "> =</ span > < span class ="n "> binominal</ span > < span class ="p "> [</ span > < span class ="n "> subset_train_mask</ span > < span class ="p "> ]</ span >
675+ < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> unique</ span > < span class ="p "> (</ span > < span class ="n "> binominal_train</ span > < span class ="p "> ))</ span > < span class ="o "> ></ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
676+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> =</ span > < span class ="n "> deepcopy</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> base_model</ span > < span class ="p "> )</ span >
677+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> fit</ span > < span class ="p "> (</ span > < span class ="n "> covariates_train</ span > < span class ="p "> ,</ span > < span class ="n "> binominal_train</ span > < span class ="p "> )</ span >
680678
681679 < span class ="n "> pred</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _compute_model_prediction</ span > < span class ="p "> (</ span >
682680 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> covariates</ span > < span class ="p "> [</ span > < span class ="n "> superset_mask</ span > < span class ="p "> ]</ span >
@@ -695,8 +693,9 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
695693 < span class ="n "> fold_mask</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> folds</ span > < span class ="o "> !=</ span > < span class ="n "> fold</ span > < span class ="p "> )</ span > < span class ="o "> &</ span > < span class ="n "> treatment_mask</ span >
696694 < span class ="n "> covariates_train</ span > < span class ="o "> =</ span > < span class ="n "> covariates</ span > < span class ="p "> [</ span > < span class ="n "> fold_mask</ span > < span class ="p "> ]</ span >
697695 < span class ="n "> binominal_train</ span > < span class ="o "> =</ span > < span class ="n "> binominal</ span > < span class ="p "> [</ span > < span class ="n "> fold_mask</ span > < span class ="p "> ]</ span >
698- < span class ="bp " > self </ span > < span class =" o " > . </ span > < span class =" n " > model </ span > < span class =" o " > = </ span > < span class =" n " > deepcopy </ span > < span class =" p " > ( </ span > < span class =" bp " > self </ span > < span class =" o " > . </ span > < span class =" n " > base_model </ span > < span class =" p " > ) </ span >
696+ < span class ="c1 " > # Pool the records across strata and train the model </ span >
699697 < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> unique</ span > < span class ="p "> (</ span > < span class ="n "> binominal_train</ span > < span class ="p "> ))</ span > < span class ="o "> ></ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
698+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> =</ span > < span class ="n "> deepcopy</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> base_model</ span > < span class ="p "> )</ span >
700699 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> fit</ span > < span class ="p "> (</ span > < span class ="n "> covariates_train</ span > < span class ="p "> ,</ span > < span class ="n "> binominal_train</ span > < span class ="p "> )</ span >
701700 < span class ="k "> for</ span > < span class ="n "> s</ span > < span class ="ow "> in</ span > < span class ="n "> s_list</ span > < span class ="p "> :</ span >
702701 < span class ="n "> s_mask</ span > < span class ="o "> =</ span > < span class ="n "> strata</ span > < span class ="o "> ==</ span > < span class ="n "> s</ span >
@@ -705,7 +704,12 @@ <h1>Source code for dte_adj</h1><div class="highlight"><pre>
705704 < span class ="n "> subset_train_mask</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> folds</ span > < span class ="o "> !=</ span > < span class ="n "> fold</ span > < span class ="p "> )</ span > < span class ="o "> &</ span > < span class ="n "> s_mask</ span > < span class ="o "> &</ span > < span class ="n "> treatment_mask</ span >
706705 < span class ="n "> covariates_train</ span > < span class ="o "> =</ span > < span class ="n "> covariates</ span > < span class ="p "> [</ span > < span class ="n "> subset_train_mask</ span > < span class ="p "> ]</ span >
707706 < span class ="n "> binominal_train</ span > < span class ="o "> =</ span > < span class ="n "> binominal</ span > < span class ="p "> [</ span > < span class ="n "> subset_train_mask</ span > < span class ="p "> ]</ span >
708- < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> unique</ span > < span class ="p "> (</ span > < span class ="n "> binominal_train</ span > < span class ="p "> ))</ span > < span class ="o "> ==</ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
707+ < span class ="c1 "> # TODO: revisit the logic here</ span >
708+ < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> unique</ span > < span class ="p "> (</ span > < span class ="n "> binominal_train</ span > < span class ="p "> ))</ span > < span class ="o "> ></ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
709+ < span class ="c1 "> # self.model = deepcopy(self.base_model)</ span >
710+ < span class ="c1 "> # self.model.fit(covariates_train, binominal_train)</ span >
711+ < span class ="k "> pass</ span >
712+ < span class ="k "> else</ span > < span class ="p "> :</ span >
709713 < span class ="n "> pred</ span > < span class ="o "> =</ span > < span class ="n "> binominal_train</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
710714 < span class ="n "> superset_prediction</ span > < span class ="p "> [</ span > < span class ="n "> superset_mask</ span > < span class ="p "> ,</ span > < span class ="n "> i</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> pred</ span >
711715 < span class ="n "> prediction</ span > < span class ="p "> [</ span > < span class ="n "> superset_mask</ span > < span class ="p "> ,</ span > < span class ="n "> i</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> (</ span >
0 commit comments