@@ -406,75 +406,81 @@ The Oregon experiment allows us to examine how treatment effects vary across dif
406406 # Individual Stratum Analysis with Local Estimators
407407 print (" \n === Individual Stratum Analysis (Local Estimators) ===" )
408408
409+ # Helper function to filter data for a specific stratum
410+ def filter_stratum_data (strata_values , stratum_name , X , Z , D , Y ):
411+ """ Filter and extract data for a specific stratum"""
412+ mask = strata_values == stratum_name
413+ return {
414+ ' X' : X[mask],
415+ ' Z' : Z[mask],
416+ ' D' : D[mask],
417+ ' Y' : Y[mask],
418+ ' strata' : np.zeros(mask.sum(), dtype = int ), # Uniform strata for subset
419+ ' n_total' : mask.sum(),
420+ ' n_assigned' : (Z[mask] == 1 ).sum(),
421+ ' n_enrolled' : (D[mask] == 1 ).sum()
422+ }
423+
424+ # Helper function to estimate LDTE for a stratum
425+ def estimate_stratum_ldte (stratum_data , location_step = 3000 , folds = 3 ):
426+ """ Initialize estimators, fit data, and compute LDTE for a stratum"""
427+ # Initialize estimators
428+ simple_estimator = dte_adj.SimpleLocalDistributionEstimator()
429+ ml_estimator = dte_adj.AdjustedLocalDistributionEstimator(
430+ LinearRegression(),
431+ folds = folds
432+ )
433+
434+ # Fit estimators
435+ simple_estimator.fit(stratum_data[' X' ], stratum_data[' Z' ],
436+ stratum_data[' D' ], stratum_data[' Y' ], stratum_data[' strata' ])
437+ ml_estimator.fit(stratum_data[' X' ], stratum_data[' Z' ],
438+ stratum_data[' D' ], stratum_data[' Y' ], stratum_data[' strata' ])
439+
440+ # Define evaluation locations based on stratum's data range
441+ locations = np.arange(stratum_data[' Y' ].min(), stratum_data[' Y' ].max(), location_step)
442+
443+ # Compute LDTE
444+ ldte_simple, lower_simple, upper_simple = simple_estimator.predict_ldte(
445+ target_treatment_arm = 1 , control_treatment_arm = 0 , locations = locations
446+ )
447+ ldte_ml, lower_ml, upper_ml = ml_estimator.predict_ldte(
448+ target_treatment_arm = 1 , control_treatment_arm = 0 , locations = locations
449+ )
450+
451+ return {
452+ ' simple' : {' ldte' : ldte_simple, ' lower' : lower_simple, ' upper' : upper_simple},
453+ ' ml' : {' ldte' : ldte_ml, ' lower' : lower_ml, ' upper' : upper_ml},
454+ ' locations' : locations,
455+ ' sample_size' : stratum_data[' n_total' ],
456+ ' treatment_assignment_size' : stratum_data[' n_assigned' ],
457+ ' treatment_indicator_size' : stratum_data[' n_enrolled' ]
458+ }
459+
409460 # Get strata values (already consolidated in preprocessing)
410461 strata_consolidated_values = df[' strata' ].values
411462 unique_consolidated_strata = np.unique(strata_consolidated_values)
412463
413- # Individual estimations for each stratum
464+ # Analyze each stratum
414465 individual_results = {}
415-
416466 for stratum in unique_consolidated_strata:
417467 print (f " \n Analyzing stratum: { stratum} " )
418468
419469 # Filter data for this stratum
420- stratum_mask = strata_consolidated_values == stratum
421- X_stratum = X[stratum_mask]
422- treatment_arms_stratum = Z[stratum_mask]
423- treatment_indicator_stratum = D[stratum_mask]
424- Y_stratum = Y_ED_CHARG_TOT_ED [stratum_mask]
425-
426- # Create uniform strata for this subset (all observations in same stratum)
427- strata_stratum = np.zeros(len (X_stratum), dtype = int )
428-
429- print (f " Sample size: { len (treatment_indicator_stratum):, } " )
430- print (f " Treatment assignment (Selected): { (treatment_arms_stratum == 1 ).sum():, } " )
431- print (f " Treatment indicator (Enrolled): { (treatment_indicator_stratum == 1 ).sum():, } " )
432-
433- # Initialize local estimators for this stratum
434- simple_stratum_estimator = dte_adj.SimpleLocalDistributionEstimator()
435- ml_stratum_estimator = dte_adj.AdjustedLocalDistributionEstimator(
436- LinearRegression(),
437- folds = 3 # Reduced folds due to smaller sample size
470+ stratum_data = filter_stratum_data(
471+ strata_consolidated_values, stratum, X, Z, D, Y_ED_CHARG_TOT_ED
438472 )
439473
440- # Fit estimators on stratum data
441- simple_stratum_estimator.fit(X_stratum, treatment_arms_stratum, treatment_indicator_stratum, Y_stratum, strata_stratum)
442- ml_stratum_estimator.fit(X_stratum, treatment_arms_stratum, treatment_indicator_stratum, Y_stratum, strata_stratum)
443-
444- # Define locations for this stratum based on its data range
445- outcome_ed_costs_locations_stratum = np.arange(Y_stratum.min(), Y_stratum.max(), 3000 )
474+ # Print stratum statistics
475+ print (f " Sample size: { stratum_data[' n_total' ]:, } " )
476+ print (f " Treatment assignment (Selected): { stratum_data[' n_assigned' ]:, } " )
477+ print (f " Treatment indicator (Enrolled): { stratum_data[' n_enrolled' ]:, } " )
446478
447- # Compute LDTE for this stratum using stratum-specific locations
448- ldte_simple_stratum, lower_simple_stratum, upper_simple_stratum = simple_stratum_estimator.predict_ldte(
449- target_treatment_arm = 1 ,
450- control_treatment_arm = 0 ,
451- locations = outcome_ed_costs_locations_stratum
479+ # Estimate LDTE for this stratum
480+ individual_results[stratum] = estimate_stratum_ldte(
481+ stratum_data, location_step = 3000 , folds = 3
452482 )
453483
454- ldte_ml_stratum, lower_ml_stratum, upper_ml_stratum = ml_stratum_estimator.predict_ldte(
455- target_treatment_arm = 1 ,
456- control_treatment_arm = 0 ,
457- locations = outcome_ed_costs_locations_stratum
458- )
459-
460- # Store results including the locations
461- individual_results[stratum] = {
462- ' simple' : {
463- ' ldte' : ldte_simple_stratum,
464- ' lower' : lower_simple_stratum,
465- ' upper' : upper_simple_stratum
466- },
467- ' ml' : {
468- ' ldte' : ldte_ml_stratum,
469- ' lower' : lower_ml_stratum,
470- ' upper' : upper_ml_stratum
471- },
472- ' locations' : outcome_ed_costs_locations_stratum,
473- ' sample_size' : len (treatment_indicator_stratum),
474- ' treatment_assignment_size' : (treatment_arms_stratum == 1 ).sum(),
475- ' treatment_indicator_size' : (treatment_indicator_stratum == 1 ).sum()
476- }
477-
478484 Visualization: Comparing Overall Population vs Stratified Results
479485~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
480486
0 commit comments