77# copyright notice and full license details.
88#
99import unittest
10- import pints
10+ import warnings
11+
1112import numpy as np
13+
14+ import pints
1215import pints ._diagnostics
1316
1417
@@ -58,7 +61,7 @@ def test_effective_sample_size(self):
5861 # matrix with two columns of samples
5962 x = np .transpose (np .array ([[1.0 , 1.1 , 1.4 , 1.3 , 1.3 ],
6063 [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]]))
61- y = pints ._diagnostics . effective_sample_size (x )
64+ y = pints .effective_sample_size (x )
6265 self .assertAlmostEqual (y [0 ], 1.439232 , 6 )
6366 self .assertAlmostEqual (y [1 ], 1.315789 , 6 )
6467
@@ -91,7 +94,7 @@ def test_rhat(self):
9194 chains = np .array ([[1.0 , 1.1 , 1.4 , 1.3 ],
9295 [1.0 , 2.0 , 3.0 , 4.0 ]])
9396 self .assertAlmostEqual (
94- pints ._diagnostics . rhat (chains ), 2.3303847470550716 , 6 )
97+ pints .rhat (chains ), 2.3303847470550716 , 6 )
9598
9699 # Test Rhat computation for two parameters, chains.shape=(3, 4, 2)
97100 chains = np .array ([
@@ -114,7 +117,7 @@ def test_rhat(self):
114117 [0.89531238 , 0.63207977 ]
115118 ]])
116119
117- y = pints ._diagnostics . rhat (chains )
120+ y = pints .rhat (chains )
118121 d = np .array (y ) - np .array ([0.84735944450487122 , 1.1712652416950846 ])
119122 self .assertLess (np .linalg .norm (d ), 0.01 )
120123
@@ -124,40 +127,27 @@ def test_bad_rhat_inputs(self):
124127
125128 # Pass chain of dimension 1
126129 chains = np .empty (shape = 1 )
127- message = (
128- 'Dimension of chains is 1. '
129- + 'Method computes Rhat for one '
130- 'or multiple parameters and therefore only accepts 2 or 3 '
131- 'dimensional arrays.' )
132130 self .assertRaisesRegex (
133- ValueError , message [ 0 ] , pints .rhat , chains )
131+ ValueError , 'only accepts 2 or 3 dimensional' , pints .rhat , chains )
134132
135133 # Pass chain of dimension 4
136134 chains = np .empty (shape = (1 , 1 , 1 , 1 ))
137- message = (
138- 'Dimension of chains is 4. '
139- + 'Method computes Rhat for one '
140- 'or multiple parameters and therefore only accepts 2 or 3 '
141- 'dimensional arrays.' )
142135 self .assertRaisesRegex (
143- ValueError , message [0 ], pints .rhat , chains )
136+ ValueError , 'only accepts 2 or 3 dimensional' , pints .rhat , chains )
137+
138+ # Pass only a single chain
139+ chains = np .empty (shape = (1 , 5 ))
140+ self .assertRaisesRegex (
141+ ValueError , 'only accepts 2 or 3 dimensional' , pints .rhat , chains )
144142
145143 # Pass bad warm-up arguments
146144 chains = np .empty (shape = (2 , 4 ))
147145
148- # warm-up greater than 100%
149- warm_up = 1.1
150- message = (
151- '`warm_up` is set to 1.1. `warm_up` only takes values in [0,1].' )
146+ # warm-up greater than 100% or negative
152147 self .assertRaisesRegex (
153- ValueError , message [0 ], pints .rhat , chains , warm_up )
154-
155- # Negative warm-up
156- warm_up = - 0.1
157- message = (
158- '`warm_up` is set to -0.1. `warm_up` only takes values in [0,1].' )
148+ ValueError , r'takes values in \[0,1\]' , pints .rhat , chains , 1.1 )
159149 self .assertRaisesRegex (
160- ValueError , message [ 0 ] , pints .rhat , chains , warm_up )
150+ ValueError , r'takes values in \[0,1\]' , pints .rhat , chains , - 0.1 )
161151
162152 # Pass chains with too little samples (n<4)
163153 chains = np .empty (shape = (1 , 4 ))
@@ -168,8 +158,7 @@ def test_bad_rhat_inputs(self):
168158 self .assertRaisesRegex (
169159 ValueError , message [0 ], pints .rhat , chains , warm_up )
170160
171- def test_rhat_all_params (self ):
172- # Tests that rhat_all works
161+ def test_rhat_deprecated_alias (self ):
173162
174163 x = np .array ([[[- 1.10580535 , 2.26589882 ],
175164 [0.35604827 , 1.03523364 ],
@@ -184,9 +173,10 @@ def test_rhat_all_params(self):
184173 [0.92272047 , - 1.49997615 ],
185174 [0.89531238 , 0.63207977 ]]])
186175
187- y = pints ._diagnostics .rhat_all_params (x )
188- d = np .array (y ) - np .array ([0.84735944450487122 , 1.1712652416950846 ])
189- self .assertLess (np .linalg .norm (d ), 0.01 )
176+ with warnings .catch_warnings (record = True ) as w :
177+ z = pints .rhat_all_params (x )
178+ self .assertIn ('deprecated' , str (w [- 1 ].message ))
179+ self .assertEqual (list (pints .rhat (x )), list (z ))
190180
191181
192182if __name__ == '__main__' :
0 commit comments