Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ jobs:

- |
tests/dims/distributions/test_core.py
tests/dims/distributions/test_censored.py
tests/dims/distributions/test_scalar.py
tests/dims/distributions/test_vector.py
tests/dims/test_model.py
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.13
hooks:
- id: ruff
- id: ruff-check
args: [--fix, --show-fixes]
- id: ruff-format
- repo: local
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-alternative-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- numpyro>=0.8.0
- pandas>=0.24.0
- pip
- pytensor>=2.38.0,<2.39
- pytensor>=2.38.2,<2.39
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.25.0
- pandas>=0.24.0
- pip
- pytensor>=2.38.0,<2.39
- pytensor>=2.38.2,<2.39
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- numpy>=1.25.0
- pandas>=0.24.0
- pip
- pytensor>=2.38.0,<2.39
- pytensor>=2.38.2,<2.39
- python-graphviz
- rich>=13.7.1
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- pandas>=0.24.0
- pip
- polyagamma
- pytensor>=2.38.0,<2.39
- pytensor>=2.38.2,<2.39
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.25.0
- pandas>=0.24.0
- pip
- pytensor>=2.38.0,<2.39
- pytensor>=2.38.2,<2.39
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies:
- pandas>=0.24.0
- pip
- polyagamma
- pytensor>=2.38.0,<2.39
- pytensor>=2.38.2,<2.39
- python-graphviz
- networkx
- rich>=13.7.1
Expand Down
11 changes: 11 additions & 0 deletions docs/source/api/dims/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,14 @@ Vector distributions
Categorical
MvNormal
ZeroSumNormal


Higher-Order distributions
==========================

.. currentmodule:: pymc.dims
.. autosummary::
:toctree: generated/
:template: distribution.rst

