Skip to content

Commit e0a34c8

Browse files
committed
extent tests
1 parent 3fee7e9 commit e0a34c8

1 file changed

Lines changed: 77 additions & 28 deletions

File tree

tests/test_estimators/test_pairwise_difference_classifier.py

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@
4747
_V2_SPARSE: sp.csr_matrix = sp.csr_matrix(_V2)
4848
_V_SPARSE: sp.csr_matrix = sp.csr_matrix(_V)
4949

50+
_SPARSE_CLASSES = [
51+
sp.csr_matrix,
52+
sp.coo_matrix,
53+
sp.bsr_matrix,
54+
sp.coo_matrix,
55+
sp.lil_matrix,
56+
]
57+
5058
# ---------------------------------------------------------------------------
5159
# Expected output constants - dual combinations (all row-pairs from V1 x V2)
5260
# ---------------------------------------------------------------------------
@@ -164,14 +172,14 @@ def test_combine_mode(self) -> None:
164172
result = dual_vector_combinations_sparse(_V1_SPARSE, _V2_SPARSE, mode="combine")
165173
self.assertTrue(sp.issparse(result))
166174
self.assertEqual(result.shape, (4, 4))
167-
self.assertTrue(np.array_equal(result.todense(), _DUAL_COMBINE_EXPECTED))
175+
self.assertTrue(np.array_equal(result.toarray(), _DUAL_COMBINE_EXPECTED))
168176

169177
def test_diff_mode(self) -> None:
170178
"""Diff returns a sparse result with correct shape and values matching dense."""
171179
result = dual_vector_combinations_sparse(_V1_SPARSE, _V2_SPARSE, mode="diff")
172180
self.assertTrue(sp.issparse(result))
173181
self.assertEqual(result.shape, (4, 2))
174-
self.assertTrue(np.array_equal(result.todense(), _DUAL_DIFF_EXPECTED))
182+
self.assertTrue(np.array_equal(result.toarray(), _DUAL_DIFF_EXPECTED))
175183

176184
def test_combine_and_diff_mode(self) -> None:
177185
"""combine_and_diff returns sparse result with correct shape matching dense."""
@@ -183,7 +191,7 @@ def test_combine_and_diff_mode(self) -> None:
183191
self.assertTrue(sp.issparse(result))
184192
self.assertEqual(result.shape, (4, 6))
185193
self.assertTrue(
186-
np.array_equal(result.todense(), _DUAL_COMBINE_AND_DIFF_EXPECTED),
194+
np.array_equal(result.toarray(), _DUAL_COMBINE_AND_DIFF_EXPECTED),
187195
)
188196

189197
def test_invalid_mode_raises(self) -> None:
@@ -204,7 +212,7 @@ def test_default_mode_is_combine(self) -> None:
204212
mode="combine",
205213
)
206214
self.assertTrue(
207-
np.array_equal(result_default.todense(), result_explicit.todense()),
215+
np.array_equal(result_default.toarray(), result_explicit.toarray()),
208216
)
209217

210218

@@ -266,9 +274,9 @@ def test_combine_and_diff_mode(self) -> None:
266274
)
267275

268276
def test_mixed_sparse_dispatch_returns_sparse(self) -> None:
269-
"""If either input is sparse, result must be sparse."""
277+
"""If either input is sparse, result must be dense."""
270278
result = dual_vector_combinations(_V1_SPARSE, _V2)
271-
self.assertTrue(sp.issparse(result))
279+
self.assertIsInstance(result, np.ndarray)
272280

