Skip to content

Commit 9cffcc7

Browse files
committed
Adding TransformedLogLikelihood class
1 parent c53b401 commit 9cffcc7

3 files changed

Lines changed: 108 additions & 5 deletions

File tree

docs/source/transformations.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,24 @@ Example::
3131
transform = pints.LogTransformation(n_parameters)
3232
mcmc = pints.MCMCController(log_posterior, n_chains, x0, transform=transform)
3333

34+
Transformation types:
35+
36+
- :class:`ComposedTransformation`
37+
- :class:`IdentityTransformation`
38+
- :class:`LogitTransformation`
39+
- :class:`LogTransformation`
40+
- :class:`RectangularBoundariesTransformation`
41+
- :class:`ScalingTransformation`
42+
- :class:`UnitCubeTransformation`
43+
44+
Transformed classes:
45+
46+
- :class:`Transformation`
47+
- :class:`TransformedBoundaries`
48+
- :class:`TransformedErrorMeasure`
49+
- :class:`TransformedLogPDF`
50+
- :class:`TransformedLogPrior`
51+
3452

3553
Transformation types
3654
********************

pints/_transformation.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,46 @@ class Transformation():
3232
"""
3333
def convert_log_pdf(self, log_pdf):
3434
"""
35-
Returns a transformed log-PDF class.
35+
Returns a transformed :class:`pints.LogPDF`.
36+
37+
If `log_pdf` is a :class:`LogPrior`, a :class:`TransformedLogPrior`
38+
will be returned, which also transforms the output of the
39+
:meth:`sample` method.
40+
41+
If `log_pdf` is a :class:`LogLikelihood`, a
42+
:class:`TransformedLogLikelihood` is returned, which is assumed to be
43+
invariant with respect to the transform (because it is a probability of
44+
the data, not the parameters). For all other types (including
45+
``LogPrior``) a non-invariant transform is used, see
46+
:class:`TransformedLogPDF` for details.
3647
"""
48+
if isinstance(log_pdf, pints.LogLikelihood):
49+
return TransformedLogLikelihood(log_pdf, self)
50+
if isinstance(log_pdf, pints.LogPrior):
51+
return TransformedLogPrior(log_pdf, self)
3752
return TransformedLogPDF(log_pdf, self)
3853

3954
def convert_log_prior(self, log_prior):
4055
"""
41-
Returns a transformed log-prior class.
56+
Deprecated function: Use :meth:`convert_log_pdf` instead.
4257
"""
58+
# Deprecated on 2026-02-06
59+
import warnings
60+
warnings.warn(
61+
'The method `convert_log_prior` is deprecated. Please use'
62+
' `convert_log_pdf` instead (which will automatically detect'
63+
' detect LogPDF subtypes).')
4364
return TransformedLogPrior(log_prior, self)
4465

4566
def convert_error_measure(self, error_measure):
4667
"""
47-
Returns a transformed error measure class.
68+
Returns a transformed :class:`pints.ErrorMeasure`.
4869
"""
4970
return TransformedErrorMeasure(error_measure, self)
5071

5172
def convert_boundaries(self, boundaries):
5273
"""
53-
Returns a transformed boundaries class.
74+
Returns a transformed :class:`pints.Boundaries` object.
5475
"""
5576
if isinstance(boundaries, pints.RectangularBoundaries):
5677
if self.elementwise():
@@ -1212,6 +1233,62 @@ def sample(self, n):
12121233
return qs
12131234

12141235

1236+
class TransformedLogLikelihood(pints.LogLikelihood):
1237+
r"""
1238+
A :class:`pints.LogLikelihood` that accepts parameters in a transformed
1239+
search space.
1240+
1241+
Unlike a :class:`TransformedLogPDF`, a likelihood (a probability of the
1242+
data, given fixed parameters) is invariant to a parameter transform (but
1243+
not to a data transform), and so no Jacobian term appears. Instead
1244+
1245+
.. math::
1246+
???
1247+
1248+
1249+
For the first order sensitivity, the transformation is done using
1250+
1251+
.. math::
1252+
???
1253+
1254+
Extends :class:`pints.LogLikelihood`.
1255+
1256+
Parameters
1257+
----------
1258+
log_likelihood
1259+
A :class:`pints.LogLikelihood`.
1260+
transformation
1261+
A :class:`pints.Transformation`.
1262+
"""
1263+
def __init__(self, log_likelihood, transformation):
1264+
self._log_likelihood = log_likelihood
1265+
self._transform = transformation
1266+
self._n_parameters = self._log_pdf.n_parameters()
1267+
if self._transform.n_parameters() != self._n_parameters:
1268+
raise ValueError('Number of parameters for log_likelihood and '
1269+
'transformation must match.')
1270+
1271+
def __call__(self, q):
1272+
# Compute LogLikelihood in the model space
1273+
return self._log_likelihood(self._transform.to_model(q))
1274+
1275+
def evaluateS1(self, q):
1276+
""" See :meth:`LogPDF.evaluateS1()`. """
1277+
1278+
# Call evaluateS1 of LogLikelihood in the model space
1279+
logl, dlogl_nojac = self._error.evaluateS1(self._transform.to_model(q))
1280+
1281+
# Calculate the S1 using change of variable (see ErrorMeasure above)
1282+
jacobian = self._transform.jacobian(q)
1283+
dlogl = np.matmul(dlogl_nojac, jacobian) # Jacobian must be 2nd term
1284+
1285+
return logl, dlogl
1286+
1287+
def n_parameters(self):
1288+
""" See :meth:`LogPDF.n_parameters()`. """
1289+
return self._n_parameters
1290+
1291+
12151292
class UnitCubeTransformation(ScalingTransformation):
12161293
"""
12171294
Maps a parameter space onto the unit (hyper)cube.

pints/tests/test_transformation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,8 @@ def test_transformed_log_prior(self):
954954
d = 2
955955
t = pints.LogTransformation(2)
956956
r = pints.UniformLogPrior([0.1, 0.1], [0.9, 0.9])
957-
tr = t.convert_log_prior(r)
957+
tr = t.convert_log_pdf(r)
958+
self.assertIsInstance(tr, pints.TransformedLogPrior)
958959

959960
# Test sample
960961
n = 1
@@ -966,6 +967,13 @@ def test_transformed_log_prior(self):
966967
self.assertEqual(x.shape, (n, d))
967968
self.assertTrue(np.all(x < 0.))
968969

970+
# Test deprecated alias
971+
with warnings.catch_warnings(record=True) as w:
972+
tr = t.convert_log_prior(r)
973+
self.assertEqual(len(w), 1)
974+
self.assertIn('deprecated', str(w[0].message))
975+
self.assertIsInstance(tr, pints.TransformedLogPrior)
976+
969977

970978
if __name__ == '__main__':
971979
unittest.main()

0 commit comments

Comments
 (0)