Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ jobs:
name: build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- uses: actions/setup-python@v5
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: "3.12"
- name: Build the sdist and the wheel
Expand All @@ -41,7 +41,7 @@ jobs:
echo "Checking import and version number (on release)"
venv-bdist/bin/python -c "import pymc_extras as pmx; assert pmx.__version__ == '${{ github.ref_name }}'[1:] if '${{ github.ref_type }}' == 'tag' else pmx.__version__; print(pmx.__version__)"
cd ..
- uses: actions/upload-artifact@v4
- uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: artifact
path: dist/*
Expand All @@ -58,8 +58,8 @@ jobs:
# write id-token is necessary for trusted publishing (OIDC)
id-token: write
steps:
- uses: actions/download-artifact@v4
- uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with:
name: artifact
path: dist
- uses: pypa/gh-action-pypi-publish@release/v1
- uses: pypa/gh-action-pypi-publish@cef221092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0
16 changes: 9 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ jobs:
run:
shell: bash -leo pipefail {0}
steps:
- uses: actions/checkout@v4
- uses: mamba-org/setup-micromamba@v2
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: mamba-org/setup-micromamba@d7c9bd84e824b79d2af72a2d4196c7f4300d3476 # v3.0.0
with:
environment-file: conda-envs/environment-test.yml
create-args: >-
Expand All @@ -62,7 +62,7 @@ jobs:
run: |
python -m pytest --color=yes -vv --cov=pymc_extras --cov-append --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
env_vars: TEST_SUBSET
name: ${{ matrix.os }}
Expand All @@ -89,11 +89,13 @@ jobs:
run:
shell: cmd /C call {0}
steps:
- uses: actions/checkout@v4
- uses: mamba-org/setup-micromamba@v2
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: mamba-org/setup-micromamba@d7c9bd84e824b79d2af72a2d4196c7f4300d3476 # v3.0.0
with:
environment-file: conda-envs/environment-test.yml
micromamba-version: "latest"
# Pinned: micromamba 2.6.1-0 ships a Windows binary with a missing DLL
# dependency that breaks GitHub-hosted runners (STATUS_DLL_NOT_FOUND).
micromamba-version: "2.6.0-0"
create-args: >-
python=${{matrix.python-version}}
environment-name: pymc-extras-test
Expand All @@ -110,7 +112,7 @@ jobs:
run: >-
python -m pytest --color=yes -vv --cov=pymc_extras --cov-append --cov-report=xml --cov-report term --durations=50 %TEST_SUBSET%
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
env_vars: TEST_SUBSET
name: ${{ matrix.os }}
Expand Down
6 changes: 3 additions & 3 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- nodefaults
dependencies:
- scikit-learn
- better-optimize>=0.3.2
- better-optimize>=0.4.1
- dask<2025.1.1
- xhistogram
- statsmodels
Expand All @@ -13,9 +13,9 @@ dependencies:
- pytest-cov
- pydantic>=2.0.0
- h5netcdf
- pymc>=6.0,<7.0
- pip
- pip:
- jax
- blackjax
- pymc @ git+https://github.com/pymc-devs/pymc.git
- preliz @ git+https://github.com/arviz-devs/preliz.git
- preliz>=0.25
4 changes: 2 additions & 2 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass, field, replace
from enum import Enum, auto
from typing import Any, Literal, Self, TypeAlias
from typing import Any, Literal, Self

import numpy as np
import pymc as pm
Expand Down Expand Up @@ -79,7 +79,7 @@
REGULARISATION_TERM = 1e-8
DEFAULT_LINKER = "cvm_nogc"

SinglePathfinderFn: TypeAlias = Callable[[int], "PathfinderResult"]
type SinglePathfinderFn = Callable[[int], "PathfinderResult"]


def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Callable:
Expand Down
11 changes: 10 additions & 1 deletion pymc_extras/linearmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import pandas as pd
import pymc as pm

from sklearn.base import BaseEstimator
# If scikit-learn is available, inherit from BaseEstimator for sklearn-pipeline interop
# (Pipeline, TransformedTargetRegressor, get_params). Without it, fall back to a stub
# so LinearModel still works as a standalone Bayesian model.
try:
from sklearn.base import BaseEstimator
except ImportError:

class BaseEstimator:
pass


from pymc_extras.model_builder import ModelBuilder

Expand Down
4 changes: 2 additions & 2 deletions pymc_extras/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def custom_transform(x):
from functools import partial
from inspect import signature
from numbers import Number
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, runtime_checkable
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

import numpy as np
import pymc as pm
Expand All @@ -111,7 +111,7 @@ def custom_transform(x):
from pytensor.tensor import TensorLike
from pytensor.xtensor.type import XTensorVariable

XTensorLike: TypeAlias = TensorLike | DataArray
type XTensorLike = TensorLike | DataArray


class UnsupportedShapeError(Exception):
Expand Down
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ classifiers = [
"Development Status :: 5 - Production/Stable",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"License :: OSI Approved :: Apache Software License",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
Expand All @@ -34,11 +34,11 @@ keywords = [
license = {file = "LICENSE"}
dynamic = ["version"] # specify the version in the __init__.py file
dependencies = [
"pymc@git+https://github.com/pymc-devs/pymc.git",
"scikit-learn",
"better-optimize>=0.3.1",
"pymc>=6.0,<7.0",
"arviz>=1.1",
"better-optimize>=0.4.1,<1.0",
"pydantic>=2.0.0",
"preliz>=0.20.0",
"preliz>=0.25.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -131,7 +131,7 @@ exclude_lines = [

[tool.ruff]
line-length = 100
target-version = "py311"
target-version = "py312"

[tool.ruff.format]
docstring-code-format = true
Expand Down
18 changes: 9 additions & 9 deletions tests/statespace/core/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
floatX = pytensor.config.floatX
nile = load_nile_test_data()
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)
mock_pymc_sample = pytest.fixture(scope="module")(mock_sample_setup_and_teardown)


def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None):
Expand Down Expand Up @@ -393,7 +393,7 @@ def pymc_mod_time_varying(ss_mod_time_varying, rng):
return m


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata(pymc_mod, rng, mock_pymc_sample):
with pymc_mod:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
Expand All @@ -403,7 +403,7 @@ def idata(pymc_mod, rng, mock_pymc_sample):
return idata


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
with exog_pymc_mod:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
Expand All @@ -412,7 +412,7 @@ def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
return idata


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata_exog_mv(exog_pymc_mod_mv, rng, mock_pymc_sample):
with exog_pymc_mod_mv:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
Expand All @@ -421,7 +421,7 @@ def idata_exog_mv(exog_pymc_mod_mv, rng, mock_pymc_sample):
return idata


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample):
with pymc_mod_no_exog:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
Expand All @@ -430,7 +430,7 @@ def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample):
return idata


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata_no_exog_mv(pymc_mod_no_exog_mv, rng, mock_pymc_sample):
with pymc_mod_no_exog_mv:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
Expand All @@ -439,7 +439,7 @@ def idata_no_exog_mv(pymc_mod_no_exog_mv, rng, mock_pymc_sample):
return idata


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata_no_exog_mv_dt(pymc_mod_no_exog_mv_dt, rng, mock_pymc_sample):
with pymc_mod_no_exog_mv_dt:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
Expand All @@ -448,7 +448,7 @@ def idata_no_exog_mv_dt(pymc_mod_no_exog_mv_dt, rng, mock_pymc_sample):
return idata


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample):
with pymc_mod_no_exog_dt:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
Expand All @@ -457,7 +457,7 @@ def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample):
return idata


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata_time_varying(pymc_mod_time_varying, rng, mock_pymc_sample):
"""Inference data for time-varying model."""
with pymc_mod_time_varying:
Expand Down
6 changes: 3 additions & 3 deletions tests/statespace/core/test_statespace_JAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
nile = load_nile_test_data()
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES

mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)
mock_pymc_sample = pytest.fixture(scope="module")(mock_sample_setup_and_teardown)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -68,7 +68,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
return m


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata(pymc_mod, rng, mock_pymc_sample):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand All @@ -91,7 +91,7 @@ def idata(pymc_mod, rng, mock_pymc_sample):
return idata


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down
Loading