diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index df5ff8d10b..655776e963 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -139,6 +139,7 @@ jobs: - | tests/dims/distributions/test_core.py + tests/dims/distributions/test_censored.py tests/dims/distributions/test_scalar.py tests/dims/distributions/test_vector.py tests/dims/test_model.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1ef3a4227..2e69aa458a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.13 hooks: - - id: ruff + - id: ruff-check args: [--fix, --show-fixes] - id: ruff-format - repo: local diff --git a/conda-envs/environment-alternative-backends.yml b/conda-envs/environment-alternative-backends.yml index 7154296e0e..2246655242 100644 --- a/conda-envs/environment-alternative-backends.yml +++ b/conda-envs/environment-alternative-backends.yml @@ -22,7 +22,7 @@ dependencies: - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.38.0,<2.39 +- pytensor>=2.38.2,<2.39 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index ad1166c470..c96aa49c58 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.38.0,<2.39 +- pytensor>=2.38.2,<2.39 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 19b8da5352..a1dd4a1151 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -11,7 +11,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.38.0,<2.39 +- pytensor>=2.38.2,<2.39 - python-graphviz - rich>=13.7.1 - scipy>=1.4.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index b63313bbdd..e17bc802b1 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -14,7 +14,7 @@ dependencies: - pandas>=0.24.0 - pip - polyagamma -- pytensor>=2.38.0,<2.39 +- pytensor>=2.38.2,<2.39 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 8411cfb603..d942055b70 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.38.0,<2.39 +- pytensor>=2.38.2,<2.39 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 1be6609a8d..e54eb38be6 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -15,7 +15,7 @@ dependencies: - pandas>=0.24.0 - pip - polyagamma -- pytensor>=2.38.0,<2.39 +- pytensor>=2.38.2,<2.39 - python-graphviz - networkx - rich>=13.7.1 diff --git a/docs/source/api/dims/distributions.rst b/docs/source/api/dims/distributions.rst index 229980b73a..cc34d68deb 100644 --- a/docs/source/api/dims/distributions.rst +++ b/docs/source/api/dims/distributions.rst @@ -39,3 +39,14 @@ Vector distributions Categorical MvNormal ZeroSumNormal + + +Higher-Order distributions +========================== + +.. currentmodule:: pymc.dims +.. autosummary:: + :toctree: generated/ + :template: distribution.rst + + Censored diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..6bf7b9c8e3 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -333,7 +333,7 @@ def sample_stats_to_xarray(self): data_warmup = {} for stat in self.trace.stat_names: name = rename_key.get(stat, stat) - if name == "tune": + if name in {"tune", "in_warmup"}: continue if self.warmup_trace: data_warmup[name] = np.array( diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 993acc0df4..528552650f 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -113,7 +113,13 @@ def point(self, idx: int) -> dict[str, np.ndarray]: """ raise NotImplementedError() - def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + def record( + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + in_warmup: bool, + ): """Record results of a sampling iteration. Parameters @@ -122,6 +128,9 @@ def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, An Values mapped to variable names stats: list of dicts The diagnostic values for each sampler + in_warmup: bool + Whether this draw belongs to the warmup phase. This is a driver-owned + concept and is intended for storage/backends to persist warmup information. """ raise NotImplementedError() diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index d02a6dbebb..891afa51f9 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -34,7 +34,6 @@ BlockedStep, CompoundStep, StatsBijection, - check_step_emits_tune, flat_statname, flatten_steps, ) @@ -106,16 +105,26 @@ def __init__( {sname: stats_dtypes[fname] for fname, sname, is_obj in sstats} for sstats in stats_bijection._stat_groups ] + if "in_warmup" in stats_dtypes and self.sampler_vars: + # Expose driver-owned warmup marker via the sampler-stats API. + self.sampler_vars[0].setdefault("in_warmup", stats_dtypes["in_warmup"]) self._chain = chain self._point_fn = point_fn self._statsbj = stats_bijection super().__init__() - def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + def record( + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + in_warmup: bool, + ): values = self._point_fn(draw) value_dict = dict(zip(self.varnames, values)) stats_dict = self._statsbj.map(stats) + stats_dict["in_warmup"] = bool(in_warmup) # Apply pickling to objects stats for fname in self._statsbj.object_stats.keys(): val_bytes = pickle.dumps(stats_dict[fname]) @@ -148,6 +157,9 @@ def get_sampler_stats( self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1 ) -> np.ndarray: slc = slice(burn, None, thin) + if stat_name in {"in_warmup", "tune"}: + # Backwards-friendly alias for users that might try "tune". + return self._get_stats("in_warmup", slc) # When there's just one sampler, default to remove the sampler dimension if sampler_idx is None and self._statsbj.n_samplers == 1: sampler_idx = 0 @@ -210,8 +222,6 @@ def make_runmeta_and_point_fn( ) -> tuple[mcb.RunMeta, PointFunc]: variables, point_fn = get_variables_and_point_fn(model, initial_point) - check_step_emits_tune(step) - # In PyMC the sampler stats are grouped by the sampler. sample_stats = [] steps = flatten_steps(step) @@ -235,6 +245,16 @@ def make_runmeta_and_point_fn( ) sample_stats.append(svar) + # driver owned warmup marker. stored once per draw. + sample_stats.append( + mcb.Variable( + name="in_warmup", + dtype=np.dtype(bool).name, + shape=[], + undefined_ndim=False, + ) + ) + coordinates = [ mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals))) for dname, cvals in model.coords.items() diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index a08fc8f47e..5d8d1be62b 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -97,7 +97,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None: new = np.zeros(draws, dtype=dtype) data[varname] = np.concatenate([old, new]) - def record(self, point, sampler_stats=None) -> None: + def record(self, point, sampler_stats=None, *, in_warmup: bool) -> None: """Record results of a sampling iteration. Parameters @@ -238,5 +238,5 @@ def point_fun(point): chain.fn = point_fun for point in point_list: - chain.record(point) + chain.record(point, in_warmup=False) return MultiTrace([chain]) diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index 9b7664c504..c32d01bff5 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -159,7 +159,11 @@ def buffer(self, group, var_name, value): buffer[var_name].append(value) def record( - self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]] + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + in_warmup: bool, ) -> bool | None: """Record the step method's returned draw and stats. @@ -185,6 +189,7 @@ def record( self.buffer(group="posterior", var_name=var_name, value=var_value) for var_name, var_value in self.stats_bijection.map(stats).items(): self.buffer(group="sample_stats", var_name=var_name, value=var_value) + self.buffer(group="sample_stats", var_name="in_warmup", value=bool(in_warmup)) self._buffered_draws += 1 if self._buffered_draws == self.draws_until_flush: self.flush() @@ -525,6 +530,7 @@ def init_trace( stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( [step] if isinstance(step, BlockedStep) else step.methods ) + stats_dtypes_shapes = {"in_warmup": (bool, [])} | stats_dtypes_shapes self.init_group_with_empty( group=self.root.create_group(name="sample_stats", overwrite=True), var_dtype_and_shape=stats_dtypes_shapes, @@ -683,6 +689,7 @@ def init_group_with_empty( for i, shape_i in enumerate(shape): dim = f"{name}_dim_{i}" dims.append(dim) + assert shape_i is not None, f"{dim} shape is None" group_coords[dim] = np.arange(shape_i, dtype="int") dims = ("chain", "draw", *dims) attrs = extra_var_attrs[name] if extra_var_attrs is not None else {} diff --git a/pymc/dims/distributions/__init__.py b/pymc/dims/distributions/__init__.py index da85bc1463..6c49789089 100644 --- a/pymc/dims/distributions/__init__.py +++ b/pymc/dims/distributions/__init__.py @@ -11,5 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pymc.dims.distributions.censored import Censored from pymc.dims.distributions.scalar import * from pymc.dims.distributions.vector import * diff --git a/pymc/dims/distributions/censored.py b/pymc/dims/distributions/censored.py new file mode 100644 index 0000000000..29d9a56136 --- /dev/null +++ b/pymc/dims/distributions/censored.py @@ -0,0 +1,51 @@ +# Copyright 2026 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from pymc.dims.distributions.core import DimDistribution, copy_docstring, expand_dist_dims +from pymc.distributions.censored import Censored as RegularCensored + + +@copy_docstring(RegularCensored) +class Censored(DimDistribution): + @classmethod + def dist(cls, dist, *, lower=None, upper=None, dim_lengths, **kwargs): + if lower is None: + lower = -np.inf + if upper is None: + upper = np.inf + return super().dist([dist, lower, upper], dim_lengths=dim_lengths, **kwargs) + + @classmethod + def xrv_op(cls, dist, lower, upper, core_dims=None, extra_dims=None, rng=None): + if extra_dims is None: + extra_dims = {} + + dist = cls._as_xtensor(dist) + lower = cls._as_xtensor(lower) + upper = cls._as_xtensor(upper) + + # Any dimensions in extra_dims, or only present in lower, upper, + # must propagate back to the dist as `extra_dims` + bounds_sizes = lower.sizes | upper.sizes + dist_dims_set = set(dist.dims) + extra_dist_dims = extra_dims | { + dim: size for dim, size in bounds_sizes.items() if dim not in dist_dims_set + } + if extra_dist_dims: + dist = expand_dist_dims(dist, extra_dist_dims) + + # Probability is inferred from the clip operation + # TODO: Make this a SymbolicRandomVariable that can itself be resized + return dist.clip(lower, upper) diff --git a/pymc/dims/distributions/core.py b/pymc/dims/distributions/core.py index aee8e4cf7f..918b28a794 100644 --- a/pymc/dims/distributions/core.py +++ b/pymc/dims/distributions/core.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Callable, Sequence from itertools import chain -from typing import cast +from typing import Any, cast import numpy as np @@ -25,7 +25,9 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.xtensor import as_xtensor from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor +from pytensor.xtensor.shape import Transpose from pytensor.xtensor.type import XTensorVariable +from pytensor.xtensor.vectorization import XRV from pymc import SymbolicRandomVariable, modelcontext from pymc.dims.distributions.transforms import DimTransform, log_odds_transform, log_transform @@ -345,3 +347,25 @@ class UnitDimDistribution(DimDistribution): """Base class for unit-valued distributions.""" default_transform = log_odds_transform + + +def expand_dist_dims(dist: XTensorVariable, extra_dims: dict[str, Any]) -> XTensorVariable: + if overlap := (set(extra_dims) & set(dist.dims)): + raise ValueError(f"extra_dims already present in distribution: {sorted(overlap)}") + + op = None if dist.owner is None else dist.owner.op + match op: + case XRV(): + # Recreate dist with new extra dims + dist_props = dist.owner.op._props_dict() + dist_props["extra_dims"] = (*(extra_dims.keys()), *dist_props["extra_dims"]) + new_dist_op = type(dist.owner.op)(**dist_props) + _old_rng, *params_and_dim_lengths = dist.owner.inputs + new_rng = None # We don't propagate the old RNG, because we don't want the new and old dists to be correlated + return new_dist_op(new_rng, *extra_dims.values(), *params_and_dim_lengths) + case Transpose(): + return expand_dist_dims(dist.owner.inputs[0], extra_dims=extra_dims).transpose( + ..., *dist.dims + ) + case _: + raise NotImplementedError(f"expand_dist_dims not implemented for {dist} with op {op}") diff --git a/pymc/dims/distributions/transforms.py b/pymc/dims/distributions/transforms.py index 14dbb9444f..e898ab0508 100644 --- a/pymc/dims/distributions/transforms.py +++ b/pymc/dims/distributions/transforms.py @@ -203,6 +203,4 @@ def backward(self, value, *rv_inputs): return value def log_jac_det(self, value, *rv_inputs): - # Use following once broadcast_like is implemented - # as_xtensor(0).broadcast_like(value, exclude=self.dims)` - return value.sum(self.dims) * 0 + return as_xtensor(0.0).broadcast_like(value, exclude=self.dims) diff --git a/pymc/dims/distributions/vector.py b/pymc/dims/distributions/vector.py index 7107712e67..0995f891e8 100644 --- a/pymc/dims/distributions/vector.py +++ b/pymc/dims/distributions/vector.py @@ -229,7 +229,8 @@ def dist(cls, sigma=1.0, *, core_dims=None, dim_lengths, **kwargs): raise ValueError("ZeroSumNormal requires atleast 1 core_dims") support_dims = as_xtensor( - as_tensor([dim_lengths[core_dim] for core_dim in core_dims]), dims=("_",) + as_tensor([dim_lengths[core_dim] for core_dim in core_dims]), + dims=("__support_shape__",), ) sigma = cls._as_xtensor(sigma) @@ -238,16 +239,25 @@ def dist(cls, sigma=1.0, *, core_dims=None, dim_lengths, **kwargs): ) @classmethod - def xrv_op(self, sigma, support_dims, core_dims, extra_dims=None, rng=None): - sigma = as_xtensor(sigma) - support_dims = as_xtensor(support_dims, dims=("_",)) - support_shape = support_dims.values - core_rv = ZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op + def xrv_op(cls, sigma, support_shape, core_dims, extra_dims=None, rng=None): + # ZeroSumNormal expects dummy dimensions on sigma for the support_shape + sigma = cls._as_xtensor(sigma).expand_dims(core_dims) + support_shape = as_xtensor(support_shape, dims=("__support_shape__",)) + core_rv = ZeroSumNormalRV.rv_op( + sigma=sigma.values, support_shape=support_shape.values + ).owner.op + core_dims_map = tuple(range(1, len(core_dims) + 1)) xop = pxr.as_xrv( core_rv, - core_inps_dims_map=[(), (0,)], - core_out_dims_map=tuple(range(1, len(core_dims) + 1)), + core_inps_dims_map=[core_dims_map, (0,)], + core_out_dims_map=core_dims_map, ) # Dummy "_" core dim to absorb the support_shape vector # If ZeroSumNormal expected a scalar per support dim, this wouldn't be needed - return xop(sigma, support_dims, core_dims=("_", *core_dims), extra_dims=extra_dims, rng=rng) + return xop( + sigma, + support_shape, + core_dims=("__support_shape__", *core_dims), + extra_dims=extra_dims, + rng=rng, + ) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index b434a6db25..402ce1d252 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -392,7 +392,7 @@ def make_node(self, *inputs): ) if size_arg_idx is not None and len(rng_arg_idxs) == 1: new_size_type = normalize_size_param(inputs[size_arg_idx]).type - if not self.input_types[size_arg_idx].in_same_class(new_size_type): + if not self.input_types[size_arg_idx].is_super(new_size_type): params = [inputs[idx] for idx in param_idxs] size = inputs[size_arg_idx] rng = inputs[rng_arg_idxs[0]] @@ -405,7 +405,7 @@ def update(self, node: Apply) -> dict[Variable, Variable]: Returns a dictionary with the symbolic expressions required for correct updating of random state input variables repeated function evaluations. This is used by - `pytensorf.compile_pymc`. + `pytensorf.compile`. """ return collect_default_updates_inner_fgraph(node) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 192fff6e30..c43b43fbe6 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -952,6 +952,8 @@ class WishartRV(RandomVariable): @classmethod def rng_fn(cls, rng, nu, V, size): scipy_size = size if size else 1 # Default size for Scipy's wishart.rvs is 1 + # Scipy doesn't accept batch nu or V + nu = _squeeze_to_ndim(nu, 0) V = _squeeze_to_ndim(V, 2) result = stats.wishart.rvs(int(nu), V, size=scipy_size, random_state=rng) if size == (1,): @@ -996,8 +998,17 @@ class Wishart(Continuous): Notes ----- - This distribution is unusable in a PyMC model. You should instead - use LKJCholeskyCov or LKJCorr. + The Wishart distribution is generally unusable as a prior distribution for + MCMC sampling. The probability of sampling a symmetric positive definite + matrix is effectively zero, since MCMC proposals in unconstrained space + almost never land exactly on the SPD manifold. + + For modeling covariance matrices, you should instead use + :class:`LKJCholeskyCov` or :class:`LKJCorr`. + + However, the Wishart distribution may be used as a likelihood with + ``observed`` in some cases, where the distribution is evaluated at + fixed observed values rather than sampled during MCMC. """ rv_op = wishart @@ -2659,9 +2670,9 @@ class ZeroSumNormalRV(SymbolicRandomVariable): @classmethod def rv_op(cls, sigma, support_shape, *, size=None, rng=None): + support_shape = _squeeze_to_ndim(pt.as_tensor(support_shape), ndim=1) n_zerosum_axes = pt.get_vector_length(support_shape) sigma = pt.as_tensor(sigma) - support_shape = pt.as_tensor(support_shape, ndim=1) rng = normalize_rng_param(rng) size = normalize_size_param(size) @@ -2677,8 +2688,10 @@ def rv_op(cls, sigma, support_shape, *, size=None, rng=None): for axis in range(n_zerosum_axes): zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True) + # sigma has core_shape = (1, 1, ...) (as many as there are zerosum axes) + ones = ",".join("1" for _ in range(n_zerosum_axes)) support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)]) - extended_signature = f"[rng],[size],(),(s)->[rng],({support_str})" + extended_signature = f"[rng],[size],({ones}),(s)->[rng],({support_str})" return cls( inputs=[rng, size, sigma, support_shape], outputs=[next_rng, zerosum_rv], @@ -2774,7 +2787,7 @@ def __new__(cls, *args, n_zerosum_axes=None, support_shape=None, dims=None, **kw def dist(cls, sigma=1.0, n_zerosum_axes=None, support_shape=None, **kwargs): n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes) - sigma = pt.as_tensor(sigma) + sigma = pt.atleast_Nd(pt.as_tensor(sigma), n=n_zerosum_axes) if not all(sigma.type.broadcastable[-n_zerosum_axes:]): raise ValueError("sigma must have length one across the zero-sum axes") diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 5e3f934953..c56cba8a55 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -309,7 +309,7 @@ def backward(self, value, *rv_inputs): return value def log_jac_det(self, value, *rv_inputs): - return pt.constant(0.0) + return value.sum(self.zerosum_axes).zeros_like() log_exp_m1 = LogExpM1() diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 4f6eef2934..7ba9618584 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -144,7 +144,7 @@ def logp(rv: Variable, value: Variable | TensorLike, warn_rvs=True, **kwargs) -> print(rv_logp.eval({value: 0.9, mu: 0.0})) # -1.32393853 # Compile a function for repeated evaluations - rv_logp_fn = pm.compile_pymc([value, mu], rv_logp) + rv_logp_fn = pm.compile([value, mu], rv_logp) print(rv_logp_fn(value=0.9, mu=0.0)) # -1.32393853 @@ -166,7 +166,7 @@ def logp(rv: Variable, value: Variable | TensorLike, warn_rvs=True, **kwargs) -> print(exp_rv_logp.eval({value: 0.9, mu: 0.0})) # -0.81912844 # Compile a function for repeated evaluations - exp_rv_logp_fn = pm.compile_pymc([value, mu], exp_rv_logp) + exp_rv_logp_fn = pm.compile([value, mu], exp_rv_logp) print(exp_rv_logp_fn(value=0.9, mu=0.0)) # -0.81912844 @@ -244,7 +244,7 @@ def logcdf(rv: Variable, value: Variable | TensorLike, warn_rvs=True, **kwargs) print(rv_logcdf.eval({value: 0.9, mu: 0.0})) # -0.2034146 # Compile a function for repeated evaluations - rv_logcdf_fn = pm.compile_pymc([value, mu], rv_logcdf) + rv_logcdf_fn = pm.compile([value, mu], rv_logcdf) print(rv_logcdf_fn(value=0.9, mu=0.0)) # -0.2034146 @@ -266,7 +266,7 @@ def logcdf(rv: Variable, value: Variable | TensorLike, warn_rvs=True, **kwargs) print(exp_rv_logcdf.eval({value: 0.9, mu: 0.0})) # -0.78078813 # Compile a function for repeated evaluations - exp_rv_logcdf_fn = pm.compile_pymc([value, mu], exp_rv_logcdf) + exp_rv_logcdf_fn = pm.compile([value, mu], exp_rv_logcdf) print(exp_rv_logcdf_fn(value=0.9, mu=0.0)) # -0.78078813 @@ -349,7 +349,7 @@ def logccdf(rv: Variable, value: Variable | TensorLike, warn_rvs=True, **kwargs) print(rv_logccdf.eval({value: 0.9, mu: 0.0})) # -1.5272506 # Compile a function for repeated evaluations - rv_logccdf_fn = pm.compile_pymc([value, mu], rv_logccdf) + rv_logccdf_fn = pm.compile([value, mu], rv_logccdf) print(rv_logccdf_fn(value=0.9, mu=0.0)) # -1.5272506 """ @@ -410,7 +410,7 @@ def icdf(rv: Variable, value: Variable | TensorLike, warn_rvs=True, **kwargs) -> print(rv_icdf.eval({value: 0.9, mu: 0.0})) # 1.28155157 # Compile a function for repeated evaluations - rv_icdf_fn = pm.compile_pymc([value, mu], rv_icdf) + rv_icdf_fn = pm.compile([value, mu], rv_icdf) print(rv_icdf_fn(value=0.9, mu=0.0)) # 1.28155157 @@ -432,7 +432,7 @@ def icdf(rv: Variable, value: Variable | TensorLike, warn_rvs=True, **kwargs) -> print(exp_rv_icdf.eval({value: 0.9, mu: 0.0})) # 3.60222448 # Compile a function for repeated evaluations - exp_rv_icdf_fn = pm.compile_pymc([value, mu], exp_rv_icdf) + exp_rv_icdf_fn = pm.compile([value, mu], exp_rv_icdf) print(exp_rv_icdf_fn(value=0.9, mu=0.0)) # 3.60222448 """ diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 2e1b96d343..8e28076406 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -16,6 +16,7 @@ from collections.abc import Sequence import numpy as np +import pytensor.tensor as pt from pytensor.graph import Apply, Op from pytensor.graph.features import AlreadyThere, Feature @@ -113,10 +114,18 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs) ) # Check there is no broadcasting between logp and jacobian if logp.type.broadcastable != log_jac_det.type.broadcastable: - raise ValueError( - f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " - "There is a bug in the implementation of either one." - ) + lb, jb = logp.type.broadcastable, log_jac_det.type.broadcastable + broadcastable_axes = [ + i for i, (ai, bi) in enumerate(zip(lb, jb, strict=True)) if ai or bi + ] + try: + logp = pt.specify_broadcastable(logp, *broadcastable_axes) + log_jac_det = pt.specify_broadcastable(log_jac_det, *broadcastable_axes) + except ValueError as err: + raise ValueError( + f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. " + "There is a bug in the implementation of either one." + ) from err if use_jacobian: if value.name: diff --git a/pymc/model/core.py b/pymc/model/core.py index 38e9f2711f..729752abeb 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1638,7 +1638,7 @@ def compile_fn( Whether to wrap the compiled function in a PointFunc, which takes a Point dictionary with model variable names and values as input. Other keyword arguments : - Any other keyword argument is sent to :py:func:`pymc.pytensorf.compile_pymc`. + Any other keyword argument is sent to :py:func:`pymc.pytensorf.compile`. Returns ------- diff --git a/pymc/model_graph.py b/pymc/model_graph.py index d673002c28..7830557262 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -203,14 +203,25 @@ def get_node_type(var_name: str, model) -> NodeType: def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTypeFormatterMapping: node_formatters = {**DEFAULT_NODE_FORMATTERS, **node_formatters} - unknown_keys = set(node_formatters.keys()) - set(NodeType) + # Allow both NodeType enum members and their string values + valid_keys = set(NodeType) | {node_type.value for node_type in NodeType} + unknown_keys = set(node_formatters.keys()) - valid_keys if unknown_keys: raise ValueError( f"Node formatters must be of type NodeType. Found: {list(unknown_keys)}." f" Please use one of {[node_type.value for node_type in NodeType]}." ) - return node_formatters + # Convert string keys to enum keys for consistent handling + normalized_formatters = {} + for key, formatter in node_formatters.items(): + if isinstance(key, str): + # Convert string to enum member + normalized_formatters[NodeType(key)] = formatter + else: + normalized_formatters[key] = formatter + + return normalized_formatters AddNode = Callable[[str, GraphvizNodeKwargs], None] diff --git a/pymc/printing.py b/pymc/printing.py index 7969799cc6..d253c588ce 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -21,11 +21,11 @@ from pytensor.graph.basic import Constant, Variable from pytensor.graph.traversal import walk from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.random.basic import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable +from pymc.logprob.abstract import MeasurableOp from pymc.model import Model __all__ = [ @@ -41,21 +41,24 @@ def str_for_dist(dist: Variable, formatting: str = "plain", include_params: bool This can be either LaTeX or plain, optionally with distribution parameter values included. """ + dist_op = dist.owner.op + if include_params: - if isinstance(dist.owner.op, RandomVariable) or getattr( - dist.owner.op, "extended_signature", None - ): - dist_args = [ - _str_for_input_var(x, formatting=formatting) - for x in dist.owner.op.dist_params(dist.owner) - ] - else: + try: + dist_args = dist.owner.op.dist_params(dist.owner) + except Exception: + # Can happen with SymbolicRandomVariable without extended_signature dist_args = [ - _str_for_input_var(x, formatting=formatting) - for x in dist.owner.inputs - if not isinstance(x.type, RandomType | NoneTypeT) + x for x in dist.owner.inputs if not isinstance(x.type, RandomType | NoneTypeT) ] + dist_args_str = [_str_for_input_var(a, formatting=formatting) for a in dist_args] + + if (print_name := getattr(dist_op, "_print_name", None)) is not None: + dist_name = print_name[formatting == "latex"] + else: + dist_name = "Unknown" + print_name = dist.name if "latex" in formatting: @@ -63,30 +66,22 @@ def str_for_dist(dist: Variable, formatting: str = "plain", include_params: bool print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}" print_name = _format_underscore(print_name) - op_name = ( - dist.owner.op._print_name[1] - if hasattr(dist.owner.op, "_print_name") - else r"\\operatorname{Unknown}" - ) if include_params: - params = ",~".join([d.strip("$") for d in dist_args]) + params = ",~".join([d.strip("$") for d in dist_args_str]) if print_name: - return rf"${print_name} \sim {op_name}({params})$" + return rf"${print_name} \sim {dist_name}({params})$" else: - return rf"${op_name}({params})$" + return rf"${dist_name}({params})$" else: if print_name: - return rf"${print_name} \sim {op_name}$" + return rf"${print_name} \sim {dist_name}$" else: - return rf"${op_name}$" + return rf"${dist_name}$" else: # plain - dist_name = ( - dist.owner.op._print_name[0] if hasattr(dist.owner.op, "_print_name") else "Unknown" - ) if include_params: - params = ", ".join(dist_args) + params = ", ".join(dist_args_str) if print_name: return rf"{print_name} ~ {dist_name}({params})" else: @@ -169,10 +164,9 @@ def str_for_potential_or_deterministic( def _str_for_input_var(var: Variable, formatting: str) -> str: - # Avoid circular import - from pymc.distributions.distribution import SymbolicRandomVariable - def _is_potential_or_deterministic(var: Variable) -> bool: + # FIXME: This is an (insufficient) hack. For model_repr we know which nodes are named variables + # and we should propagate that information instead of guessing based on whether something was monkey-patched if not hasattr(var, "str_repr"): return False try: @@ -183,9 +177,7 @@ def _is_potential_or_deterministic(var: Variable) -> bool: if isinstance(var, Constant | SharedVariable): return _str_for_constant(var, formatting) - elif isinstance( - var.owner.op, RandomVariable | SymbolicRandomVariable - ) or _is_potential_or_deterministic(var): + elif isinstance(var.owner.op, MeasurableOp) or _is_potential_or_deterministic(var): # show the names for RandomVariables, Deterministics, and Potentials, rather # than the full expression return _str_for_input_rv(var, formatting) @@ -227,25 +219,23 @@ def _str_for_constant(var: Constant | SharedVariable, formatting: str) -> str: def _str_for_expression(var: Variable, formatting: str) -> str: # Avoid circular import - from pymc.distributions.distribution import SymbolicRandomVariable # construct a string like f(a1, ..., aN) listing all random variables a as arguments def _expand(x): - if x.owner and (not isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)): + if x.owner and not isinstance(x.owner.op, MeasurableOp): return reversed(x.owner.inputs) parents = [] names = [] for x in walk(nodes=var.owner.inputs, expand=_expand): assert isinstance(x, Variable) - if x.owner and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable): + if x.owner and isinstance(x.owner.op, MeasurableOp): parents.append(x) xname = x.name if xname is None: # If the variable is unnamed, we show the op's name as we do # with constants - opname = x.owner.op.name - if opname is not None: + if (opname := getattr(x.owner.op, "name", None)) is not None: xname = rf"<{opname}>" assert xname is not None names.append(xname) diff --git a/pymc/progress_bar/rich_progress.py b/pymc/progress_bar/rich_progress.py index 6c980221ee..09607ebc1f 100644 --- a/pymc/progress_bar/rich_progress.py +++ b/pymc/progress_bar/rich_progress.py @@ -13,6 +13,7 @@ # limitations under the License. from collections.abc import Iterable +from sys import stderr from typing import Any, Self from rich.box import SIMPLE_HEAD @@ -212,7 +213,7 @@ def _create_progress_bar( finished_style=Style.parse("rgb(31,119,180)"), ), *columns, - console=Console(theme=theme), + console=Console(file=stderr, theme=theme), include_headers=True, ) @@ -323,5 +324,5 @@ def RichSimpleProgress(theme: Theme | None): TimeRemainingColumn(), TextColumn("/"), TimeElapsedColumn(), - console=Console(theme=default_progress_theme if theme is None else theme), + console=Console(file=stderr, theme=default_progress_theme if theme is None else theme), ) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index c9f3eb00f1..9db0db8ade 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -294,7 +294,7 @@ def draw( random_seed : int, RandomState or numpy_Generator, optional Seed for the random number generator. **kwargs : dict, optional - Keyword arguments for :func:`pymc.pytensorf.compile_pymc`. + Keyword arguments for :func:`pymc.pytensorf.compile`. Returns ------- @@ -410,7 +410,7 @@ def sample_prior_predictive( idata_kwargs : dict, optional Keyword arguments for :func:`pymc.to_inference_data` compile_kwargs: dict, optional - Keyword arguments for :func:`pymc.pytensorf.compile_pymc`. + Keyword arguments for :func:`pymc.pytensorf.compile`. samples : int Number of samples from the prior predictive to generate. Deprecated in favor of `draws`. @@ -580,7 +580,7 @@ def sample_posterior_predictive( Keyword arguments for :func:`pymc.to_inference_data` if ``predictions=False`` or to :func:`pymc.predictions_to_inference_data` otherwise. compile_kwargs: dict, optional - Keyword arguments for :func:`pymc.pytensorf.compile_pymc`. + Keyword arguments for :func:`pymc.pytensorf.compile`. Returns ------- diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index b063fe8846..acf43b7b75 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1124,18 +1124,10 @@ def _sample_return( else: traces, length = _choose_chains(traces, 0) mtrace = MultiTrace(traces)[:length] - # count the number of tune/draw iterations that happened - # ideally via the "tune" statistic, but not all samplers record it! - if "tune" in mtrace.stat_names: - # Get the tune stat directly from chain 0, sampler 0 - stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0) - stat = tuple(stat) - n_tune = stat.count(True) - n_draws = stat.count(False) - else: - # these may be wrong when KeyboardInterrupt happened, but they're better than nothing - n_tune = min(tune, len(mtrace)) - n_draws = max(0, len(mtrace) - n_tune) + # Count the number of tune/draw iterations that happened. + # The warmup/draw boundary is owned by the sampling driver. + n_tune = min(tune, len(mtrace)) + n_draws = max(0, len(mtrace) - n_tune) if discard_tuned_samples: mtrace = mtrace[n_tune:] @@ -1304,7 +1296,7 @@ def _sample( try: for it, stats in enumerate(sampling_gen): progress_manager.update( - chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune + chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it < tune ) if not progress_manager.combined_progress or chain == progress_manager.chains - 1: @@ -1375,7 +1367,7 @@ def _iter_sample( step.stop_tuning() point, stats = step.step(point) - trace.record(point, stats) + trace.record(point, stats, in_warmup=i < tune) log_warning_stats(stats) if callback is not None: @@ -1488,7 +1480,7 @@ def _mp_sample( strace = traces[draw.chain] if not zarr_recording: # Zarr recording happens in each process - strace.record(draw.point, draw.stats) + strace.record(draw.point, draw.stats, in_warmup=draw.tuning) log_warning_stats(draw.stats) if callback is not None: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index cf3fd223e5..78df40581b 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -249,7 +249,7 @@ def _start_loop(self): raise KeyboardInterrupt() elif msg[0] == "write_next": if zarr_recording: - self._zarr_chain.record(point, stats) + self._zarr_chain.record(point, stats, in_warmup=tuning) self._write_point(point) is_last = draw + 1 == self._draws + self._tune self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats)) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 5bd1771704..7d1d9902f9 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -458,7 +458,7 @@ def _iter_population( # apply the update to the points and record to the traces for c, strace in enumerate(traces): points[c], stats = updates[c] - flushed = strace.record(points[c], stats) + flushed = strace.record(points[c], stats, in_warmup=i < tune) log_warning_stats(stats) if flushed and isinstance(strace, ZarrChain): sampling_state = popstep.request_sampling_state(c) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index deec9a4576..3189d13edc 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -414,7 +414,7 @@ def _build_trace_from_kernel_state( var_samples = np.round(var_samples).astype(var.dtype) value.append(var_samples.reshape(shape)) size += new_size - strace.record(point=dict(zip(varnames, value))) + strace.record(point=dict(zip(varnames, value)), in_warmup=False) return strace diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index a9cae903f0..389bd3b30a 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -92,6 +92,10 @@ def infer_warn_stats_info( sds[sname] = (dtype, None) elif sds: stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()}) + + # Even when a step method does not emit any stats, downstream components still assume one stats "slot" per step method. represent that with a single empty dict. + if not stats_dtypes: + stats_dtypes.append({}) return stats_dtypes, sds @@ -351,16 +355,6 @@ def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: return steps -def check_step_emits_tune(step: CompoundStep | BlockedStep): - if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes: - raise TypeError(f"{type(step)} does not emit the required 'tune' stat.") - elif isinstance(step, CompoundStep): - for sstep in step.methods: - if "tune" not in sstep.stats_dtypes_shapes: - raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.") - return - - class StatsBijection: """Map between a `list` of stats to `dict` of stats.""" diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 297b095e23..c3e6d75e5c 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -273,7 +273,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.iter_count += 1 stats: dict[str, Any] = { - "tune": self.tune, "diverging": diverging, "divergences": self.divergences, "perf_counter_diff": perf_end - perf_start, diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 1697341bc8..57fd5219b1 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -53,7 +53,6 @@ class HamiltonianMC(BaseHMC): stats_dtypes_shapes = { "step_size": (np.float64, []), "n_steps": (np.int64, []), - "tune": (bool, []), "step_size_bar": (np.float64, []), "accept": (np.float64, []), "diverging": (bool, []), diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index c927d57e31..f674e852ee 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -110,7 +110,6 @@ class NUTS(BaseHMC): stats_dtypes_shapes = { "depth": (np.int64, []), "step_size": (np.float64, []), - "tune": (bool, []), "mean_tree_accept": (np.float64, []), "step_size_bar": (np.float64, []), "tree_size": (np.float64, []), diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index c042bc1f3d..6bae4d92c6 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -146,7 +146,6 @@ class Metropolis(ArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (np.float64, []), - "tune": (bool, []), "scaling": (np.float64, []), } @@ -316,7 +315,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": np.mean(self.scaling), "accept": np.mean(np.exp(self.accept_rate_iter)), "accepted": np.mean(self.accepted_iter), @@ -331,7 +329,6 @@ def competence(var, has_grad): @staticmethod def _progressbar_config(n_chains=1): columns = [ - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), TextColumn( "{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1) @@ -339,7 +336,6 @@ def _progressbar_config(n_chains=1): ] stats = { - "tune": [True] * n_chains, "scaling": [0] * n_chains, "accept_rate": [0.0] * n_chains, } @@ -351,7 +347,7 @@ def _make_progressbar_update_functions(): def update_stats(step_stats): return { "accept_rate" if key == "accept" else key: step_stats[key] - for key in ("tune", "accept", "scaling") + for key in ("accept", "scaling") } return (update_stats,) @@ -448,7 +444,6 @@ class BinaryMetropolis(ArrayStep): stats_dtypes_shapes = { "accept": (np.float64, []), - "tune": (bool, []), "p_jump": (np.float64, []), } @@ -505,7 +500,6 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: self.accepted += accepted stats = { - "tune": self.tune, "accept": np.exp(accept), "p_jump": p_jump, } @@ -537,7 +531,6 @@ def competence(var): @dataclass_state class BinaryGibbsMetropolisState(StepMethodState): - tune: bool transit_p: int shuffle_dims: bool order: list @@ -574,9 +567,7 @@ class BinaryGibbsMetropolis(ArrayStep): name = "binary_gibbs_metropolis" - stats_dtypes_shapes = { - "tune": (bool, []), - } + stats_dtypes_shapes = {} _state_class = BinaryGibbsMetropolisState @@ -594,9 +585,6 @@ def __init__( ): model = pm.modelcontext(model) - # Doesn't actually tune, but it's required to emit a sampler stat - # that indicates whether a draw was done in a tuning phase. - self.tune = True # transition probabilities self.transit_p = transit_p @@ -649,10 +637,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: if accepted: logp_curr = logp_prop - stats = { - "tune": self.tune, - } - return q, [stats] + return q, [{}] @staticmethod def competence(var): @@ -695,9 +680,7 @@ class CategoricalGibbsMetropolis(ArrayStep): name = "categorical_gibbs_metropolis" - stats_dtypes_shapes = { - "tune": (bool, []), - } + stats_dtypes_shapes = {} _state_class = CategoricalGibbsMetropolisState @@ -793,7 +776,7 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType logp_curr = logp_prop # This step doesn't have any tunable parameters - return q, [{"tune": False}] + return q, [{}] def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -811,7 +794,7 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k) # This step doesn't have any tunable parameters - return q, [{"tune": False}] + return q, [{}] def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: raise NotImplementedError() @@ -919,7 +902,6 @@ class DEMetropolis(PopulationArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (bool, []), - "tune": (bool, []), "scaling": (np.float64, []), "lambda": (np.float64, []), } @@ -1011,8 +993,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, - "scaling": self.scaling, + "scaling": np.mean(self.scaling), "lambda": self.lamb, "accept": np.exp(accept), "accepted": accepted, @@ -1090,7 +1071,6 @@ class DEMetropolisZ(ArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (bool, []), - "tune": (bool, []), "scaling": (np.float64, []), "lambda": (np.float64, []), } @@ -1213,7 +1193,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": np.mean(self.scaling), "lambda": self.lamb, "accept": np.exp(accept), diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 180ac1c882..5ea92fc916 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -72,7 +72,6 @@ class Slice(ArrayStepShared): name = "slice" default_blocked = False stats_dtypes_shapes = { - "tune": (bool, []), "nstep_out": (int, []), "nstep_in": (int, []), } @@ -184,7 +183,6 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]: self.n_tunes += 1 stats = { - "tune": self.tune, "nstep_out": nstep_out, "nstep_in": nstep_in, } @@ -202,18 +200,17 @@ def competence(var, has_grad): @staticmethod def _progressbar_config(n_chains=1): columns = [ - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)), TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)), ] - stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} + stats = {"nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} return columns, stats @staticmethod def _make_progressbar_update_functions(): def update_stats(step_stats): - return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} + return {key: step_stats[key] for key in {"nstep_out", "nstep_in"}} return (update_stats,) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 509e285f90..aecf712caf 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1598,7 +1598,7 @@ def sample( try: trace.setup(draws=draws, chain=0) for point in points: - trace.record(point) + trace.record(point, in_warmup=False) finally: trace.close() diff --git a/requirements-dev.txt b/requirements-dev.txt index 3ca21cf6c4..e55991b2eb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,7 +16,7 @@ pandas>=0.24.0 polyagamma pre-commit>=2.8.0 pymc-sphinx-theme>=0.16.0 -pytensor>=2.38.0,<2.39 +pytensor>=2.38.2,<2.39 pytest-cov>=2.5 pytest>=3.0 rich>=13.7.1 diff --git a/requirements.txt b/requirements.txt index 8e23e80665..ebbf0b097b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cachetools>=4.2.1,<7 cloudpickle numpy>=1.25.0 pandas>=0.24.0 -pytensor>=2.38.0,<2.39 +pytensor>=2.38.2,<2.39 rich>=13.7.1 scipy>=1.4.1 threadpoolctl>=3.1.0,<4.0.0 diff --git a/tests/backends/fixtures.py b/tests/backends/fixtures.py index a4f28a1262..a6a4699fe1 100644 --- a/tests/backends/fixtures.py +++ b/tests/backends/fixtures.py @@ -195,11 +195,11 @@ def setup_class(cls): stats2 = [ {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[1] ] - strace0.record(point=point0, sampler_stats=stats1) - strace1.record(point=point1, sampler_stats=stats2) + strace0.record(point=point0, sampler_stats=stats1, in_warmup=False) + strace1.record(point=point1, sampler_stats=stats2, in_warmup=False) else: - strace0.record(point=point0) - strace1.record(point=point1) + strace0.record(point=point0, in_warmup=False) + strace1.record(point=point1, in_warmup=False) strace0.close() strace1.close() cls.mtrace = base.MultiTrace([strace0, strace1]) @@ -244,9 +244,9 @@ def record_point(self, val): } if self.sampler_vars is not None: stats = [{key: dtype(val) for key, dtype in vars.items()} for vars in self.sampler_vars] - self.strace.record(point=point, sampler_stats=stats) + self.strace.record(point=point, sampler_stats=stats, in_warmup=False) else: - self.strace.record(point=point) + self.strace.record(point=point, in_warmup=False) def test_standard_close(self): for idx in range(self.draws): @@ -270,7 +270,7 @@ def test_standard_close(self): def test_missing_stats(self): if self.sampler_vars is not None: with pytest.raises(ValueError): - self.strace.record(point=self.test_point) + self.strace.record(point=self.test_point, in_warmup=False) def test_clean_interrupt(self): self.record_point(0) diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 5512981edf..27316b219a 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -749,9 +749,9 @@ def test_save_warmup(self, save_warmup, chains, tune, draws): post_prefix = "" if draws > 0 else "~" test_dict = { f"{post_prefix}posterior": ["u1", "n1"], - f"{post_prefix}sample_stats": ["~tune", "accept"], + f"{post_prefix}sample_stats": ["~in_warmup", "accept"], f"{warmup_prefix}warmup_posterior": ["u1", "n1"], - f"{warmup_prefix}warmup_sample_stats": ["~tune"], + f"{warmup_prefix}warmup_sample_stats": ["~in_warmup"], "~warmup_log_likelihood": [], "~log_likelihood": [], } @@ -786,9 +786,9 @@ def test_save_warmup_issue_1208_after_3_9(self): idata = to_inference_data(trace, save_warmup=True) test_dict = { "posterior": ["u1", "n1"], - "sample_stats": ["~tune", "accept"], + "sample_stats": ["~in_warmup", "accept"], "warmup_posterior": ["u1", "n1"], - "warmup_sample_stats": ["~tune", "accept"], + "warmup_sample_stats": ["~in_warmup", "accept"], } fails = check_multiple_attrs(test_dict, idata) assert not fails @@ -800,7 +800,7 @@ def test_save_warmup_issue_1208_after_3_9(self): idata = to_inference_data(trace[-30:], save_warmup=True) test_dict = { "posterior": ["u1", "n1"], - "sample_stats": ["~tune", "accept"], + "sample_stats": ["~in_warmup", "accept"], "~warmup_posterior": [], "~warmup_sample_stats": [], } diff --git a/tests/backends/test_base.py b/tests/backends/test_base.py index 0f450119a7..bbfbedab32 100644 --- a/tests/backends/test_base.py +++ b/tests/backends/test_base.py @@ -43,7 +43,7 @@ def test_init_trace_continuation_unsupported(self): B = pm.Uniform("B") strace = pm.backends.ndarray.NDArray(vars=[A, B]) strace.setup(10, 0) - strace.record({"A": 2, "B_interval__": 0.1}) + strace.record({"A": 2, "B_interval__": 0.1}, in_warmup=False) assert len(strace) == 1 with pytest.raises(ValueError, match="Continuation of traces"): _init_trace( diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py index e72731af6b..89fa2ccba0 100644 --- a/tests/backends/test_mcbackend.py +++ b/tests/backends/test_mcbackend.py @@ -119,7 +119,8 @@ def test_make_runmeta_and_point_fn(simple_model): assert not vars["vector"].is_deterministic assert not vars["vector_interval__"].is_deterministic assert vars["matrix"].is_deterministic - assert len(rmeta.sample_stats) == len(step.stats_dtypes[0]) + assert "in_warmup" in {s.name for s in rmeta.sample_stats} + assert len(rmeta.sample_stats) == len(step.stats_dtypes[0]) + 1 with simple_model: step = pm.NUTS() @@ -201,7 +202,7 @@ def test_get_sampler_stats(self): for i in range(N): draw = {"a": rng.normal(), "b_interval__": rng.normal()} stats = [{"tune": (i <= 5), "s1": i, "accepted": bool(rng.randint(0, 2))}] - cra.record(draw, stats) + cra.record(draw, stats, in_warmup=i <= 5) # Check final state of the chain assert len(cra) == N @@ -254,7 +255,7 @@ def test_get_sampler_stats_compound(self, caplog): {"tune": tune, "s1": i, "accepted": bool(rng.randint(0, 2))}, {"tune": tune, "s2": i, "accepted": bool(rng.randint(0, 2))}, ] - cra.record(draw, stats) + cra.record(draw, stats, in_warmup=tune) # The 'accepted' stat was emitted by both samplers assert cra.get_sampler_stats("accepted", sampler_idx=None).shape == (N, 2) @@ -293,13 +294,20 @@ def test_return_multitrace(self, simple_model, discard_warmup): return_inferencedata=False, ) assert isinstance(mtrace, pm.backends.base.MultiTrace) - tune = mtrace._straces[0].get_sampler_stats("tune") - assert isinstance(tune, np.ndarray) + in_warmup = mtrace.get_sampler_stats("in_warmup", combine=False, squeeze=False) + assert len(in_warmup) == 3 + assert all(s.dtype == np.dtype(bool) for s in in_warmup) + + # Warmup is tracked by the sampling driver and persisted via `in_warmup`. if discard_warmup: - assert tune.shape == (7, 3) + assert len(mtrace) == 7 + assert all(len(s) == 7 for s in in_warmup) + assert all(not np.any(s) for s in in_warmup) else: - assert tune.shape == (12, 3) - pass + assert len(mtrace) == 12 + assert all(len(s) == 12 for s in in_warmup) + assert all(np.all(s[:5]) for s in in_warmup) + assert all(not np.any(s[5:]) for s in in_warmup) @pytest.mark.parametrize("cores", [1, 3]) def test_return_inferencedata(self, simple_model, cores): diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py index af9c9e0a06..b90834f367 100644 --- a/tests/backends/test_zarr.py +++ b/tests/backends/test_zarr.py @@ -132,7 +132,7 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): else: manually_collected_draws.append(point) manually_collected_stats.append(stats) - trace.straces[0].record(point, stats) + trace.straces[0].record(point, stats, in_warmup=tuning) trace.straces[0].record_sampling_state(model_step) assert {group_name for group_name, _ in trace.root.groups()} == expected_groups diff --git a/tests/dims/distributions/test_censored.py b/tests/dims/distributions/test_censored.py new file mode 100644 index 0000000000..eb96f7a4da --- /dev/null +++ b/tests/dims/distributions/test_censored.py @@ -0,0 +1,88 @@ +# Copyright 2026 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from pytensor.xtensor import as_xtensor +from pytensor.xtensor.shape import Transpose +from pytensor.xtensor.vectorization import XRV + +import pymc.distributions as regular_distributions + +from pymc.dims import Censored, Normal +from pymc.model.core import Model +from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph + + +@pytest.mark.parametrize("lower", [None, -1]) +@pytest.mark.parametrize("upper", [None, 1]) +def test_censored_basic(lower, upper): + coords = {"space": range(3), "time": range(4)} + + with Model(coords=coords) as model: + dist = Normal.dist(np.pi, np.e) + Censored("y", dist, lower=lower, upper=upper, dims=("space", "time")) + + with Model(coords=coords) as reference_model: + dist = regular_distributions.Normal.dist(np.pi, np.e) + regular_distributions.Censored( + "y", dist=dist, lower=lower, upper=upper, dims=("space", "time") + ) + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + +def test_censored_dims(): + """Test that both censored (and the underlying dist) have all the implied and explicit dims.""" + coords = { + "a": range(3), + "b": range(4), + "c": range(5), + "d": range(6), + } + with Model(coords=coords) as model: + dist = Normal.dist( + mu=as_xtensor([0, 1, 2], dims=("a",)), + sigma=as_xtensor([1, 2, 3], dims=("b",)), + dim_lengths={"c": model.dim_lengths["c"]}, + ) + assert set(dist.dims) == {"c", "a", "b"} + + c0 = Censored("c0", dist) + assert c0.dims == ("c", "a", "b") + c0_dist = c0.owner.inputs[0] + assert isinstance(c0_dist.owner.op, XRV) + assert c0_dist.dims == ("c", "a", "b") + + c1 = Censored("c1", dist, dims=("a", "b", "c")) + assert c1.dims == ("a", "b", "c") + assert isinstance(c1.owner.op, Transpose) + c1_dist = c1.owner.inputs[0].owner.inputs[0] + assert isinstance(c1_dist.owner.op, XRV) + assert c1_dist.dims == ("c", "a", "b") + + c2 = Censored("c2", dist, dims=(..., "d")) + assert c2.dims == ("c", "a", "b", "d") + assert isinstance(c1.owner.op, Transpose) + c2_dist = c2.owner.inputs[0].owner.inputs[0] + assert isinstance(c2_dist.owner.op, XRV) + assert c2_dist.dims == ("d", "c", "a", "b") + + lower = as_xtensor(np.zeros((6, 5)), dims=("d", "c")) + c3 = Censored("c3", dist, lower=lower) + assert c3.dims == ("d", "c", "a", "b") + c3_dist = c3.owner.inputs[0] + assert isinstance(c3_dist.owner.op, XRV) + assert c3_dist.dims == ("d", "c", "a", "b") diff --git a/tests/dims/distributions/test_vector.py b/tests/dims/distributions/test_vector.py index 8cfdadb372..a71ca2306d 100644 --- a/tests/dims/distributions/test_vector.py +++ b/tests/dims/distributions/test_vector.py @@ -99,3 +99,23 @@ def test_zerosumnormal(): # Logp is correct, but we have join(..., -1) and join(..., 1), that don't get canonicalized to the same # Should work once https://github.com/pymc-devs/pytensor/issues/1505 is fixed # assert_equivalent_logp_graph(model, reference_model) + + +def test_zerosumnormal_batch_sigma(): + coords = {"a": range(3), "b": range(5)} + sigma = np.array([1, 2, 3.0]) + with Model(coords=coords) as model: + ZeroSumNormal( + "x", + sigma=as_xtensor(sigma, dims=("a",)), + core_dims=("b",), + ) + + with Model(coords=coords) as ref_model: + regular_distributions.ZeroSumNormal("x", sigma=sigma[:, None], dims=("a", "b")) + + ip = model.initial_point() + np.testing.assert_allclose( + model.compile_logp(sum=False)(ip), + ref_model.compile_logp(sum=False)(ip), + ) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 67ac8644d0..27d145aecb 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1763,6 +1763,13 @@ def test_batched_sigma(self): sigma=batch_test_sigma[None, :, None], n_zerosum_axes=2, support_shape=(3, 2) ) + def test_batched_transformed_logp_shape(self): + with pm.Model() as m: + x = pm.ZeroSumNormal("x", sigma=np.ones(3)[:, None], support_shape=(2,)) + assert x.type.shape == (3, 2) + assert m.logp(sum=False)[0].type.shape == (3,) + assert m.logp(sum=False, jacobian=False)[0].type.shape == (3,) + class TestMvStudentTCov(BaseTestDistributionRandom): def mvstudentt_rng_fn(self, size, nu, mu, scale, rng): diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 45cef9d86e..370a57f6eb 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -657,12 +657,12 @@ def log_jac_det(self, value, *inputs): buggy_transform = BuggyTransform() - with pm.Model() as m: - pm.Uniform("x", shape=(4, 3), default_transform=buggy_transform) + import numpy as np for jacobian_val in (True, False): - with pytest.raises( - ValueError, - match="are not allowed to broadcast together. There is a bug in the implementation of either one", - ): - m.logp(jacobian=jacobian_val) + with pm.Model() as m: + pm.Uniform("x", shape=(4, 3), default_transform=buggy_transform) + + logp_fn = m.compile_logp(jacobian=jacobian_val) + with pytest.raises(AssertionError, match="SpecifyShape"): + logp_fn({"x_buggy__": np.zeros((4, 3))}) diff --git a/tests/logprob/test_transform_value.py b/tests/logprob/test_transform_value.py index 61e2a1356e..d60b92c0b1 100644 --- a/tests/logprob/test_transform_value.py +++ b/tests/logprob/test_transform_value.py @@ -578,6 +578,20 @@ def scan_step(prev_innov, prev_rng): np.testing.assert_allclose(logp_fn(**test_point), ref_logp_fn(**test_point)) +def test_halfstudent_t_with_frozen_dims(): + """Regression test: log_jac_det had mismatched broadcastable dims vs logp when + dims were frozen to a single-element coordinate, causing a ValueError. + """ + from pymc.model.transform.optimization import freeze_dims_and_data + + with pm.Model(coords={"x_dim": ["only_one"]}) as model: + pm.HalfStudentT("x", nu=7, sigma=1, dims="x_dim") + + fmodel = freeze_dims_and_data(model) + [x_logp] = fmodel.logp(sum=False) + assert x_logp.type.shape == (1,) + + def test_weakref_leak(): """Check that the rewrite does not have a growing memory footprint. diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 8a3a99dbe8..c781204a71 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -414,6 +414,33 @@ def test_sample_return_lengths(self): assert idata.posterior.sizes["draw"] == 100 assert idata.posterior.sizes["chain"] == 3 + def test_categorical_gibbs_respects_driver_tune_boundary(self): + with pm.Model(): + pm.Categorical("x", p=np.array([0.2, 0.3, 0.5])) + sample_kwargs = { + "tune": 5, + "draws": 7, + "chains": 1, + "cores": 1, + "return_inferencedata": False, + "compute_convergence_checks": False, + "progressbar": False, + "random_seed": 123, + } + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + mtrace = pm.sample(discard_tuned_samples=True, **sample_kwargs) + assert len(mtrace) == 7 + assert mtrace.report.n_tune == 5 + assert mtrace.report.n_draws == 7 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + with pytest.warns(UserWarning, match="will be included"): + mtrace_warmup = pm.sample(discard_tuned_samples=False, **sample_kwargs) + assert len(mtrace_warmup) == 12 + assert mtrace_warmup.report.n_tune == 5 + assert mtrace_warmup.report.n_draws == 7 + @pytest.mark.parametrize("cores", [1, 2]) def test_logs_sampler_warnings(self, caplog, cores): """Asserts that "warning" sampler stats are logged during sampling.""" diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py index 8d497f3011..c453652079 100644 --- a/tests/step_methods/hmc/test_nuts.py +++ b/tests/step_methods/hmc/test_nuts.py @@ -157,7 +157,6 @@ def test_sampler_stats(self): "step_size", "step_size_bar", "tree_size", - "tune", "perf_counter_diff", "perf_counter_start", "process_time_diff", diff --git a/tests/step_methods/test_compound.py b/tests/step_methods/test_compound.py index 6c8957f9b3..a7a08bd500 100644 --- a/tests/step_methods/test_compound.py +++ b/tests/step_methods/test_compound.py @@ -36,11 +36,11 @@ from tests.models import simple_2model_continuous -def test_all_stepmethods_emit_tune_stat(): +def test_stepmethods_do_not_require_tune_stat(): step_types = pm.step_methods.STEP_METHODS assert len(step_types) > 5 for cls in step_types: - assert "tune" in cls.stats_dtypes_shapes + assert "tune" not in cls.stats_dtypes_shapes class TestCompoundStep: diff --git a/tests/test_printing.py b/tests/test_printing.py index 7efba2a81e..bf228eb6dd 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -319,6 +319,57 @@ def random(rng, mu, size): assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"]) +class TestDimsDist: + def setup_class(self): + from pymc.dims.distributions import Normal as DimsNormal + from pymc.dims.distributions import ZeroSumNormal as DimsZeroSumNormal + + with Model(coords={"group": range(3), "obs": range(5)}) as self.model: + mu = DimsNormal("mu", 0, 10, dims=("group",)) + sigma = DimsNormal("sigma", 0, 1) + zsn = DimsZeroSumNormal("zsn", sigma=1, core_dims="group") + DimsNormal("y", mu + zsn, sigma, dims=("obs", "group")) + + self.expected = { + ("plain", True): [ + r"mu ~ Normal(0, 10)", + r"sigma ~ Normal(0, 1)", + r"zsn ~ ZeroSumNormal(f(), f())", + r"y ~ Normal(f(mu, zsn), sigma)", + ], + ("plain", False): [ + r"mu ~ Normal", + r"sigma ~ Normal", + r"zsn ~ ZeroSumNormal", + r"y ~ Normal", + ], + ("latex", True): [ + r"\text{mu} &\sim & \operatorname{Normal}(0,~10)", + r"\text{sigma} &\sim & \operatorname{Normal}(0,~1)", + r"\text{zsn} &\sim & \operatorname{ZeroSumNormal}(f(),~f())", + r"\text{y} &\sim & \operatorname{Normal}(f(\text{mu},~\text{zsn}),~\text{sigma})", + ], + ("latex", False): [ + r"\text{mu} &\sim & \operatorname{Normal}", + r"\text{sigma} &\sim & \operatorname{Normal}", + r"\text{zsn} &\sim & \operatorname{ZeroSumNormal}", + r"\text{y} &\sim & \operatorname{Normal}", + ], + } + + def test_str_repr(self): + for formatting, include_params in [("plain", True), ("plain", False)]: + model_text = self.model.str_repr(formatting=formatting, include_params=include_params) + for text in self.expected[(formatting, include_params)]: + assert text in model_text + + def test_latex_repr(self): + for formatting, include_params in [("latex", True), ("latex", False)]: + model_text = self.model.str_repr(formatting=formatting, include_params=include_params) + for text in self.expected[(formatting, include_params)]: + assert text in model_text + + class TestLatexRepr: @staticmethod def simple_model() -> Model: