Skip to content

Commit 7d0051a

Browse files
committed
comment
1 parent 0f7c307 commit 7d0051a

1 file changed

Lines changed: 6 additions & 10 deletions

File tree

dte_adj/stratified.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,8 @@ def _compute_cumulative_distribution(
7373
for s in s_list:
7474
s_mask = strata == s
7575
w_s[s] = (s_mask & treatment_mask).sum() / s_mask.sum()
76-
n_obs = outcomes.shape[0]
77-
n_loc = locations.shape[0]
78-
for i, outcome in enumerate(locations):
79-
for j in range(n_obs):
76+
for i, outcome in enumerate(n_loc):
77+
for j in range(n_records):
8078
s = strata[j]
8179
prediction[j, i] = (outcomes[j] <= outcome) / w_s[s] * treatment_mask[j]
8280

@@ -123,10 +121,8 @@ def _compute_interval_probability(
123121
for s in s_list:
124122
s_mask = strata == s
125123
w_s[s] = (s_mask & treatment_mask).sum() / s_mask.sum()
126-
n_obs = outcomes.shape[0]
127-
n_loc = locations.shape[0]
128124
for i, outcome in enumerate(locations):
129-
for j in range(n_obs):
125+
for j in range(n_records):
130126
s = strata[j]
131127
prediction[j, i] = (outcomes[j] <= outcome) / w_s[s] * treatment_mask[j]
132128

@@ -349,7 +345,7 @@ def _compute_interval_probability(
349345
self.model.fit(covariates_train, binomial_train)
350346
for s in s_list:
351347
s_mask = strata == s
352-
wight = (s_mask & treatment_mask).sum() / s_mask.sum()
348+
weight = (s_mask & treatment_mask).sum() / s_mask.sum()
353349
superset_mask = (folds == fold) & s_mask
354350
subset_train_mask = (folds != fold) & s_mask & treatment_mask
355351
covariates_train = covariates[subset_train_mask]
@@ -361,7 +357,7 @@ def _compute_interval_probability(
361357
pred
362358
+ treatment_mask[superset_mask]
363359
* (binomial[superset_mask] - pred)
364-
/ wight
360+
/ weight
365361
)
366362
continue
367363
pred = self._compute_model_prediction(
@@ -371,7 +367,7 @@ def _compute_interval_probability(
371367
pred
372368
+ treatment_mask[superset_mask]
373369
* (binomial[superset_mask] - pred)
374-
/ wight
370+
/ weight
375371
)
376372
superset_prediction[superset_mask, i] = pred
377373

0 commit comments

Comments
 (0)