Skip to content

Commit fd3f3af

Browse files
committed
HYPER: make hyperparameter likelihood handle array backends
1 parent 53a2e77 commit fd3f3af

1 file changed

Lines changed: 11 additions & 7 deletions

File tree

bilby/hyper/likelihood.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55

6+
from ..compat.utils import array_module
67
from ..core.likelihood import Likelihood, _fallback_to_parameters
78
from .model import Model
89
from ..core.prior import PriorDict
@@ -29,11 +30,13 @@ class HyperparameterLikelihood(Likelihood):
2930
the sampling prior and the hyperparameterised model.
3031
max_samples: int, optional
3132
Maximum number of samples to use from each set.
33+
xp: module
34+
The array backend to use for the data.
3235
3336
"""
3437

3538
def __init__(self, posteriors, hyper_prior, sampling_prior=None,
36-
log_evidences=None, max_samples=1e100):
39+
log_evidences=None, max_samples=1e100, xp=np):
3740
if not isinstance(hyper_prior, Model):
3841
hyper_prior = Model([hyper_prior])
3942
if sampling_prior is None:
@@ -53,26 +56,27 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None,
5356
self.max_samples = max_samples
5457
super(HyperparameterLikelihood, self).__init__()
5558

56-
self.data = self.resample_posteriors()
59+
self.data = self.resample_posteriors(xp=xp)
5760
self.n_posteriors = len(self.posteriors)
5861
self.samples_per_posterior = self.max_samples
5962
self.samples_factor =\
6063
- self.n_posteriors * np.log(self.samples_per_posterior)
6164

6265
def log_likelihood_ratio(self, parameters=None):
6366
parameters = _fallback_to_parameters(self, parameters)
64-
log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data, **parameters) /
65-
self.data['prior'], axis=-1)))
67+
probs = self.hyper_prior.prob(self.data, **parameters)
68+
xp = array_module(probs)
69+
log_l = xp.sum(xp.log(xp.sum(probs / self.data['prior'], axis=-1)))
6670
log_l += self.samples_factor
67-
return np.nan_to_num(log_l)
71+
return xp.nan_to_num(log_l)
6872

6973
def noise_log_likelihood(self):
7074
return self.evidence_factor
7175

7276
def log_likelihood(self, parameters=None):
7377
return self.noise_log_likelihood() + self.log_likelihood_ratio(parameters=parameters)
7478

75-
def resample_posteriors(self, max_samples=None):
79+
def resample_posteriors(self, max_samples=None, xp=np):
7680
"""
7781
Convert list of pandas DataFrame object to dict of arrays.
7882
@@ -107,5 +111,5 @@ def resample_posteriors(self, max_samples=None):
107111
for key in data:
108112
data[key].append(temp[key])
109113
for key in data:
110-
data[key] = np.array(data[key])
114+
data[key] = xp.array(data[key])
111115
return data

0 commit comments

Comments
 (0)