diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 04ec6f2a8b..13cb535a23 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,6 +30,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install ruff black mypy nbstripout nbformat - name: Lint run: | diff --git a/pyro/distributions/coalescent.py b/pyro/distributions/coalescent.py index 291e87b96f..6dd8f2f8c1 100644 --- a/pyro/distributions/coalescent.py +++ b/pyro/distributions/coalescent.py @@ -249,7 +249,7 @@ class CoalescentRateLikelihood: def __init__(self, leaf_times, coal_times, duration, *, validate_args=None): assert leaf_times.size(-1) == 1 + coal_times.size(-1) assert isinstance(duration, int) and duration >= 2 - if validate_args is True or validate_args is None and is_validation_enabled: + if validate_args is True or validate_args is None and is_validation_enabled(): constraint = CoalescentTimesConstraint(leaf_times, ordered=False) if not constraint.check(coal_times).all(): raise ValueError("Invalid (leaf_times, coal_times)") diff --git a/pyro/distributions/conditional.py b/pyro/distributions/conditional.py index bc3548ed99..23240a9449 100644 --- a/pyro/distributions/conditional.py +++ b/pyro/distributions/conditional.py @@ -43,12 +43,12 @@ def inv(self) -> "ConditionalTransformModule": class _ConditionalInverseTransformModule(ConditionalTransformModule): - def __init__(self, transform: ConditionalTransform): + def __init__(self, transform: ConditionalTransformModule): super().__init__() self._transform = transform @property - def inv(self) -> ConditionalTransform: + def inv(self) -> ConditionalTransformModule: return self._transform def condition(self, context: torch.Tensor): diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index 0ce8fd8cdf..f726b62a7b 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -8,6 +8,8 @@ try: from torch.distributions.constraints import ( Constraint, + _GreaterThan, + _LowerCholesky, boolean, cat, corr_cholesky, @@ -122,12 +124,12 @@ def check(self, value): return ordered_vector.check(value) & independent(positive, 1).check(value) -class _SoftplusPositive(type(positive)): +class _SoftplusPositive(_GreaterThan): def __init__(self): super().__init__(lower_bound=0.0) -class _SoftplusLowerCholesky(type(lower_cholesky)): +class _SoftplusLowerCholesky(_LowerCholesky): pass diff --git a/pyro/distributions/distribution.py b/pyro/distributions/distribution.py index ecac30645e..bdf73d272f 100644 --- a/pyro/distributions/distribution.py +++ b/pyro/distributions/distribution.py @@ -4,12 +4,13 @@ import functools import inspect from abc import ABCMeta, abstractmethod +from typing import Any, Callable, List import torch from pyro.distributions.score_parts import ScoreParts -COERCIONS = [] +COERCIONS: List = [] class DistributionMeta(ABCMeta): @@ -51,6 +52,7 @@ class Distribution(metaclass=DistributionMeta): has_rsample = False has_enumerate_support = False + rsample: Callable[..., torch.Tensor] def __call__(self, *args, **kwargs): """ @@ -65,7 +67,7 @@ def __call__(self, *args, **kwargs): return self.sample(*args, **kwargs) @abstractmethod - def sample(self, *args, **kwargs): + def sample(self, *args, **kwargs) -> torch.Tensor: """ Samples a random value. @@ -82,7 +84,7 @@ def sample(self, *args, **kwargs): raise NotImplementedError @abstractmethod - def log_prob(self, x, *args, **kwargs): + def log_prob(self, *args: Any, **kwargs: Any) -> torch.Tensor: """ Evaluates log probability densities for each of a batch of samples. diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 9f2a242682..9967bcb7f0 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -1005,7 +1005,6 @@ class LinearHMM(HiddenMarkovModel): """ arg_constraints = {} - support = constraints.independent(constraints.real, 2) has_rsample = True def __init__( diff --git a/pyro/distributions/kl.py b/pyro/distributions/kl.py index f88a0019d7..8414c9eb25 100644 --- a/pyro/distributions/kl.py +++ b/pyro/distributions/kl.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +from typing import List from torch.distributions import ( Independent, @@ -53,4 +54,4 @@ def _kl_independent_mvn(p, q): raise NotImplementedError -__all__ = [] +__all__: List[str] = [] diff --git a/pyro/distributions/nanmasked.py b/pyro/distributions/nanmasked.py index 1166ec4c9c..ebd5b7b9b9 100644 --- a/pyro/distributions/nanmasked.py +++ b/pyro/distributions/nanmasked.py @@ -24,7 +24,7 @@ class NanMaskedNormal(Normal): def log_prob(self, value: torch.Tensor) -> torch.Tensor: ok = value.isfinite() if ok.all(): - return super().log_prob(value) + return super().log_prob(value) # type: ignore[no-any-return] # Broadcast all tensors. value, ok, loc, scale = torch.broadcast_tensors(value, ok, self.loc, self.scale) @@ -65,7 +65,7 @@ class NanMaskedMultivariateNormal(MultivariateNormal): def log_prob(self, value: torch.Tensor) -> torch.Tensor: ok = value.isfinite() if ok.all(): - return super().log_prob(value) + return super().log_prob(value) # type: ignore[no-any-return] # Broadcast all tensors. This might waste some computation by eagerly # broadcasting, but the optimal implementation is quite complex. diff --git a/pyro/distributions/projected_normal.py b/pyro/distributions/projected_normal.py index fcd4c29212..51cf87d050 100644 --- a/pyro/distributions/projected_normal.py +++ b/pyro/distributions/projected_normal.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +from typing import Callable, Dict import torch @@ -51,7 +52,9 @@ def model(): arg_constraints = {"concentration": constraints.real_vector} support = constraints.sphere has_rsample = True - _log_prob_impls = {} # maps dim -> function(concentration, value) + _log_prob_impls: Dict[int, Callable] = ( + {} + ) # maps dim -> function(concentration, value) def __init__(self, concentration, *, validate_args=None): assert concentration.dim() >= 1 diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 2f3f255d97..6a390459d6 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -14,6 +14,38 @@ from .. import settings from . import constraints +# Additionally try to import explicitly to help mypy static analysis. +try: + from torch.distributions import ( + Bernoulli, + Cauchy, + Chi2, + ContinuousBernoulli, + Exponential, + ExponentialFamily, + FisherSnedecor, + Gumbel, + HalfCauchy, + HalfNormal, + Kumaraswamy, + Laplace, + LKJCholesky, + LogisticNormal, + MixtureSameFamily, + NegativeBinomial, + OneHotCategoricalStraightThrough, + Pareto, + RelaxedBernoulli, + RelaxedOneHotCategorical, + StudentT, + TransformedDistribution, + VonMises, + Weibull, + Wishart, + ) +except ImportError: + pass + def _clamp_by_zero(x): # works like clamp(x, min=0) but has grad at 0 is 0.5 @@ -202,7 +234,7 @@ def log_prob(self, value): return (-value - 1) * torch.nn.functional.softplus(self.logits) + self.logits -class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin): +class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin): # type: ignore def __init__(self, loc, scale, validate_args=None): base_dist = Normal(loc, scale) # This differs from torch.distributions.LogNormal only in that base_dist is @@ -294,7 +326,7 @@ def log_prob(self, value): ) -class Independent(torch.distributions.Independent, TorchDistributionMixin): +class Independent(torch.distributions.Independent, TorchDistributionMixin): # type: ignore @staticmethod def infer_shapes(**kwargs): raise NotImplementedError diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index ace02da72a..7a287c4f2e 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -16,7 +16,7 @@ from .util import broadcast_shape, scale_and_mask -class TorchDistributionMixin(Distribution, Callable): +class TorchDistributionMixin(Distribution, Callable): # type: ignore[misc] """ Mixin to provide Pyro compatibility for PyTorch distributions. diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 3a93d72780..c643976987 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -5,6 +5,7 @@ import math import warnings import weakref +from typing import List import torch @@ -92,4 +93,4 @@ def _lazy_property__call__(self): raise NotImplementedError -__all__ = [] +__all__: List[str] = [] diff --git a/pyro/distributions/torch_transform.py b/pyro/distributions/torch_transform.py index eedd8f5936..5590bcd83e 100644 --- a/pyro/distributions/torch_transform.py +++ b/pyro/distributions/torch_transform.py @@ -34,7 +34,7 @@ def __init__(self, parts, cache_size=0): def __hash__(self): return super(torch.nn.Module, self).__hash__() - def with_cache(self, cache_size=1): + def with_cache(self, cache_size=1) -> "ComposeTransformModule": if cache_size == self._cache_size: return self return ComposeTransformModule(self.parts, cache_size=cache_size) diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index 8428ce1334..9a286630f9 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -92,7 +92,7 @@ from .power import PositivePowerTransform from .radial import ConditionalRadial, Radial, conditional_radial, radial from .simplex_to_ordered import SimplexToOrderedTransform -from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform +from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform # type: ignore[assignment] from .spline import ConditionalSpline, Spline, conditional_spline, spline from .spline_autoregressive import ( ConditionalSplineAutoregressive, diff --git a/pyro/distributions/transforms/cholesky.py b/pyro/distributions/transforms/cholesky.py index c174c40d4a..7f5d9e9c71 100644 --- a/pyro/distributions/transforms/cholesky.py +++ b/pyro/distributions/transforms/cholesky.py @@ -26,8 +26,8 @@ class CholeskyTransform(Transform): """ bijective = True - domain = constraints.positive_definite - codomain = constraints.lower_cholesky + domain: constraints.Constraint = constraints.positive_definite + codomain: constraints.Constraint = constraints.lower_cholesky def __eq__(self, other): return isinstance(other, CholeskyTransform) @@ -55,8 +55,7 @@ class CorrMatrixCholeskyTransform(CholeskyTransform): bijective = True domain = constraints.corr_matrix - # TODO: change corr_cholesky_constraint to corr_cholesky when the latter is availabler - codomain = constraints.corr_cholesky_constraint + codomain = constraints.corr_cholesky def __eq__(self, other): return isinstance(other, CorrMatrixCholeskyTransform) diff --git a/setup.cfg b/setup.cfg index 1da059e331..7756adacf1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,10 +40,6 @@ ignore_errors = True [mypy-pyro.contrib.*] ignore_errors = True -[mypy-pyro.distributions.*] -ignore_errors = True -warn_unused_ignores = True - [mypy-pyro.generic.*] ignore_errors = True warn_unused_ignores = True