Skip to content

Commit 53fcff5

Browse files
authored
Update register (#272)
* update register * fix test * update pymc branch
1 parent 9f8832d commit 53fcff5

3 files changed

Lines changed: 6 additions & 4 deletions

File tree

pymc_bart/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import pymc as pm
14+
from pymc.sampling import mcmc
1515

1616
from pymc_bart.bart import BART
1717
from pymc_bart.pgbart import PGBART
@@ -47,4 +47,6 @@
4747
__version__ = "0.11.0"
4848

4949

50-
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
50+
methods = mcmc.STEP_METHODS
51+
if not any(method is PGBART for method in methods):
52+
methods.append(PGBART)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pytensor>=3.0.0
2-
pymc @ git+https://github.com/pymc-devs/pymc.git@v6
2+
pymc @ git+https://github.com/pymc-devs/pymc.git
33
arviz-stats[xarray]>=1.1.0
44
numba
55
matplotlib

tests/test_bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_shared_variable(response):
9797
idata = pm.sample(tune=100, draws=100, chains=2, random_seed=3415)
9898
ppc = pm.sample_posterior_predictive(idata)
9999
pm.set_data({"data_X": X[:3]})
100-
ppc2 = pm.sample_posterior_predictive(idata)
100+
ppc2 = pm.sample_posterior_predictive(idata, sample_vars=["mu", "y"])
101101

102102
assert ppc.posterior_predictive["y"].shape == (2, 100, 50)
103103
assert ppc2.posterior_predictive["y"].shape == (2, 100, 3)

0 commit comments

Comments
 (0)