Skip to content

Commit 4525331

Browse files
authored
Use PyMC v6, Pytensor v3 and ArviZ 1.1 (#269)
* use v6, v3 and 1.1 * update python version
1 parent 02aa8e6 commit 4525331

4 files changed

Lines changed: 14 additions & 13 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ["3.11", "3.12", "3.13"]
14+
python-version: ["3.12", "3.13", "3.14"]
1515
linker: [cvm, numba]
1616

1717
name: Tests py${{ matrix.python-version }} ${{ matrix.linker }}

pymc_bart/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def plot_convergence(
9090
9191
Parameters
9292
----------
93-
idata : InferenceData
94-
InferenceData object containing the posterior samples.
93+
idata : DataTree
94+
DataTree object containing the posterior samples.
9595
var_name : Optional[str]
9696
Name of the BART variable to plot. Defaults to None.
9797
kind : str
@@ -683,8 +683,8 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None
683683
684684
Parameters
685685
----------
686-
idata : InferenceData
687-
InferenceData with a variable "variable_inclusion" in ``sample_stats`` group
686+
idata : DataTree
687+
DataTree with a variable "variable_inclusion" in ``sample_stats`` group
688688
X : npt.NDArray
689689
The covariate matrix.
690690
model : Optional[pm.Model]
@@ -745,8 +745,8 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
745745
746746
Parameters
747747
----------
748-
idata : InferenceData
749-
InferenceData containing a collection of BART_trees in sample_stats group
748+
idata : DataTree
749+
DataTree containing a collection of BART_trees in sample_stats group
750750
X : npt.NDArray
751751
The covariate matrix.
752752
labels : Optional[list[str]]
@@ -813,8 +813,8 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
813813
814814
Parameters
815815
----------
816-
idata : InferenceData
817-
InferenceData containing a "variable_inclusion" variable in the sample_stats group.
816+
idata : DataTree
817+
DataTree containing a "variable_inclusion" variable in the sample_stats group.
818818
bartrv : BART Random Variable
819819
BART variable once the model that include it has been fitted.
820820
X : npt.NDArray

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
pymc>=5.24.0
2-
arviz-stats[xarray]>=0.6.0
1+
pytensor @ git+https://github.com/pymc-devs/pytensor.git@v3
2+
pymc @ git+https://github.com/pymc-devs/pymc.git@v6
3+
arviz-stats[xarray]>=1.1.0
34
numba
45
matplotlib
56
numpy>=2.0

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
"Development Status :: 5 - Production/Stable",
3030
"Programming Language :: Python",
3131
"Programming Language :: Python :: 3",
32-
"Programming Language :: Python :: 3.11",
3332
"Programming Language :: Python :: 3.12",
3433
"Programming Language :: Python :: 3.13",
34+
"Programming Language :: Python :: 3.14",
3535
"License :: OSI Approved :: Apache Software License",
3636
"Intended Audience :: Science/Research",
3737
"Topic :: Scientific/Engineering",
@@ -76,6 +76,6 @@ def get_version():
7676
packages=find_packages(),
7777
include_package_data=True,
7878
classifiers=classifiers,
79-
python_requires=">=3.11",
79+
python_requires=">=3.12",
8080
install_requires=install_reqs,
8181
)

0 commit comments

Comments
 (0)