Skip to content

Commit a4c610b

Browse files
committed
feat: add dist+logp hybrid path to pmd.CustomDist, remove random arg
1 parent 059bb56 commit a4c610b

3 files changed

Lines changed: 226 additions & 83 deletions

File tree

pymc/dims/distributions/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class DimDistribution:
191191

192192
xrv_op: Callable
193193
default_transform: DimTransform | None = None
194+
_forward_dim_lengths: bool = False
194195

195196
@staticmethod
196197
def _as_xtensor(x):
@@ -325,6 +326,8 @@ def dist(
325326
}
326327
if kwargs.get("rng") is None:
327328
kwargs["rng"] = pt.random.shared_rng(seed=None)
329+
if cls._forward_dim_lengths and dim_lengths is not None:
330+
kwargs["dim_lengths"] = dim_lengths
328331
_, rv = cls.xrv_op(
329332
*dist_params,
330333
extra_dims=extra_dims,

0 commit comments

Comments
 (0)