Skip to content

Add pmd.CustomDist — dims-aware custom distribution for pymc.dims#8311

Open
williambdean wants to merge 6 commits into
pymc-devs:mainfrom
williambdean:customdist-dims
Open

Add pmd.CustomDist — dims-aware custom distribution for pymc.dims#8311
williambdean wants to merge 6 commits into
pymc-devs:mainfrom
williambdean:customdist-dims

Conversation

@williambdean
Copy link
Copy Markdown
Contributor

Adds CustomDist to pymc.dims.distributions, a sibling to pm.CustomDist that operates on XTensorVariable with named dims.

Two construction paths:

  • Symbolic (dist= kwarg): receives XTensorVariable params, returns an XTensorVariable RV (e.g., composing pmd.Normal.dist). Auto-derives logp from inner XRV nodes.

  • Black-box (logp= kwarg): dynamically creates a RandomVariable subclass; dispatches _logprob, _logcdf, _support_point. The value arrives as XTensorVariable; use .values for pt.* ops or ptx.* for dim-aware ops.

Key design points:

  • Params go through the same DimDistribution._as_xtensor path as pmd.Normal etc. — identical behavior (scalars auto-convert, non-scalars require dims).
  • User callables (logp, logcdf, support_point) captured in closures to avoid Python descriptor protocol issues.
  • Dynamic RandomVariable subclass sets only signature (not ndim_supp/ndims_params) to avoid deprecation warnings.

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 20, 2026

Black-box (logp= kwarg),

These are orthogonal to having a dist argument. You can have dist with logp (or without, maybe it derives it). The only incompatible case is dist AND random, since they both represent the random path

from pymc.model.core import new_or_existing_block_model_access


class _DimCustomDistRV(RandomVariable):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be a subclass of XRV no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM, but maybe I wouldn't allow this and would ask users to always use dist. random was more the legacy way of defining random graphs before, but the same way users are asked to use pytensor for logp they should also be asked (and comfortable) with using pytensor/pymc.dist operations for the random?

Comment thread pymc/dims/distributions/custom.py Outdated
return func


def _default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is just rv.zeros_like() ?

Comment thread pymc/dims/distributions/custom.py Outdated
class CustomDist(DimDistribution):
"""Dims-aware CustomDist for pymc.dims.

Supports the same ``dist=`` (symbolic) and ``logp=`` (black-box) paths as
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is logp "black-box"?

Comment thread pymc/dims/distributions/custom.py Outdated
Comment thread pymc/dims/distributions/custom.py Outdated
Comment thread pymc/dims/distributions/custom.py Outdated
Comment thread pymc/dims/distributions/custom.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented May 20, 2026

Codecov Report

❌ Patch coverage is 84.57447% with 29 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.66%. Comparing base (f09f6b6) to head (b512e11).

Files with missing lines Patch % Lines
pymc/dims/distributions/custom.py 84.23% 29 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #8311      +/-   ##
==========================================
- Coverage   91.73%   91.66%   -0.07%     
==========================================
  Files         125      126       +1     
  Lines       20471    20659     +188     
==========================================
+ Hits        18779    18938     +159     
- Misses       1692     1721      +29     
Files with missing lines Coverage Δ
pymc/dims/distributions/__init__.py 100.00% <100.00%> (ø)
pymc/dims/distributions/core.py 91.83% <100.00%> (+0.12%) ⬆️
pymc/dims/distributions/custom.py 84.23% <84.23%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread pymc/dims/distributions/custom.py
@ricardoV94
Copy link
Copy Markdown
Member

Good start, I think we should drop the random argument and a lot of complexity falls out of the way

@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 20, 2026

Documentation build overview

📚 pymc | 🛠️ Build #32779749 | 📁 Comparing b512e11 against latest (f09f6b6)

  🔍 Preview build  

1 file changed
± glossary.html

Supports both symbolic (dist=) and black-box (logp=) paths,
enabling user-defined distributions with named dims. The symbolic
path auto-derives logprob from inner XRV nodes; the black-box path
creates a dynamic RandomVariable subclass and registers _logprob
dispatches that reconstruct XTensorVariables for the value and
dims-bearing params.
Covers both symbolic (dist=) and black-box (logp=/random=) paths:
graph comparison against regular distributions, dim propagation,
observed data, custom support points, and model variables as params.
}
if kwargs.get("rng") is None:
kwargs["rng"] = pt.random.shared_rng(seed=None)
if cls._forward_dim_lengths and dim_lengths is not None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code smell, why?

Comment on lines +270 to +272
else:
output_dims = cls._infer_output_dims(xtensor_params, extra_dims, core_dims, dim_lengths)
rv = xtensor_from_tensor(rv, dims=output_dims)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will help everyone to be strict. Your RV must return XtensorVariables

return xtensor_shared_rng(seed=None), rv
return rv

# Hybrid: use dist for sampling but user functions for logp/logcdf/support_point
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code smell. Note what we do in CustomDist, we build a SymbolicRandomVariable to clearly demarcate the input/output boundary of the CustomDist graph. We should have an equivalent subclass for dims variables. Then you dispatch on that like customdist does.

It's also what enables later having this in factory classes like Censored/Truncated/Mixture that need to recreate the dist variable with the right dimensionality.

outputs=rv.values,
)

def random_fn(*args, rng=None, size=None):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop legacy random stuff, only allow dist

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants