22from typing import Tuple
33from scipy .stats import norm
44from copy import deepcopy
5- from .util import compute_confidence_intervals , find_le
5+ from abc import ABC
6+ from .util import compute_confidence_intervals
67
78__all__ = ["SimpleDistributionEstimator" , "AdjustedDistributionEstimator" ]
89
910
10- class DistributionFunctionMixin ( object ):
11+ class DistributionEstimatorBase ( ABC ):
1112 """A mixin including several convenience functions to compute and display distribution functions."""
1213
1314 def __init__ (self ):
@@ -311,55 +312,18 @@ def find_quantile(quantile, arm):
311312
312313 return result
313314
314- def predict (self , treatment_arms : np .ndarray , outcomes : np .ndarray ) -> np .ndarray :
315- """Compute cumulative distribution values.
316-
317- Args:
318- treatment_arms (np.ndarray): The index of the treatment arm.
319- outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
320-
321- Returns:
322- np.ndarray: Estimated cumulative distribution values for the input.
323- """
324- raise NotImplementedError ()
325-
326- def _compute_cumulative_distribution (
327- self ,
328- target_treatment_arms : np .ndarray ,
329- locations : np .ndarray ,
330- confoundings : np .ndarray ,
331- treatment_arms : np .ndarray ,
332- outcomes : np .array ,
333- ) -> np .ndarray :
334- """Compute the cumulative distribution values."""
335- raise NotImplementedError ()
336-
337-
338- class SimpleDistributionEstimator (DistributionFunctionMixin ):
339- """A class for computing the empirical distribution function and the distributional parameters
340- based on the distribution function.
341- """
342-
343- def __init__ (self ):
344- """Initializes the SimpleDistributionEstimator.
345-
346- Returns:
347- SimpleDistributionEstimator: An instance of the estimator.
348- """
349- super ().__init__ ()
350-
351315 def fit (
352316 self , confoundings : np .ndarray , treatment_arms : np .ndarray , outcomes : np .ndarray
353- ) -> "SimpleDistributionEstimator " :
354- """Train the SimpleDistributionEstimator .
317+ ) -> "DistributionEstimatorBase " :
318+ """Train the DistributionEstimatorBase .
355319
356320 Args:
357321 confoundings (np.ndarray): Pre-treatment covariates.
358322 treatment_arms (np.ndarray): The index of the treatment arm.
359323 outcomes (np.ndarray): Scalar-valued observed outcome.
360324
361325 Returns:
362- SimpleDistributionEstimator : The fitted estimator.
326+ DistributionEstimatorBase : The fitted estimator.
363327 """
364328 if confoundings .shape [0 ] != treatment_arms .shape [0 ]:
365329 raise ValueError (
@@ -380,7 +344,7 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
380344
381345 Args:
382346 treatment_arms (np.ndarray): The index of the treatment arm.
383- locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
347+ outcomes (np.ndarray): Scalar values to be used for computing the cumulative distribution.
384348
385349 Returns:
386350 np.ndarray: Estimated cumulative distribution values for the input.
@@ -390,6 +354,13 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
390354 "This estimator has not been trained yet. Please call fit first"
391355 )
392356
357+ unincluded_arms = set (treatment_arms ) - set (self .treatment_arms )
358+
359+ if len (unincluded_arms ) > 0 :
360+ raise ValueError (
361+ f"This treatment_arms argument contains arms not included in the training data: { unincluded_arms } "
362+ )
363+
393364 return self ._compute_cumulative_distribution (
394365 treatment_arms ,
395366 locations ,
@@ -398,6 +369,31 @@ def predict(self, treatment_arms: np.ndarray, locations: np.ndarray) -> np.ndarr
398369 self .outcomes ,
399370 )[0 ]
400371
372+ def _compute_cumulative_distribution (
373+ self ,
374+ target_treatment_arms : np .ndarray ,
375+ locations : np .ndarray ,
376+ confoundings : np .ndarray ,
377+ treatment_arms : np .ndarray ,
378+ outcomes : np .array ,
379+ ) -> np .ndarray :
380+ """Compute the cumulative distribution values."""
381+ raise NotImplementedError ()
382+
383+
384+ class SimpleDistributionEstimator (DistributionEstimatorBase ):
385+ """A class for computing the empirical distribution function and the distributional parameters
386+ based on the distribution function.
387+ """
388+
389+ def __init__ (self ):
390+ """Initializes the SimpleDistributionEstimator.
391+
392+ Returns:
393+ SimpleDistributionEstimator: An instance of the estimator.
394+ """
395+ super ().__init__ ()
396+
401397 def _compute_cumulative_distribution (
402398 self ,
403399 target_treatment_arms : np .ndarray ,
@@ -432,12 +428,12 @@ def _compute_cumulative_distribution(
432428 cumulative_distribution = np .zeros (locations .shape )
433429 for i , (outcome , arm ) in enumerate (zip (locations , target_treatment_arms )):
434430 cumulative_distribution [i ] = (
435- find_le (d_outcome [arm ], outcome ) + 1
431+ np . searchsorted (d_outcome [arm ], outcome , side = "right" )
436432 ) / d_outcome [arm ].shape [0 ]
437433 return cumulative_distribution , np .zeros ((n_obs , n_loc ))
438434
439435
440- class AdjustedDistributionEstimator (DistributionFunctionMixin ):
436+ class AdjustedDistributionEstimator (DistributionEstimatorBase ):
441437 """A class is for estimating the adjusted distribution function and computing the Distributional parameters based on the trained conditional estimator."""
442438
443439 def __init__ (self , base_model , folds = 3 ):
@@ -450,60 +446,16 @@ def __init__(self, base_model, folds=3):
450446 Returns:
451447 AdjustedDistributionEstimator: An instance of the estimator.
452448 """
449+ if (not hasattr (base_model , "predict" )) and (
450+ not hasattr (base_model , "predict_proba" )
451+ ):
452+ raise ValueError (
453+ "Base model should implement either predict_proba or predict"
454+ )
453455 self .base_model = base_model
454456 self .folds = folds
455457 super ().__init__ ()
456458
457- def fit (
458- self , confoundings : np .ndarray , treatment_arms : np .ndarray , outcomes : np .ndarray
459- ) -> "AdjustedDistributionEstimator" :
460- """Train the AdjustedDistributionEstimator.
461-
462- Args:
463- confoundings (np.ndarray): Pre-treatment covariates.
464- treatment_arms (np.ndarray): The index of the treatment arm.
465- outcomes (np.ndarray): Scalar-valued observed outcome.
466-
467- Returns:
468- AdjustedDistributionEstimator: The fitted estimator.
469- """
470- if confoundings .shape [0 ] != treatment_arms .shape [0 ]:
471- raise ValueError (
472- "The shape of confounding and treatment_arm should be same"
473- )
474-
475- if confoundings .shape [0 ] != outcomes .shape [0 ]:
476- raise ValueError ("The shape of confounding and outcome should be same" )
477-
478- self .confoundings = confoundings
479- self .treatment_arms = treatment_arms
480- self .outcomes = outcomes
481-
482- return self
483-
484- def predict (self , treatment_arms : np .ndarray , locations : np .ndarray ) -> np .ndarray :
485- """Compute cumulative distribution values.
486-
487- Args:
488- treatment_arms (np.ndarray): The index of the treatment arm.
489- locations (np.ndarray): Scalar values to be used for computing the cumulative distribution.
490-
491- Returns:
492- np.ndarray: Estimated cumulative distribution values for the input.
493- """
494- if self .outcomes is None :
495- raise ValueError (
496- "This estimator has not been trained yet. Please call fit first"
497- )
498-
499- return self ._compute_cumulative_distribution (
500- treatment_arms ,
501- locations ,
502- self .confoundings ,
503- self .treatment_arms ,
504- self .outcomes ,
505- )[0 ]
506-
507459 def _compute_cumulative_distribution (
508460 self ,
509461 target_treatment_arms : np .ndarray ,
@@ -548,13 +500,19 @@ def _compute_cumulative_distribution(
548500 continue
549501 model = deepcopy (self .base_model )
550502 model .fit (confounding_train , binominal_train )
551- subset_prediction [subset_mask ] = model . predict_proba ( confounding_fit )[
552- :, 1
553- ]
554- superset_prediction [superset_mask , i ] = model . predict_proba (
555- confoundings [superset_mask ]
556- )[:, 1 ]
503+ subset_prediction [subset_mask ] = self . _compute_model_prediction (
504+ model , confounding_fit
505+ )
506+ superset_prediction [superset_mask , i ] = self . _compute_model_prediction (
507+ model , confoundings [superset_mask ]
508+ )
557509 cumulative_distribution [i ] = (
558510 cdf - subset_prediction .mean () + superset_prediction [:, i ].mean ()
559511 )
560512 return cumulative_distribution , superset_prediction
513+
514+ def _compute_model_prediction (self , model , confoundings : np .ndarray ) -> np .ndarray :
515+ if hasattr (model , "predict_proba" ):
516+ return model .predict_proba (confoundings )[:, 1 ]
517+ else :
518+ return model .predict (confoundings )
0 commit comments