33
44import numpy as np
55
6+ from ..compat .utils import array_module
67from ..core .likelihood import Likelihood , _fallback_to_parameters
78from .model import Model
89from ..core .prior import PriorDict
@@ -29,11 +30,13 @@ class HyperparameterLikelihood(Likelihood):
2930 the sampling prior and the hyperparameterised model.
3031 max_samples: int, optional
3132 Maximum number of samples to use from each set.
33+ xp: module
34+ The array backend to use for the data.
3235
3336 """
3437
3538 def __init__ (self , posteriors , hyper_prior , sampling_prior = None ,
36- log_evidences = None , max_samples = 1e100 ):
39+ log_evidences = None , max_samples = 1e100 , xp = np ):
3740 if not isinstance (hyper_prior , Model ):
3841 hyper_prior = Model ([hyper_prior ])
3942 if sampling_prior is None :
@@ -53,26 +56,27 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None,
5356 self .max_samples = max_samples
5457 super (HyperparameterLikelihood , self ).__init__ ()
5558
56- self .data = self .resample_posteriors ()
59+ self .data = self .resample_posteriors (xp = xp )
5760 self .n_posteriors = len (self .posteriors )
5861 self .samples_per_posterior = self .max_samples
5962 self .samples_factor = \
6063 - self .n_posteriors * np .log (self .samples_per_posterior )
6164
6265 def log_likelihood_ratio (self , parameters = None ):
6366 parameters = _fallback_to_parameters (self , parameters )
64- log_l = np .sum (np .log (np .sum (self .hyper_prior .prob (self .data , ** parameters ) /
65- self .data ['prior' ], axis = - 1 )))
67+ probs = self .hyper_prior .prob (self .data , ** parameters )
68+ xp = array_module (probs )
69+ log_l = xp .sum (xp .log (xp .sum (probs / self .data ['prior' ], axis = - 1 )))
6670 log_l += self .samples_factor
67- return np .nan_to_num (log_l )
71+ return xp .nan_to_num (log_l )
6872
6973 def noise_log_likelihood (self ):
7074 return self .evidence_factor
7175
7276 def log_likelihood (self , parameters = None ):
7377 return self .noise_log_likelihood () + self .log_likelihood_ratio (parameters = parameters )
7478
75- def resample_posteriors (self , max_samples = None ):
79+ def resample_posteriors (self , max_samples = None , xp = np ):
7680 """
7781 Convert list of pandas DataFrame object to dict of arrays.
7882
@@ -107,5 +111,5 @@ def resample_posteriors(self, max_samples=None):
107111 for key in data :
108112 data [key ].append (temp [key ])
109113 for key in data :
110- data [key ] = np .array (data [key ])
114+ data [key ] = xp .array (data [key ])
111115 return data
0 commit comments