Censored
2 changes: 1 addition & 1 deletion pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def sample_stats_to_xarray(self):
data_warmup = {}
for stat in self.trace.stat_names:
name = rename_key.get(stat, stat)
if name == "tune":
if name in {"tune", "in_warmup"}:
continue
if self.warmup_trace:
data_warmup[name] = np.array(
Expand Down
11 changes: 10 additions & 1 deletion pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,13 @@ def point(self, idx: int) -> dict[str, np.ndarray]:
"""
raise NotImplementedError()

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
def record(
self,
draw: Mapping[str, np.ndarray],
stats: Sequence[Mapping[str, Any]],
*,
in_warmup: bool,
):
"""Record results of a sampling iteration.

Parameters
Expand All @@ -122,6 +128,9 @@ def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, An
Values mapped to variable names
stats: list of dicts
The diagnostic values for each sampler
in_warmup: bool
Whether this draw belongs to the warmup phase. This is a driver-owned
concept and is intended for storage/backends to persist warmup information.
"""
raise NotImplementedError()

Expand Down
28 changes: 24 additions & 4 deletions pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
BlockedStep,
CompoundStep,
StatsBijection,
check_step_emits_tune,
flat_statname,
flatten_steps,
)
Expand Down Expand Up @@ -106,16 +105,26 @@ def __init__(
{sname: stats_dtypes[fname] for fname, sname, is_obj in sstats}
for sstats in stats_bijection._stat_groups
]
if "in_warmup" in stats_dtypes and self.sampler_vars:
# Expose driver-owned warmup marker via the sampler-stats API.
self.sampler_vars[0].setdefault("in_warmup", stats_dtypes["in_warmup"])

self._chain = chain
self._point_fn = point_fn
self._statsbj = stats_bijection
super().__init__()

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
def record(
self,
draw: Mapping[str, np.ndarray],
stats: Sequence[Mapping[str, Any]],
*,
in_warmup: bool,
):
values = self._point_fn(draw)
value_dict = dict(zip(self.varnames, values))
stats_dict = self._statsbj.map(stats)
stats_dict["in_warmup"] = bool(in_warmup)
# Apply pickling to objects stats
for fname in self._statsbj.object_stats.keys():
val_bytes = pickle.dumps(stats_dict[fname])
Expand Down Expand Up @@ -148,6 +157,9 @@ def get_sampler_stats(
self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1
) -> np.ndarray:
slc = slice(burn, None, thin)
if stat_name in {"in_warmup", "tune"}:
# Backwards-friendly alias for users that might try "tune".
return self._get_stats("in_warmup", slc)
# When there's just one sampler, default to remove the sampler dimension
if sampler_idx is None and self._statsbj.n_samplers == 1:
sampler_idx = 0
Expand Down Expand Up @@ -210,8 +222,6 @@ def make_runmeta_and_point_fn(
) -> tuple[mcb.RunMeta, PointFunc]:
variables, point_fn = get_variables_and_point_fn(model, initial_point)

check_step_emits_tune(step)

# In PyMC the sampler stats are grouped by the sampler.
sample_stats = []
steps = flatten_steps(step)
Expand All @@ -235,6 +245,16 @@ def make_runmeta_and_point_fn(
)
sample_stats.append(svar)

# driver owned warmup marker. stored once per draw.
sample_stats.append(
mcb.Variable(
name="in_warmup",
dtype=np.dtype(bool).name,
shape=[],
undefined_ndim=False,
)
)

coordinates = [
mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals)))
for dname, cvals in model.coords.items()
Expand Down
4 changes: 2 additions & 2 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
new = np.zeros(draws, dtype=dtype)
data[varname] = np.concatenate([old, new])

def record(self, point, sampler_stats=None) -> None:
def record(self, point, sampler_stats=None, *, in_warmup: bool) -> None:
"""Record results of a sampling iteration.

Parameters
Expand Down Expand Up @@ -238,5 +238,5 @@ def point_fun(point):

chain.fn = point_fun
for point in point_list:
chain.record(point)
chain.record(point, in_warmup=False)
return MultiTrace([chain])
9 changes: 8 additions & 1 deletion pymc/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def buffer(self, group, var_name, value):
buffer[var_name].append(value)

def record(
self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]
self,
draw: Mapping[str, np.ndarray],
stats: Sequence[Mapping[str, Any]],
*,
in_warmup: bool,
) -> bool | None:
"""Record the step method's returned draw and stats.

Expand All @@ -185,6 +189,7 @@ def record(
self.buffer(group="posterior", var_name=var_name, value=var_value)
for var_name, var_value in self.stats_bijection.map(stats).items():
self.buffer(group="sample_stats", var_name=var_name, value=var_value)
self.buffer(group="sample_stats", var_name="in_warmup", value=bool(in_warmup))
self._buffered_draws += 1
if self._buffered_draws == self.draws_until_flush:
self.flush()
Expand Down Expand Up @@ -525,6 +530,7 @@ def init_trace(
stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(
[step] if isinstance(step, BlockedStep) else step.methods
)
stats_dtypes_shapes = {"in_warmup": (bool, [])} | stats_dtypes_shapes
self.init_group_with_empty(
group=self.root.create_group(name="sample_stats", overwrite=True),
var_dtype_and_shape=stats_dtypes_shapes,
Expand Down Expand Up @@ -683,6 +689,7 @@ def init_group_with_empty(
for i, shape_i in enumerate(shape):
dim = f"{name}_dim_{i}"
dims.append(dim)
assert shape_i is not None, f"{dim} shape is None"
group_coords[dim] = np.arange(shape_i, dtype="int")
dims = ("chain", "draw", *dims)
attrs = extra_var_attrs[name] if extra_var_attrs is not None else {}
Expand Down
1 change: 1 addition & 0 deletions pymc/dims/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pymc.dims.distributions.censored import Censored
from pymc.dims.distributions.scalar import *
from pymc.dims.distributions.vector import *
51 changes: 51 additions & 0 deletions pymc/dims/distributions/censored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2026 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

from pymc.dims.distributions.core import DimDistribution, copy_docstring, expand_dist_dims
from pymc.distributions.censored import Censored as RegularCensored


@copy_docstring(RegularCensored)
class Censored(DimDistribution):
@classmethod
def dist(cls, dist, *, lower=None, upper=None, dim_lengths, **kwargs):
if lower is None:
lower = -np.inf
if upper is None:
upper = np.inf
return super().dist([dist, lower, upper], dim_lengths=dim_lengths, **kwargs)

@classmethod
def xrv_op(cls, dist, lower, upper, core_dims=None, extra_dims=None, rng=None):
if extra_dims is None:
extra_dims = {}

dist = cls._as_xtensor(dist)
lower = cls._as_xtensor(lower)
upper = cls._as_xtensor(upper)

# Any dimensions in extra_dims, or only present in lower, upper,
# must propagate back to the dist as `extra_dims`
bounds_sizes = lower.sizes | upper.sizes
dist_dims_set = set(dist.dims)
extra_dist_dims = extra_dims | {
dim: size for dim, size in bounds_sizes.items() if dim not in dist_dims_set
}
if extra_dist_dims:
dist = expand_dist_dims(dist, extra_dist_dims)

# Probability is inferred from the clip operation
# TODO: Make this a SymbolicRandomVariable that can itself be resized
return dist.clip(lower, upper)
26 changes: 25 additions & 1 deletion pymc/dims/distributions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from collections.abc import Callable, Sequence
from itertools import chain
from typing import cast
from typing import Any, cast

import numpy as np

Expand All @@ -25,7 +25,9 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor
from pytensor.xtensor.shape import Transpose
from pytensor.xtensor.type import XTensorVariable
from pytensor.xtensor.vectorization import XRV

from pymc import SymbolicRandomVariable, modelcontext
from pymc.dims.distributions.transforms import DimTransform, log_odds_transform, log_transform
Expand Down Expand Up @@ -345,3 +347,25 @@ class UnitDimDistribution(DimDistribution):
"""Base class for unit-valued distributions."""

default_transform = log_odds_transform


def expand_dist_dims(dist: XTensorVariable, extra_dims: dict[str, Any]) -> XTensorVariable:
if overlap := (set(extra_dims) & set(dist.dims)):
raise ValueError(f"extra_dims already present in distribution: {sorted(overlap)}")

op = None if dist.owner is None else dist.owner.op
match op:
case XRV():
# Recreate dist with new extra dims
dist_props = dist.owner.op._props_dict()
dist_props["extra_dims"] = (*(extra_dims.keys()), *dist_props["extra_dims"])
new_dist_op = type(dist.owner.op)(**dist_props)
_old_rng, *params_and_dim_lengths = dist.owner.inputs
new_rng = None # We don't propagate the old RNG, because we don't want the new and old dists to be correlated
return new_dist_op(new_rng, *extra_dims.values(), *params_and_dim_lengths)
case Transpose():
return expand_dist_dims(dist.owner.inputs[0], extra_dims=extra_dims).transpose(
..., *dist.dims
)
case _:
raise NotImplementedError(f"expand_dist_dims not implemented for {dist} with op {op}")
4 changes: 1 addition & 3 deletions pymc/dims/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,4 @@ def backward(self, value, *rv_inputs):
return value

def log_jac_det(self, value, *rv_inputs):
# Use following once broadcast_like is implemented
# as_xtensor(0).broadcast_like(value, exclude=self.dims)`
return value.sum(self.dims) * 0
return as_xtensor(0.0).broadcast_like(value, exclude=self.dims)
Loading