Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,10 @@ def _infer_value_domain(cls, **domains):
@classmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(cls, name, raw_shape):
support = cls.dist_class.arg_constraints.get(name, None)
constraints = getattr(cls.dist_class, "arg_constraints", {})
if isinstance(constraints, property):
constraints = {}
support = constraints.get(name, None)
# XXX: if the backend does not have the same definition of constraints, we should
# define backend-specific distributions and overide these `infer_value_domain`,
# `infer_param_domain` methods.
Expand All @@ -326,9 +329,12 @@ def _infer_param_domain(cls, name, raw_shape):
# resolve the issue: logits's constraints are real (instead of real_vector)
# for discrete multivariate distributions in Pyro
elif support_name == "Real":
constraints = getattr(cls.dist_class, "arg_constraints", {})
if isinstance(constraints, property):
constraints = {}
if name == "logits" and (
"probs" in cls.dist_class.arg_constraints
and type(cls.dist_class.arg_constraints["probs"]).__name__.lstrip("_")
"probs" in constraints
and type(constraints.get("probs")).__name__.lstrip("_")
== "Simplex"
):
output = Reals[raw_shape[-1 - event_dim :]]
Expand Down Expand Up @@ -374,10 +380,13 @@ def make_dist(
backend_dist_class, param_names=(), generate_eager=True, generate_to_funsor=True
):
if not param_names:
constraints = getattr(backend_dist_class, "arg_constraints", {})
if isinstance(constraints, property):
constraints = {}
param_names = tuple(
name
for name in inspect.getfullargspec(backend_dist_class.__init__)[0][1:]
if name in backend_dist_class.arg_constraints
if name in constraints
)

@makefun.with_signature(
Expand Down Expand Up @@ -437,7 +446,7 @@ def dist_init(self, **kwargs):
("Pareto", ()),
("Poisson", ()),
("StudentT", ()),
("Uniform", ()),
("Uniform", ("low", "high")),
("VonMises", ()),
]

Expand Down Expand Up @@ -587,7 +596,7 @@ def module(self):
def __call__(self, cls, args, kwargs):
# Check whether distribution class takes any tensor inputs.
arg_constraints = getattr(cls, "arg_constraints", None)
if not arg_constraints:
if not arg_constraints or isinstance(arg_constraints, property):
return

# Check whether any tensor inputs are actually funsors.
Expand Down
Loading