Skip to content

Commit 03fabcc

Browse files
committed
Added test for #533
1 parent 708f287 commit 03fabcc

1 file changed

Lines changed: 22 additions & 32 deletions

File tree

pints/tests/test_diagnostics.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
# copyright notice and full license details.
88
#
99
import unittest
10-
import pints
10+
import warnings
11+
1112
import numpy as np
13+
14+
import pints
1215
import pints._diagnostics
1316

1417

@@ -58,7 +61,7 @@ def test_effective_sample_size(self):
5861
# matrix with two columns of samples
5962
x = np.transpose(np.array([[1.0, 1.1, 1.4, 1.3, 1.3],
6063
[1.0, 2.0, 3.0, 4.0, 5.0]]))
61-
y = pints._diagnostics.effective_sample_size(x)
64+
y = pints.effective_sample_size(x)
6265
self.assertAlmostEqual(y[0], 1.439232, 6)
6366
self.assertAlmostEqual(y[1], 1.315789, 6)
6467

@@ -91,7 +94,7 @@ def test_rhat(self):
9194
chains = np.array([[1.0, 1.1, 1.4, 1.3],
9295
[1.0, 2.0, 3.0, 4.0]])
9396
self.assertAlmostEqual(
94-
pints._diagnostics.rhat(chains), 2.3303847470550716, 6)
97+
pints.rhat(chains), 2.3303847470550716, 6)
9598

9699
# Test Rhat computation for two parameters, chains.shape=(3, 4, 2)
97100
chains = np.array([
@@ -114,7 +117,7 @@ def test_rhat(self):
114117
[0.89531238, 0.63207977]
115118
]])
116119

117-
y = pints._diagnostics.rhat(chains)
120+
y = pints.rhat(chains)
118121
d = np.array(y) - np.array([0.84735944450487122, 1.1712652416950846])
119122
self.assertLess(np.linalg.norm(d), 0.01)
120123

@@ -124,40 +127,27 @@ def test_bad_rhat_inputs(self):
124127

125128
# Pass chain of dimension 1
126129
chains = np.empty(shape=1)
127-
message = (
128-
'Dimension of chains is 1. '
129-
+ 'Method computes Rhat for one '
130-
'or multiple parameters and therefore only accepts 2 or 3 '
131-
'dimensional arrays.')
132130
self.assertRaisesRegex(
133-
ValueError, message[0], pints.rhat, chains)
131+
ValueError, 'only accepts 2 or 3 dimensional', pints.rhat, chains)
134132

135133
# Pass chain of dimension 4
136134
chains = np.empty(shape=(1, 1, 1, 1))
137-
message = (
138-
'Dimension of chains is 4. '
139-
+ 'Method computes Rhat for one '
140-
'or multiple parameters and therefore only accepts 2 or 3 '
141-
'dimensional arrays.')
142135
self.assertRaisesRegex(
143-
ValueError, message[0], pints.rhat, chains)
136+
ValueError, 'only accepts 2 or 3 dimensional', pints.rhat, chains)
137+
138+
# Pass only a single chain
139+
chains = np.empty(shape=(1, 5))
140+
self.assertRaisesRegex(
141+
ValueError, 'only accepts 2 or 3 dimensional', pints.rhat, chains)
144142

145143
# Pass bad warm-up arguments
146144
chains = np.empty(shape=(2, 4))
147145

148-
# warm-up greater than 100%
149-
warm_up = 1.1
150-
message = (
151-
'`warm_up` is set to 1.1. `warm_up` only takes values in [0,1].')
146+
# warm-up greater than 100% or negative
152147
self.assertRaisesRegex(
153-
ValueError, message[0], pints.rhat, chains, warm_up)
154-
155-
# Negative warm-up
156-
warm_up = -0.1
157-
message = (
158-
'`warm_up` is set to -0.1. `warm_up` only takes values in [0,1].')
148+
ValueError, r'takes values in \[0,1\]', pints.rhat, chains, 1.1)
159149
self.assertRaisesRegex(
160-
ValueError, message[0], pints.rhat, chains, warm_up)
150+
ValueError, r'takes values in \[0,1\]', pints.rhat, chains, -0.1)
161151

162152
# Pass chains with too little samples (n<4)
163153
chains = np.empty(shape=(1, 4))
@@ -168,8 +158,7 @@ def test_bad_rhat_inputs(self):
168158
self.assertRaisesRegex(
169159
ValueError, message[0], pints.rhat, chains, warm_up)
170160

171-
def test_rhat_all_params(self):
172-
# Tests that rhat_all works
161+
def test_rhat_deprecated_alias(self):
173162

174163
x = np.array([[[-1.10580535, 2.26589882],
175164
[0.35604827, 1.03523364],
@@ -184,9 +173,10 @@ def test_rhat_all_params(self):
184173
[0.92272047, -1.49997615],
185174
[0.89531238, 0.63207977]]])
186175

187-
y = pints._diagnostics.rhat_all_params(x)
188-
d = np.array(y) - np.array([0.84735944450487122, 1.1712652416950846])
189-
self.assertLess(np.linalg.norm(d), 0.01)
176+
with warnings.catch_warnings(record=True) as w:
177+
z = pints.rhat_all_params(x)
178+
self.assertIn('deprecated', str(w[-1].message))
179+
self.assertEqual(list(pints.rhat(x)), list(z))
190180

191181

192182
if __name__ == '__main__':

0 commit comments

Comments
 (0)