|
6 | 6 |
|
7 | 7 | from skmatter.datasets import load_degenerate_CH4_manifold |
8 | 8 | from skmatter.metrics import ( |
| 9 | + check_global_reconstruction_measures_input, |
| 10 | + check_local_reconstruction_measures_input, |
9 | 11 | componentwise_prediction_rigidity, |
10 | 12 | global_reconstruction_distortion, |
11 | 13 | global_reconstruction_error, |
@@ -214,6 +216,46 @@ def test_local_reconstruction_error_test_idx(self): |
214 | 216 | f"size {test_size}", |
215 | 217 | ) |
216 | 218 |
|
| 219 | + def test_source_target_len(self): |
| 220 | + # tests that the source and target features have the same lenght |
| 221 | + X = np.array([[1, 2, 3], [4, 5, 6]]) |
| 222 | + Y = np.array([[1, 2, 3]]) |
| 223 | + |
| 224 | + train_idx = [0] |
| 225 | + test_idx = [1] |
| 226 | + scaler = None |
| 227 | + estimator = None |
| 228 | + |
| 229 | + with self.assertRaises(ValueError) as context: |
| 230 | + check_global_reconstruction_measures_input( |
| 231 | + X, Y, train_idx, test_idx, scaler, estimator |
| 232 | + ) |
| 233 | + |
| 234 | + expected_message = "First dimension of X (2) and Y (1) must match" |
| 235 | + self.assertEqual(str(context.exception), expected_message) |
| 236 | + |
| 237 | + def test_len_n_local_points(self): |
| 238 | + # tests that source len is greater or equal than n_local_points in LFRE |
| 239 | + X = np.array([[1, 2, 3], [4, 5, 6]]) |
| 240 | + Y = np.array([[1, 1, 1], [2, 2, 2]]) |
| 241 | + |
| 242 | + n_local_points = 10 |
| 243 | + train_idx = [0] |
| 244 | + test_idx = [1] |
| 245 | + scaler = None |
| 246 | + estimator = None |
| 247 | + |
| 248 | + with self.assertRaises(ValueError) as context: |
| 249 | + check_local_reconstruction_measures_input( |
| 250 | + X, Y, n_local_points, train_idx, test_idx, scaler, estimator |
| 251 | + ) |
| 252 | + |
| 253 | + expected_message = ( |
| 254 | + f"X has {len(X)} samples but n_local_points={n_local_points}. " |
| 255 | + "Must have at least n_local_points samples" |
| 256 | + ) |
| 257 | + self.assertEqual(str(context.exception), expected_message) |
| 258 | + |
217 | 259 |
|
218 | 260 | class DistanceTests(unittest.TestCase): |
219 | 261 | @classmethod |
|
0 commit comments