273281
def test_invalid_mode_raises(self) -> None:
274282
"""An unsupported mode string must raise ValueError."""
@@ -343,15 +351,15 @@ def test_combine_mode(self) -> None:
343351
result = single_vector_combinations_sparse(_V_SPARSE, mode="combine")
344352
self.assertTrue(sp.issparse(result))
345353
self.assertEqual(result.shape, (n * (n - 1) // 2, 4))
346-
self.assertTrue(np.array_equal(result.todense(), _SINGLE_COMBINE_EXPECTED))
354+
self.assertTrue(np.array_equal(result.toarray(), _SINGLE_COMBINE_EXPECTED))
347355

348356
def test_diff_mode(self) -> None:
349357
"""Test diff returns sparse result with correct shape matching dense."""
350358
n = _V_SPARSE.shape[0]
351359
result = single_vector_combinations_sparse(_V_SPARSE, mode="diff")
352360
self.assertTrue(sp.issparse(result))
353361
self.assertEqual(result.shape, (n * (n - 1) // 2, 2))
354-
self.assertTrue(np.array_equal(result.todense(), _SINGLE_DIFF_EXPECTED))
362+
self.assertTrue(np.array_equal(result.toarray(), _SINGLE_DIFF_EXPECTED))
355363

356364
def test_combine_and_diff_mode(self) -> None:
357365
"""combine_and_diff returns sparse result with correct shape matching dense."""
@@ -360,7 +368,7 @@ def test_combine_and_diff_mode(self) -> None:
360368
self.assertTrue(sp.issparse(result))
361369
self.assertEqual(result.shape, (n * (n - 1) // 2, 6))
362370
self.assertTrue(
363-
np.array_equal(result.todense(), _SINGLE_COMBINE_AND_DIFF_EXPECTED),
371+
np.array_equal(result.toarray(), _SINGLE_COMBINE_AND_DIFF_EXPECTED),
364372
)
365373

366374
def test_no_self_pairs(self) -> None:
@@ -386,7 +394,7 @@ def test_default_mode_is_combine(self) -> None:
386394
result_default = single_vector_combinations_sparse(_V_SPARSE)
387395
result_explicit = single_vector_combinations_sparse(_V_SPARSE, mode="combine")
388396
self.assertTrue(
389-
np.array_equal(result_default.todense(), result_explicit.todense()),
397+
np.array_equal(result_default.toarray(), result_explicit.toarray()),
390398
)
391399

392400

@@ -407,10 +415,15 @@ def test_combine_mode(self) -> None:
407415
self.assertEqual(dense.shape, (expected_rows, 4))
408416
self.assertTrue(np.array_equal(dense, _SINGLE_COMBINE_EXPECTED))
409417

410-
sparse = single_vector_combinations(_V_SPARSE, mode="combine")
411-
self.assertTrue(sp.issparse(sparse))
412-
self.assertEqual(sparse.shape, (expected_rows, 4))
413-
self.assertTrue(np.array_equal(sparse.toarray(), _SINGLE_COMBINE_EXPECTED))
418+
for sparse_cls in _SPARSE_CLASSES:
419+
with self.subTest(sparse_cls=sparse_cls.__name__):
420+
v = sparse_cls(_V)
421+
sparse = single_vector_combinations(v, mode="combine")
422+
self.assertTrue(sp.issparse(sparse))
423+
self.assertEqual(sparse.shape, (expected_rows, 4))
424+
self.assertTrue(
425+
np.array_equal(sparse.toarray(), _SINGLE_COMBINE_EXPECTED),
426+
)
414427

415428
def test_diff_mode(self) -> None:
416429
"""Verify diff mode dispatches correctly for dense and sparse inputs.
@@ -426,10 +439,13 @@ def test_diff_mode(self) -> None:
426439
self.assertEqual(dense.shape, (expected_rows, 2))
427440
self.assertTrue(np.array_equal(dense, _SINGLE_DIFF_EXPECTED))
428441

429-
sparse = single_vector_combinations(_V_SPARSE, mode="diff")
430-
self.assertTrue(sp.issparse(sparse))
431-
self.assertEqual(sparse.shape, (expected_rows, 2))
432-
self.assertTrue(np.array_equal(sparse.toarray(), _SINGLE_DIFF_EXPECTED))
442+
for sparse_cls in _SPARSE_CLASSES:
443+
with self.subTest(sparse_cls=sparse_cls.__name__):
444+
v = sparse_cls(_V)
445+
sparse = single_vector_combinations(v, mode="diff")
446+
self.assertTrue(sp.issparse(sparse))
447+
self.assertEqual(sparse.shape, (expected_rows, 2))
448+
self.assertTrue(np.array_equal(sparse.toarray(), _SINGLE_DIFF_EXPECTED))
433449

434450
def test_combine_and_diff_mode(self) -> None:
435451
"""Verify that combine_and_diff dispatches to the correct backend.
@@ -445,12 +461,15 @@ def test_combine_and_diff_mode(self) -> None:
445461
self.assertEqual(dense.shape, (expected_rows, 6))
446462
self.assertTrue(np.array_equal(dense, _SINGLE_COMBINE_AND_DIFF_EXPECTED))
447463

448-
sparse = single_vector_combinations(_V_SPARSE, mode="combine_and_diff")
449-
self.assertTrue(sp.issparse(sparse))
450-
self.assertEqual(sparse.shape, (expected_rows, 6))
451-
self.assertTrue(
452-
np.array_equal(sparse.toarray(), _SINGLE_COMBINE_AND_DIFF_EXPECTED),
453-
)
464+
for sparse_cls in _SPARSE_CLASSES:
465+
with self.subTest(sparse_cls=sparse_cls.__name__):
466+
v = sparse_cls(_V)
467+
sparse = single_vector_combinations(v, mode="combine_and_diff")
468+
self.assertTrue(sp.issparse(sparse))
469+
self.assertEqual(sparse.shape, (expected_rows, 6))
470+
self.assertTrue(
471+
np.array_equal(sparse.toarray(), _SINGLE_COMBINE_AND_DIFF_EXPECTED),
472+
)
454473

455474
def test_no_self_pairs(self) -> None:
456475
"""Result must contain fewer rows than dual_vector_combinations(V, V).
@@ -461,10 +480,13 @@ def test_no_self_pairs(self) -> None:
461480
dual_vector_combinations(_V, _V, mode="combine").shape[0],
462481
single_vector_combinations(_V, mode="combine").shape[0],
463482
)
464-
self.assertGreater(
465-
dual_vector_combinations(_V_SPARSE, _V_SPARSE, mode="combine").shape[0],
466-
single_vector_combinations(_V_SPARSE, mode="combine").shape[0],
467-
)
483+
for sparse_cls in _SPARSE_CLASSES:
484+
with self.subTest(sparse_cls=sparse_cls.__name__):
485+
v = sparse_cls(_V)
486+
self.assertGreater(
487+
dual_vector_combinations(v, v, mode="combine").shape[0],
488+
single_vector_combinations(v, mode="combine").shape[0],
489+
)
468490

469491
def test_invalid_mode_raises(self) -> None:
470492
"""An unsupported mode string must raise ValueError."""
@@ -733,6 +755,33 @@ def test_predict_proba_fallback_no_predict_proba(self) -> None:
733755
atol=1e-6,
734756
)
735757

758+
def test_multiclass_fallback_no_predict_proba(self) -> None:
759+
"""Fallback branch must also work for multiclass data (no predict_proba).
760+
761+
Uses RidgeClassifier (no predict_proba) with a 3-class dataset and
762+
confirms output shape is (n_samples, 3) with rows summing to 1.
763+
"""
764+
model = PairwiseDifferenceClassifier(estimator=RidgeClassifier())
765+
model.fit(self.X_multi, self.y_multi)
766+
proba = model.predict_proba(self.X_multi)
767+
self.assertEqual(proba.shape, (len(self.X_multi), 3))
768+
np.testing.assert_allclose(
769+
proba.sum(axis=1),
770+
np.ones(len(self.X_multi)),
771+
atol=1e-6,
772+
)
773+
774+
def test_multiclass_fallback_predict(self) -> None:
775+
"""Predict must return correct-length output with valid class labels.
776+
777+
Uses RidgeClassifier (no predict_proba) with a 3-class dataset.
778+
"""
779+
model = PairwiseDifferenceClassifier(estimator=RidgeClassifier())
780+
model.fit(self.X_multi, self.y_multi)
781+
y_pred = model.predict(self.X_multi)
782+
self.assertEqual(len(y_pred), len(self.X_multi))
783+
self.assertTrue(set(y_pred).issubset(set(self.y_multi)))
784+
736785
def test_predict_proba_normalised_multiclass(self) -> None:
737786
"""predict_proba rows must always sum to 1, even for multiclass.
738787

0 commit comments

Comments
 (0)