diff --git a/funsor/distribution.py b/funsor/distribution.py index 81dee5ca..4892f606 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -306,7 +306,10 @@ def _infer_value_domain(cls, **domains): @classmethod @functools.lru_cache(maxsize=5000) def _infer_param_domain(cls, name, raw_shape): - support = cls.dist_class.arg_constraints.get(name, None) + constraints = getattr(cls.dist_class, "arg_constraints", {}) + if isinstance(constraints, property): + constraints = {} + support = constraints.get(name, None) # XXX: if the backend does not have the same definition of constraints, we should # define backend-specific distributions and overide these `infer_value_domain`, # `infer_param_domain` methods. @@ -326,9 +329,12 @@ def _infer_param_domain(cls, name, raw_shape): # resolve the issue: logits's constraints are real (instead of real_vector) # for discrete multivariate distributions in Pyro elif support_name == "Real": + constraints = getattr(cls.dist_class, "arg_constraints", {}) + if isinstance(constraints, property): + constraints = {} if name == "logits" and ( - "probs" in cls.dist_class.arg_constraints - and type(cls.dist_class.arg_constraints["probs"]).__name__.lstrip("_") + "probs" in constraints + and type(constraints.get("probs")).__name__.lstrip("_") == "Simplex" ): output = Reals[raw_shape[-1 - event_dim :]] @@ -374,10 +380,13 @@ def make_dist( backend_dist_class, param_names=(), generate_eager=True, generate_to_funsor=True ): if not param_names: + constraints = getattr(backend_dist_class, "arg_constraints", {}) + if isinstance(constraints, property): + constraints = {} param_names = tuple( name for name in inspect.getfullargspec(backend_dist_class.__init__)[0][1:] - if name in backend_dist_class.arg_constraints + if name in constraints ) @makefun.with_signature( @@ -437,7 +446,7 @@ def dist_init(self, **kwargs): ("Pareto", ()), ("Poisson", ()), ("StudentT", ()), - ("Uniform", ()), + ("Uniform", ("low", "high")), ("VonMises", ()), ] @@ -587,7 +596,7 @@ def module(self): def __call__(self, cls, args, kwargs): # Check whether distribution class takes any tensor inputs. arg_constraints = getattr(cls, "arg_constraints", None) - if not arg_constraints: + if not arg_constraints or isinstance(arg_constraints, property): return # Check whether any tensor inputs are actually funsors.