|
15 | 15 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis |
16 | 16 | from sklearn.ensemble import BaggingClassifier |
17 | 17 | from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge |
18 | | -from sklearn.metrics import make_scorer, roc_auc_score |
| 18 | +from sklearn.metrics import check_scoring, make_scorer, roc_auc_score |
19 | 19 | from sklearn.model_selection import cross_val_predict |
20 | 20 | from sklearn.multiclass import OneVsRestClassifier |
21 | 21 | from sklearn.pipeline import make_pipeline |
@@ -298,6 +298,78 @@ def test_generalization_light(metadata_routing): |
298 | 298 | assert_array_equal(y_preds[0], y_preds[1]) |
299 | 299 |
|
300 | 300 |
|
| 301 | +@pytest.mark.parametrize( |
| 302 | + "scoring, est_name, method", |
| 303 | + [ |
| 304 | + (None, "logreg", "predict"), |
| 305 | + ("accuracy", "logreg", "predict"), |
| 306 | + ("balanced_accuracy", "logreg", "predict"), |
| 307 | + ("roc_auc", "logreg", "decision_function"), |
| 308 | + ("neg_log_loss", "logreg", "predict_proba"), |
| 309 | + (None, "ridge", "predict"), |
| 310 | + ("roc_auc_multiclass", "logreg", "predict_proba"), |
| 311 | + ("accuracy_kwargs", "logreg", "predict"), |
| 312 | + ], |
| 313 | +) |
| 314 | +def test_gl_score_branches(scoring, est_name, method): |
| 315 | + """Test _gl_score against its own can_batch=False nested-loop fallback.""" |
| 316 | + n_trials, n_sensors, n_iter = 12, 3, 4 |
| 317 | + rng = np.random.RandomState(0) |
| 318 | + X = rng.randn(n_trials, n_sensors, n_iter) |
| 319 | + y = rng.randint(0, 3 if scoring == "roc_auc_multiclass" else 2, n_trials) |
| 320 | + per_slice = scoring in ("neg_log_loss", "roc_auc_multiclass", "accuracy_kwargs") |
| 321 | + # liblinear is binary-only, switch to lbfgs for the multi-class case. |
| 322 | + solver = "lbfgs" if scoring == "roc_auc_multiclass" else "liblinear" |
| 323 | + if scoring == "roc_auc_multiclass": |
| 324 | + scoring = make_scorer( |
| 325 | + roc_auc_score, response_method="predict_proba", multi_class="ovr" |
| 326 | + ) |
| 327 | + elif scoring == "accuracy_kwargs": |
| 328 | + # start from the default scorer but add a kwarg to prevent batching |
| 329 | + acc_func = check_scoring(LogisticRegression(), "accuracy")._score_func |
| 330 | + scoring = make_scorer(acc_func, normalize=False) |
| 331 | + est = Ridge() if est_name == "ridge" else LogisticRegression(solver=solver) |
| 332 | + gl = GeneralizingEstimator(est, scoring=scoring).fit(X, y) |
| 333 | + |
| 334 | + # Measure batching: count pred and call scores. Wraps `fn` calls so they |
| 335 | + # append to a bucket; preserve __name__ because _gl_score matches it |
| 336 | + def counting(fn, bucket): |
| 337 | + def wrapped(*a, **k): |
| 338 | + bucket.append(1) |
| 339 | + return fn(*a, **k) |
| 340 | + |
| 341 | + wrapped.__name__ = getattr(fn, "__name__", wrapped.__name__) |
| 342 | + return wrapped |
| 343 | + |
| 344 | + # First we count calls to scorer |
| 345 | + score_calls = [] |
| 346 | + scorer = check_scoring(est, scoring) |
| 347 | + if getattr(scorer, "_score_func", None) is not None: |
| 348 | + scorer._score_func = counting(scorer._score_func, score_calls) |
| 349 | + |
| 350 | + # Now we count calls to the estimator that _gl_score will call (hardcoded) |
| 351 | + pred_calls = [] |
| 352 | + for e in gl.estimators_: |
| 353 | + setattr(e, method, counting(getattr(e, method), pred_calls)) |
| 354 | + |
| 355 | + # Batched path: assert call counts immediately so the buckets only reflect |
| 356 | + # this run (the reference run below would otherwise add to them). |
| 357 | + gl.scoring = scorer |
| 358 | + actual = gl.score(X, y) |
| 359 | + assert len(pred_calls) == (n_iter if est_name != "ridge" else n_iter**2) |
| 360 | + assert len(score_calls) == (n_iter**2 if per_slice else 0) |
| 361 | + |
| 362 | + # Reference: force can_batch=False. _score_func set (non-None) bypasses the |
| 363 | + # qname coercion; missing _response_method makes can_batch False. |
| 364 | + def force_fallback(e, X, y): |
| 365 | + return scorer(e, X, y) |
| 366 | + |
| 367 | + force_fallback._score_func = id |
| 368 | + gl.scoring = force_fallback |
| 369 | + expected = gl.score(X, y) |
| 370 | + assert_allclose(actual, expected) |
| 371 | + |
| 372 | + |
301 | 373 | @pytest.mark.parametrize( |
302 | 374 | "n_jobs, verbose", [(1, False), (2, False), (1, True), (2, "info")] |
303 | 375 | ) |
|
0 commit comments