Add pmd.CustomDist — dims-aware custom distribution for pymc.dims#8311
Add pmd.CustomDist — dims-aware custom distribution for pymc.dims#8311williambdean wants to merge 6 commits into
Conversation
These are orthogonal to having a dist argument. You can have dist with logp (or without, maybe it derives it). The only incompatible case is dist AND random, since they both represent the random path |
| from pymc.model.core import new_or_existing_block_model_access | ||
|
|
||
|
|
||
| class _DimCustomDistRV(RandomVariable): |
There was a problem hiding this comment.
It should be a subclass of XRV no?
There was a problem hiding this comment.
NVM, but maybe I wouldn't allow this and would ask users to always use dist. random was more the legacy way of defining random graphs before, but the same way users are asked to use pytensor for logp they should also be asked (and comfortable) with using pytensor/pymc.dist operations for the random?
| return func | ||
|
|
||
|
|
||
| def _default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False): |
There was a problem hiding this comment.
I think this is just rv.zeros_like() ?
| class CustomDist(DimDistribution): | ||
| """Dims-aware CustomDist for pymc.dims. | ||
|
|
||
| Supports the same ``dist=`` (symbolic) and ``logp=`` (black-box) paths as |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8311 +/- ##
==========================================
- Coverage 91.73% 91.66% -0.07%
==========================================
Files 125 126 +1
Lines 20471 20659 +188
==========================================
+ Hits 18779 18938 +159
- Misses 1692 1721 +29
🚀 New features to boost your workflow:
|
|
Good start, I think we should drop the random argument and a lot of complexity falls out of the way |
Supports both symbolic (dist=) and black-box (logp=) paths, enabling user-defined distributions with named dims. The symbolic path auto-derives logprob from inner XRV nodes; the black-box path creates a dynamic RandomVariable subclass and registers _logprob dispatches that reconstruct XTensorVariables for the value and dims-bearing params.
Covers both symbolic (dist=) and black-box (logp=/random=) paths: graph comparison against regular distributions, dim propagation, observed data, custom support points, and model variables as params.
a4c610b to
dae3580
Compare
dae3580 to
68ae365
Compare
| } | ||
| if kwargs.get("rng") is None: | ||
| kwargs["rng"] = pt.random.shared_rng(seed=None) | ||
| if cls._forward_dim_lengths and dim_lengths is not None: |
| else: | ||
| output_dims = cls._infer_output_dims(xtensor_params, extra_dims, core_dims, dim_lengths) | ||
| rv = xtensor_from_tensor(rv, dims=output_dims) |
There was a problem hiding this comment.
I think it will help everyone to be strict. Your RV must return XtensorVariables
| return xtensor_shared_rng(seed=None), rv | ||
| return rv | ||
|
|
||
| # Hybrid: use dist for sampling but user functions for logp/logcdf/support_point |
There was a problem hiding this comment.
Code smell. Note what we do in CustomDist, we build a SymbolicRandomVariable to clearly demarcate the input/output boundary of the CustomDist graph. We should have an equivalent subclass for dims variables. Then you dispatch on that like customdist does.
It's also what enables later having this in factory classes like Censored/Truncated/Mixture that need to recreate the dist variable with the right dimensionality.
| outputs=rv.values, | ||
| ) | ||
|
|
||
| def random_fn(*args, rng=None, size=None): |
There was a problem hiding this comment.
drop legacy random stuff, only allow dist
Adds
CustomDisttopymc.dims.distributions, a sibling topm.CustomDistthat operates onXTensorVariablewith named dims.Two construction paths:
Symbolic (
dist=kwarg): receives XTensorVariable params, returns an XTensorVariable RV (e.g., composingpmd.Normal.dist). Auto-derives logp from inner XRV nodes.Black-box (
logp=kwarg): dynamically creates aRandomVariablesubclass; dispatches_logprob,_logcdf,_support_point. Thevaluearrives asXTensorVariable; use.valuesforpt.*ops orptx.*for dim-aware ops.Key design points:
DimDistribution._as_xtensorpath aspmd.Normaletc. — identical behavior (scalars auto-convert, non-scalars require dims).logp,logcdf,support_point) captured in closures to avoid Python descriptor protocol issues.RandomVariablesubclass sets onlysignature(notndim_supp/ndims_params) to avoid deprecation warnings.