Skip to content

Commit 019f2ee

Browse files
committed
feat: add bar chart
1 parent 78da0ef commit 019f2ee

4 files changed

Lines changed: 335 additions & 311 deletions

File tree

dte_adj/__init__.py

Lines changed: 16 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from scipy.stats import norm
44
import math
55
from copy import deepcopy
6+
from .util import compute_confidence_intervals, find_le
67

78

89
class DistributionFunctionMixin(object):
@@ -33,15 +34,15 @@ def predict_dte(
3334
target_treatment_arm (int): The index of the treatment arm of the treatment group.
3435
control_treatment_arm (int): The index of the treatment arm of the control group.
3536
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
36-
alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
37+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
3738
variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
3839
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
3940
4041
Returns:
4142
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
4243
- Expected DTEs
43-
- Upper bands
44-
- Lower bands
44+
- Upper bounds
45+
- Lower bounds
4546
"""
4647
return self._compute_dtes(
4748
target_treatment_arm,
@@ -68,14 +69,14 @@ def predict_pte(
6869
control_treatment_arm (int): The index of the treatment arm of the control group.
6970
locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
7071
width (float): The width of each outcome interval.
71-
alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
72+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
7273
variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
7374
7475
Returns:
7576
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
7677
- Expected PTEs
77-
- Upper bands
78-
- Lower bands
78+
- Upper bounds
79+
- Lower bounds
7980
"""
8081
return self._compute_ptes(
8182
target_treatment_arm,
@@ -102,14 +103,14 @@ def predict_qte(
102103
target_treatment_arm (int): The index of the treatment arm of the treatment group.
103104
control_treatment_arm (int): The index of the treatment arm of the control group.
104105
quantiles (np.ndarray, optional): Quantiles used for QTE. Defaults to [0.1 * i for i in range(1, 10)].
105-
alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
106+
alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05.
106107
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
107108
108109
Returns:
109110
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
110111
- Expected QTEs
111-
- Upper bands
112-
- Lower bands
112+
- Upper bounds
113+
- Lower bounds
113114
"""
114115
qte = self._compute_qtes(
115116
target_treatment_arm,
@@ -170,7 +171,7 @@ def _compute_dtes(
170171

171172
mat_indicator = (self.outcome[:, np.newaxis] <= locations).astype(int)
172173

173-
lower_band, upper_band = compute_confidence_intervals(
174+
lower_bound, upper_bound = compute_confidence_intervals(
174175
vec_y=self.outcome,
175176
vec_d=self.treatment_arm,
176177
vec_loc=locations,
@@ -188,8 +189,8 @@ def _compute_dtes(
188189

189190
return (
190191
dte,
191-
lower_band,
192-
upper_band,
192+
lower_bound,
193+
upper_bound,
193194
)
194195

195196
def _compute_ptes(
@@ -248,7 +249,7 @@ def _compute_ptes(
248249
int
249250
)
250251

251-
lower_band, upper_band = compute_confidence_intervals(
252+
lower_bound, upper_bound = compute_confidence_intervals(
252253
vec_y=self.outcome,
253254
vec_d=self.treatment_arm,
254255
vec_loc=locations,
@@ -266,8 +267,8 @@ def _compute_ptes(
266267

267268
return (
268269
pte,
269-
lower_band,
270-
upper_band,
270+
lower_bound,
271+
upper_bound,
271272
)
272273

273274
def _compute_qtes(
@@ -326,28 +327,6 @@ def _compute_cumulative_distribution(
326327
raise NotImplementedError()
327328

328329

329-
def find_le(array: np.ndarray, threshold):
330-
"""Find the rightmost value less than or equal to threshold in a sorted array
331-
332-
Args:
333-
array (np.ndarray): The sorted array to search in.
334-
threshold (float): The threshold value.
335-
336-
Returns:
337-
int: The index where the value first exceeds the threshold.
338-
"""
339-
low, high = 0, array.shape[0] - 1
340-
result = -1
341-
while low <= high:
342-
mid = (low + high) // 2
343-
if array[mid] <= threshold:
344-
result = mid
345-
low = mid + 1
346-
else:
347-
high = mid - 1
348-
return result
349-
350-
351330
class SimpleDistributionEstimator(DistributionFunctionMixin):
352331
"""A class for computing the empirical distribution function and the distributional parameters
353332
based on the distribution function.
@@ -564,108 +543,3 @@ def _compute_cumulative_distribution(
564543
)
565544
return cumulative_distribution, superset_prediction
566545

567-
568-
def compute_confidence_intervals(
569-
vec_y: np.ndarray,
570-
vec_d: np.ndarray,
571-
vec_loc: np.ndarray,
572-
mat_y_u: np.ndarray,
573-
vec_prediction_target: np.ndarray,
574-
vec_prediction_control: np.ndarray,
575-
mat_entire_predictions_target: np.ndarray,
576-
mat_entire_predictions_control: np.ndarray,
577-
ind_target: int,
578-
ind_control: int,
579-
alpha: 0.05,
580-
variance_type="moment",
581-
n_bootstrap=500,
582-
):
583-
"""Computes the confidence intervals of distribution parameters.
584-
585-
Args:
586-
vec_y (np.ndarray): Outcome variable vector.
587-
vec_d (np.ndarray): Treatment indicator vector.
588-
vec_loc (np.ndarray): Locations where the distribution parameters are estimated.
589-
mat_y_u (np.ndarray): Indicator function for 1{Y⩽y}. Shape is n_obs * n_loc.
590-
vec_prediction_target (np.ndarray): Estimated values from the conditional model for the treatment group.
591-
vec_prediction_control (np.ndarray): Estimated values from the conditional model for the control group.
592-
mat_entire_predictions_target (np.ndarray): Prediction of the conditional distribution estimator for target group.
593-
mat_entire_predictions_control (np.ndarray): Prediction of the conditional distribution estimator for control group.
594-
ind_target (int): Index of the target treatment indicator.
595-
ind_control (int): Index of the control treatment indicator.
596-
alpha (float, optional): Significance level of the confidence band. Defaults to 0.05.
597-
variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
598-
n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 500.
599-
600-
Returns:
601-
Tuple[np.ndarray, np.ndarray]: A tuple containing:
602-
- np.ndarray: lower band.
603-
- np.ndarray: upper band.
604-
"""
605-
num_obs = vec_y.shape[0]
606-
n_loc = vec_loc.shape[0]
607-
mat_d = np.tile(vec_d, (n_loc, 1)).T
608-
vec_dte = vec_prediction_target - vec_prediction_control
609-
mat_dte = np.tile(vec_dte, (num_obs, 1))
610-
611-
num_target = (vec_d == ind_target).sum()
612-
num_control = (vec_d == ind_control).sum()
613-
influence_function = (
614-
num_obs
615-
/ num_target
616-
* (mat_d == ind_target)
617-
* (mat_y_u - mat_entire_predictions_target)
618-
+ mat_entire_predictions_target
619-
- num_obs
620-
/ num_control
621-
* (mat_d == ind_control)
622-
* (mat_y_u - mat_entire_predictions_control)
623-
- mat_entire_predictions_control
624-
- mat_dte
625-
)
626-
627-
omega = (influence_function**2).mean(axis=0)
628-
629-
if variance_type == "moment":
630-
vec_dte_lower_moment = vec_dte + norm.ppf(alpha / 2) * np.sqrt(omega / num_obs)
631-
vec_dte_upper_moment = vec_dte + norm.ppf(1 - alpha / 2) * np.sqrt(
632-
omega / num_obs
633-
)
634-
return vec_dte_lower_moment, vec_dte_upper_moment
635-
elif variance_type == "uniform":
636-
tstats = np.zeros((n_bootstrap, len(vec_loc)))
637-
boot_draw = np.zeros((n_bootstrap, len(vec_loc)))
638-
639-
for b in range(n_bootstrap):
640-
eta1 = np.random.normal(0, 1, num_obs)
641-
eta2 = np.random.normal(0, 1, num_obs)
642-
xi = eta1 / np.sqrt(2) + (eta2**2 - 1) / 2
643-
644-
boot_draw[b, :] = (
645-
1 / num_obs * np.sum(xi[:, np.newaxis] * influence_function, axis=0)
646-
)
647-
648-
tstats = np.abs(boot_draw)[:, :-1] / np.sqrt(omega[:-1] / num_obs)
649-
max_tstats = np.max(tstats, axis=1)
650-
quantile_max_tstats = np.quantile(max_tstats, 1 - alpha)
651-
652-
vec_dte_lower_boot = vec_dte - quantile_max_tstats * np.sqrt(omega / num_obs)
653-
vec_dte_upper_boot = vec_dte + quantile_max_tstats * np.sqrt(omega / num_obs)
654-
return vec_dte_lower_boot, vec_dte_upper_boot
655-
elif variance_type == "simple":
656-
w_target = num_obs / num_target
657-
w_control = num_obs / num_control
658-
vec_dte_var = w_target * (
659-
vec_prediction_target * (1 - vec_prediction_target)
660-
) + w_control * vec_prediction_control * (1 - vec_prediction_control)
661-
662-
vec_dte_lower_simple = vec_dte + norm.ppf(alpha / 2) / np.sqrt(
663-
num_obs
664-
) * np.sqrt(vec_dte_var)
665-
vec_dte_upper_simple = vec_dte + norm.ppf(1 - alpha / 2) / np.sqrt(
666-
num_obs
667-
) * np.sqrt(vec_dte_var)
668-
669-
return vec_dte_lower_simple, vec_dte_upper_simple
670-
else:
671-
raise RuntimeError(f"Invalid variance type was speficied: {variance_type}")

dte_adj/plot/__init__.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66

77
def plot(
8-
x_values: np.ndarray,
9-
y_values: np.ndarray,
10-
upper_bands: np.ndarray,
11-
lower_bands: np.ndarray,
8+
X: np.ndarray,
9+
means: np.ndarray,
10+
upper_bounds: np.ndarray,
11+
lower_bounds: np.ndarray,
12+
chart_type="line",
1213
ax: Optional[axis.Axis] = None,
1314
title: Optional[str] = None,
1415
xlabel: Optional[str] = None,
@@ -17,10 +18,11 @@ def plot(
1718
"""Visualize distributional parameters and their confidence intervals.
1819
1920
Args:
20-
x_values (np.Array): values to be used for x axis.
21-
y_values (np.Array): Expected distributional parameters.
22-
upper_bands (np.Array): Upper band for the distributional parameters.
23-
lower_bands (np.Array): Lower band for the distributional parameters.
21+
X (np.Array): values to be used for x axis.
22+
means (np.Array): Expected distributional parameters.
23+
upper_bounds (np.Array): Upper bound for the distributional parameters.
24+
lower_bounds (np.Array): Lower bound for the distributional parameters.
25+
chart_type (str): Chart type of the plotting. Available values are line or bar.
2426
ax (matplotlib.axes.Axes, optional): Target axes instance. If None, a new figure and axes will be created.
2527
title (str, optional): Axes title.
2628
xlabel (str, optional): X-axis title label.
@@ -32,15 +34,28 @@ def plot(
3234
if ax is None:
3335
fig, ax = plt.subplots()
3436

35-
ax.plot(x_values, y_values, label="Values", color="blue")
36-
ax.fill_between(
37-
x_values,
38-
lower_bands,
39-
upper_bands,
40-
color="gray",
41-
alpha=0.3,
42-
label="Confidence Interval",
43-
)
37+
if chart_type == "line":
38+
ax.plot(X, means, label="Values", color="blue")
39+
ax.fill_between(
40+
X,
41+
lower_bounds,
42+
upper_bounds,
43+
color="gray",
44+
alpha=0.3,
45+
label="Confidence Interval",
46+
)
47+
elif chart_type == "bar":
48+
ax.bar(
49+
X,
50+
means,
51+
yerr=[
52+
np.clip(means - lower_bounds, 0, None),
53+
np.clip(upper_bounds - means, 0, None),
54+
],
55+
capsize=5,
56+
)
57+
else:
58+
raise ValueError(f"Chart type {chart_type} is not supported")
4459

4560
if title is not None:
4661
ax.set_title(title)

0 commit comments

Comments
 (0)