diff --git a/pyro/distributions/conjugate.py b/pyro/distributions/conjugate.py index ffb37f3148..ce5cc59574 100644 --- a/pyro/distributions/conjugate.py +++ b/pyro/distributions/conjugate.py @@ -18,7 +18,7 @@ def _log_beta_1(alpha, value, is_sparse): if is_sparse: mask = value != 0 value, alpha, mask = torch.broadcast_tensors(value, alpha, mask) - result = torch.zeros_like(value) + result = torch.zeros_like(value, dtype=alpha.dtype) value = value[mask] alpha = alpha[mask] result[mask] = (