Skip to content

Commit 76d06c0

Browse files
committed
revert
1 parent 79f8842 commit 76d06c0

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

dte_adj/stratified.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ def _compute_cumulative_distribution(
262262
covariates_train = covariates[fold_mask]
263263
binomial_train = binomial[fold_mask]
264264
# Pool the records across strata and train the model
265-
# if len(np.unique(binomial_train)) > 1:
266-
# self.model = deepcopy(self.base_model)
267-
# self.model.fit(covariates_train, binomial_train)
265+
if len(np.unique(binomial_train)) > 1:
266+
self.model = deepcopy(self.base_model)
267+
self.model.fit(covariates_train, binomial_train)
268268
for s in s_list:
269269
s_mask = strata == s
270270
weight = (s_mask & treatment_mask).sum() / s_mask.sum()
@@ -274,8 +274,8 @@ def _compute_cumulative_distribution(
274274
binomial_train = binomial[subset_train_mask]
275275
# TODO: revisit the logic here
276276
if len(np.unique(binomial_train)) > 1:
277-
self.model = deepcopy(self.base_model)
278-
self.model.fit(covariates_train, binomial_train)
277+
# self.model = deepcopy(self.base_model)
278+
# self.model.fit(covariates_train, binomial_train)
279279
pass
280280
else:
281281
pred = binomial_train[0]

0 commit comments

Comments
 (0)