Skip to content

Commit 8fdc26d

Browse files
committed
Improving the readability of stratified analysis processing
1 parent dbfbae3 commit 8fdc26d

1 file changed

Lines changed: 61 additions & 55 deletions

File tree

docs/source/tutorials/oregon.rst

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -406,75 +406,81 @@ The Oregon experiment allows us to examine how treatment effects vary across dif
406406
# Individual Stratum Analysis with Local Estimators
407407
print("\n=== Individual Stratum Analysis (Local Estimators) ===")
408408
409+
# Helper function to filter data for a specific stratum
410+
def filter_stratum_data(strata_values, stratum_name, X, Z, D, Y):
411+
"""Filter and extract data for a specific stratum"""
412+
mask = strata_values == stratum_name
413+
return {
414+
'X': X[mask],
415+
'Z': Z[mask],
416+
'D': D[mask],
417+
'Y': Y[mask],
418+
'strata': np.zeros(mask.sum(), dtype=int), # Uniform strata for subset
419+
'n_total': mask.sum(),
420+
'n_assigned': (Z[mask] == 1).sum(),
421+
'n_enrolled': (D[mask] == 1).sum()
422+
}
423+
424+
# Helper function to estimate LDTE for a stratum
425+
def estimate_stratum_ldte(stratum_data, location_step=3000, folds=3):
426+
"""Initialize estimators, fit data, and compute LDTE for a stratum"""
427+
# Initialize estimators
428+
simple_estimator = dte_adj.SimpleLocalDistributionEstimator()
429+
ml_estimator = dte_adj.AdjustedLocalDistributionEstimator(
430+
LinearRegression(),
431+
folds=folds
432+
)
433+
434+
# Fit estimators
435+
simple_estimator.fit(stratum_data['X'], stratum_data['Z'],
436+
stratum_data['D'], stratum_data['Y'], stratum_data['strata'])
437+
ml_estimator.fit(stratum_data['X'], stratum_data['Z'],
438+
stratum_data['D'], stratum_data['Y'], stratum_data['strata'])
439+
440+
# Define evaluation locations based on stratum's data range
441+
locations = np.arange(stratum_data['Y'].min(), stratum_data['Y'].max(), location_step)
442+
443+
# Compute LDTE
444+
ldte_simple, lower_simple, upper_simple = simple_estimator.predict_ldte(
445+
target_treatment_arm=1, control_treatment_arm=0, locations=locations
446+
)
447+
ldte_ml, lower_ml, upper_ml = ml_estimator.predict_ldte(
448+
target_treatment_arm=1, control_treatment_arm=0, locations=locations
449+
)
450+
451+
return {
452+
'simple': {'ldte': ldte_simple, 'lower': lower_simple, 'upper': upper_simple},
453+
'ml': {'ldte': ldte_ml, 'lower': lower_ml, 'upper': upper_ml},
454+
'locations': locations,
455+
'sample_size': stratum_data['n_total'],
456+
'treatment_assignment_size': stratum_data['n_assigned'],
457+
'treatment_indicator_size': stratum_data['n_enrolled']
458+
}
459+
409460
# Get strata values (already consolidated in preprocessing)
410461
strata_consolidated_values = df['strata'].values
411462
unique_consolidated_strata = np.unique(strata_consolidated_values)
412463
413-
# Individual estimations for each stratum
464+
# Analyze each stratum
414465
individual_results = {}
415-
416466
for stratum in unique_consolidated_strata:
417467
print(f"\nAnalyzing stratum: {stratum}")
418468
419469
# Filter data for this stratum
420-
stratum_mask = strata_consolidated_values == stratum
421-
X_stratum = X[stratum_mask]
422-
treatment_arms_stratum = Z[stratum_mask]
423-
treatment_indicator_stratum = D[stratum_mask]
424-
Y_stratum = Y_ED_CHARG_TOT_ED[stratum_mask]
425-
426-
# Create uniform strata for this subset (all observations in same stratum)
427-
strata_stratum = np.zeros(len(X_stratum), dtype=int)
428-
429-
print(f" Sample size: {len(treatment_indicator_stratum):,}")
430-
print(f" Treatment assignment (Selected): {(treatment_arms_stratum == 1).sum():,}")
431-
print(f" Treatment indicator (Enrolled): {(treatment_indicator_stratum == 1).sum():,}")
432-
433-
# Initialize local estimators for this stratum
434-
simple_stratum_estimator = dte_adj.SimpleLocalDistributionEstimator()
435-
ml_stratum_estimator = dte_adj.AdjustedLocalDistributionEstimator(
436-
LinearRegression(),
437-
folds=3 # Reduced folds due to smaller sample size
470+
stratum_data = filter_stratum_data(
471+
strata_consolidated_values, stratum, X, Z, D, Y_ED_CHARG_TOT_ED
438472
)
439473
440-
# Fit estimators on stratum data
441-
simple_stratum_estimator.fit(X_stratum, treatment_arms_stratum, treatment_indicator_stratum, Y_stratum, strata_stratum)
442-
ml_stratum_estimator.fit(X_stratum, treatment_arms_stratum, treatment_indicator_stratum, Y_stratum, strata_stratum)
443-
444-
# Define locations for this stratum based on its data range
445-
outcome_ed_costs_locations_stratum = np.arange(Y_stratum.min(), Y_stratum.max(), 3000)
474+
# Print stratum statistics
475+
print(f" Sample size: {stratum_data['n_total']:,}")
476+
print(f" Treatment assignment (Selected): {stratum_data['n_assigned']:,}")
477+
print(f" Treatment indicator (Enrolled): {stratum_data['n_enrolled']:,}")
446478
447-
# Compute LDTE for this stratum using stratum-specific locations
448-
ldte_simple_stratum, lower_simple_stratum, upper_simple_stratum = simple_stratum_estimator.predict_ldte(
449-
target_treatment_arm=1,
450-
control_treatment_arm=0,
451-
locations=outcome_ed_costs_locations_stratum
479+
# Estimate LDTE for this stratum
480+
individual_results[stratum] = estimate_stratum_ldte(
481+
stratum_data, location_step=3000, folds=3
452482
)
453483
454-
ldte_ml_stratum, lower_ml_stratum, upper_ml_stratum = ml_stratum_estimator.predict_ldte(
455-
target_treatment_arm=1,
456-
control_treatment_arm=0,
457-
locations=outcome_ed_costs_locations_stratum
458-
)
459-
460-
# Store results including the locations
461-
individual_results[stratum] = {
462-
'simple': {
463-
'ldte': ldte_simple_stratum,
464-
'lower': lower_simple_stratum,
465-
'upper': upper_simple_stratum
466-
},
467-
'ml': {
468-
'ldte': ldte_ml_stratum,
469-
'lower': lower_ml_stratum,
470-
'upper': upper_ml_stratum
471-
},
472-
'locations': outcome_ed_costs_locations_stratum,
473-
'sample_size': len(treatment_indicator_stratum),
474-
'treatment_assignment_size': (treatment_arms_stratum == 1).sum(),
475-
'treatment_indicator_size': (treatment_indicator_stratum == 1).sum()
476-
}
477-
478484
Visualization: Comparing Overall Population vs Stratified Results
479485
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
480486

0 commit comments

Comments
 (0)