Skip to content

Commit 98003f6

Browse files
committed
Adds tests for the batched/vectorized _gl_score
1 parent ed3ada8 commit 98003f6

1 file changed

Lines changed: 73 additions & 1 deletion

File tree

mne/decoding/tests/test_search_light.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
1616
from sklearn.ensemble import BaggingClassifier
1717
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
18-
from sklearn.metrics import make_scorer, roc_auc_score
18+
from sklearn.metrics import check_scoring, make_scorer, roc_auc_score
1919
from sklearn.model_selection import cross_val_predict
2020
from sklearn.multiclass import OneVsRestClassifier
2121
from sklearn.pipeline import make_pipeline
@@ -298,6 +298,78 @@ def test_generalization_light(metadata_routing):
298298
assert_array_equal(y_preds[0], y_preds[1])
299299

300300

301+
@pytest.mark.parametrize(
302+
"scoring, est_name, method",
303+
[
304+
(None, "logreg", "predict"),
305+
("accuracy", "logreg", "predict"),
306+
("balanced_accuracy", "logreg", "predict"),
307+
("roc_auc", "logreg", "decision_function"),
308+
("neg_log_loss", "logreg", "predict_proba"),
309+
(None, "ridge", "predict"),
310+
("roc_auc_multiclass", "logreg", "predict_proba"),
311+
("accuracy_kwargs", "logreg", "predict"),
312+
],
313+
)
314+
def test_gl_score_branches(scoring, est_name, method):
315+
"""Test _gl_score against its own can_batch=False nested-loop fallback."""
316+
n_trials, n_sensors, n_iter = 12, 3, 4
317+
rng = np.random.RandomState(0)
318+
X = rng.randn(n_trials, n_sensors, n_iter)
319+
y = rng.randint(0, 3 if scoring == "roc_auc_multiclass" else 2, n_trials)
320+
per_slice = scoring in ("neg_log_loss", "roc_auc_multiclass", "accuracy_kwargs")
321+
# liblinear is binary-only, switch to lbfgs for the multi-class case.
322+
solver = "lbfgs" if scoring == "roc_auc_multiclass" else "liblinear"
323+
if scoring == "roc_auc_multiclass":
324+
scoring = make_scorer(
325+
roc_auc_score, response_method="predict_proba", multi_class="ovr"
326+
)
327+
elif scoring == "accuracy_kwargs":
328+
# start from the default scorer but add a kwarg to prevent batching
329+
acc_func = check_scoring(LogisticRegression(), "accuracy")._score_func
330+
scoring = make_scorer(acc_func, normalize=False)
331+
est = Ridge() if est_name == "ridge" else LogisticRegression(solver=solver)
332+
gl = GeneralizingEstimator(est, scoring=scoring).fit(X, y)
333+
334+
# Measure batching: count pred and call scores. Wraps `fn` calls so they
335+
# append to a bucket; preserve __name__ because _gl_score matches it
336+
def counting(fn, bucket):
337+
def wrapped(*a, **k):
338+
bucket.append(1)
339+
return fn(*a, **k)
340+
341+
wrapped.__name__ = getattr(fn, "__name__", wrapped.__name__)
342+
return wrapped
343+
344+
# First we count calls to scorer
345+
score_calls = []
346+
scorer = check_scoring(est, scoring)
347+
if getattr(scorer, "_score_func", None) is not None:
348+
scorer._score_func = counting(scorer._score_func, score_calls)
349+
350+
# Now we count calls to the estimator that _gl_score will call (hardcoded)
351+
pred_calls = []
352+
for e in gl.estimators_:
353+
setattr(e, method, counting(getattr(e, method), pred_calls))
354+
355+
# Batched path: assert call counts immediately so the buckets only reflect
356+
# this run (the reference run below would otherwise add to them).
357+
gl.scoring = scorer
358+
actual = gl.score(X, y)
359+
assert len(pred_calls) == (n_iter if est_name != "ridge" else n_iter**2)
360+
assert len(score_calls) == (n_iter**2 if per_slice else 0)
361+
362+
# Reference: force can_batch=False. _score_func set (non-None) bypasses the
363+
# qname coercion; missing _response_method makes can_batch False.
364+
def force_fallback(e, X, y):
365+
return scorer(e, X, y)
366+
367+
force_fallback._score_func = id
368+
gl.scoring = force_fallback
369+
expected = gl.score(X, y)
370+
assert_allclose(actual, expected)
371+
372+
301373
@pytest.mark.parametrize(
302374
"n_jobs, verbose", [(1, False), (2, False), (1, True), (2, "info")]
303375
)

0 commit comments

Comments
 (0)