Skip to content

Commit cd6beb8

Browse files
authored
Refactor predict_pte (#43)
* remove width from predict_pte * simplify
1 parent 0747976 commit cd6beb8

3 files changed

Lines changed: 382 additions & 95 deletions

File tree

docs/source/get_started.rst

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ A convenience function is available to visualize distribution effects. This meth
8181
:align: center
8282

8383
To initialize the adjusted distribution function, the base model for conditional distribution function needs to be passed.
84-
In the following example, we use Logistic Regression. Please make sure that your base model implements ``fit`` and ``predict_proba`` methods.
84+
In the following example, Logistic Regression is used. Please make sure that your base model implements ``fit`` and ``predict_proba`` methods.
8585

8686
.. code-block:: python
8787
@@ -104,7 +104,7 @@ DTE can be computed and visualized in the following code.
104104
:width: 450px
105105
:align: center
106106

107-
Confidence bands can be computed in different ways. In the following code, we use moment method to calculate the confidence band.
107+
Confidence bands can be computed in different ways. In the following code, moment condition is used to calculate the confidence band.
108108

109109
.. code-block:: python
110110
@@ -130,20 +130,22 @@ Also, an uniform confidence band is used when ``uniform`` is specified for the `
130130
:width: 450px
131131
:align: center
132132

133-
To compute PTE, we can use ``predict_pte`` method.
133+
To compute PTE, you can use ``predict_pte`` method. The ``locations`` parameter defines interval boundaries, and the method returns probability treatment effects for each interval.
134+
For each interval, the starting point is not included but the ending point is included. For example, if the `locations` is [0, 1, 2], PTE is computed for `(0, 1]` and `(1, 2]`.
134135

135136
.. code-block:: python
136137
137-
pte, lower_bound, upper_bound = estimator.predict_pte(target_treatment_arm=1, control_treatment_arm=0, width=1, locations=locations, variance_type="simple")
138-
plot(locations, pte, lower_bound, upper_bound, chart_type="bar", title="PTE of adjusted estimator with simple confidence band")
138+
pte, lower_bound, upper_bound = estimator.predict_pte(target_treatment_arm=1, control_treatment_arm=0, locations=locations, variance_type="simple")
139+
# Note: pte will have shape (len(locations)-1,) since it computes intervals between locations
140+
plot(locations[:-1], pte, lower_bound, upper_bound, chart_type="bar", title="PTE of adjusted estimator with simple confidence band")
139141
140142
.. image:: _static/pte_empirical.png
141143
:alt: PTE of adjusted estimator with simple confidence band
142144
:height: 300px
143145
:width: 450px
144146
:align: center
145147

146-
To compute QTE, we use ``predict_qte`` method. The confidence band is computed by bootstrap method.
148+
To compute QTE, you can use ``predict_qte`` method. The confidence band is computed by bootstrap method.
147149

148150
.. code-block:: python
149151

0 commit comments

Comments
 (0)