|
| 1 | +""" |
| 2 | +Tests for fit_intercept=False across ALL fitters. |
| 3 | +
|
| 4 | +P0 bug: _fit_window_regression_numba hardcoded fit_intercept=True. |
| 5 | +These tests prevent recurrence across all code paths. |
| 6 | +
|
| 7 | +Key invariance: |
| 8 | + - All fitters with fit_intercept=False recover known polynomial coefficients |
| 9 | + - All fitters agree with each other (cross-fitter parity) |
| 10 | + - No fitter produces intercept columns when fit_intercept=False |
| 11 | +""" |
| 12 | +import numpy as np |
| 13 | +import pandas as pd |
| 14 | +import pytest |
| 15 | + |
| 16 | +try: |
| 17 | + from groupby_regression_optimized import ( |
| 18 | + make_parallel_fit_v2, |
| 19 | + make_parallel_fit_v3, |
| 20 | + make_parallel_fit_v4, |
| 21 | + ) |
| 22 | + from groupby_regression_sliding_window import make_sliding_window_fit |
| 23 | +except ImportError: |
| 24 | + from ..groupby_regression_optimized import ( |
| 25 | + make_parallel_fit_v2, |
| 26 | + make_parallel_fit_v3, |
| 27 | + make_parallel_fit_v4, |
| 28 | + ) |
| 29 | + from ..groupby_regression_sliding_window import make_sliding_window_fit |
| 30 | + |
| 31 | + |
| 32 | +# ── Fixture ── |
| 33 | + |
| 34 | +@pytest.fixture |
| 35 | +def poly_df(): |
| 36 | + """DataFrame with polynomial basis including constant term. |
| 37 | +
|
| 38 | + True model: y = 0.5 + 2*drift + 0.3*drift^2 - tgslp + noise(σ=0.05) |
| 39 | +
|
| 40 | + This is the exact pattern that triggers the bug: |
| 41 | + fit_intercept=False with a constant column in linear_columns. |
| 42 | + """ |
| 43 | + rng = np.random.RandomState(42) |
| 44 | + frames = [] |
| 45 | + for sec in range(4): |
| 46 | + for row_bin in range(5): |
| 47 | + n = 200 |
| 48 | + drift = rng.uniform(-1, 1, n) |
| 49 | + tgslp = rng.uniform(-0.5, 0.5, n) |
| 50 | + y = 0.5 + 2 * drift + 0.3 * drift ** 2 - tgslp + rng.normal(0, 0.05, n) |
| 51 | + frames.append(pd.DataFrame({ |
| 52 | + 'sec': sec, |
| 53 | + 'row_bin': row_bin, |
| 54 | + 'drift': drift, |
| 55 | + 'tgslp': tgslp, |
| 56 | + 'y': y, |
| 57 | + 'const': np.ones(n), |
| 58 | + 'drift1': drift, |
| 59 | + 'drift2': drift ** 2, |
| 60 | + 'tgslp1': tgslp, |
| 61 | + })) |
| 62 | + return pd.concat(frames, ignore_index=True) |
| 63 | + |
| 64 | + |
| 65 | +LIN_COLS = ['const', 'drift1', 'drift2', 'tgslp1'] |
| 66 | +GB_COLS = ['sec', 'row_bin'] |
| 67 | +TRUE_COEFFS = {'const': 0.5, 'drift1': 2.0, 'drift2': 0.3, 'tgslp1': -1.0} |
| 68 | + |
| 69 | + |
| 70 | +# ═══════════════════════════════════════════════════════════════ |
| 71 | +# Helper: check coefficients recovered |
| 72 | +# ═══════════════════════════════════════════════════════════════ |
| 73 | + |
| 74 | +def _check_coefficients(dfGB, suffix, fitter_name): |
| 75 | + """Verify recovered coefficients match true values.""" |
| 76 | + for col, true_val in TRUE_COEFFS.items(): |
| 77 | + col_name = f'y_slope_{col}{suffix}' |
| 78 | + if col_name not in dfGB.columns: |
| 79 | + pytest.fail(f"{fitter_name}: missing column {col_name}") |
| 80 | + mean_val = dfGB[col_name].mean() |
| 81 | + np.testing.assert_allclose( |
| 82 | + mean_val, true_val, atol=0.15, |
| 83 | + err_msg=f"{fitter_name}: {col} not recovered " |
| 84 | + f"(got {mean_val:.3f}, expected {true_val:.3f})") |
| 85 | + |
| 86 | + |
| 87 | +def _check_no_intercept_columns(dfGB, suffix, fitter_name): |
| 88 | + """Verify no intercept columns in output.""" |
| 89 | + intercept_cols = [c for c in dfGB.columns if 'intercept' in c.lower()] |
| 90 | + assert len(intercept_cols) == 0, \ |
| 91 | + f"{fitter_name}: fit_intercept=False produced intercept columns: {intercept_cols}" |
| 92 | + |
| 93 | + |
| 94 | +def _check_no_failures(dfGB, suffix, fitter_name): |
| 95 | + """Verify no fit failures.""" |
| 96 | + qf_col = f'quality_flag{suffix}' |
| 97 | + if qf_col in dfGB.columns: |
| 98 | + n_failed = dfGB[qf_col].str.contains('failed').sum() |
| 99 | + assert n_failed == 0, \ |
| 100 | + f"{fitter_name}: {n_failed}/{len(dfGB)} bins failed with fit_intercept=False" |
| 101 | + |
| 102 | + |
| 103 | +# ═══════════════════════════════════════════════════════════════ |
| 104 | +# Test 1: V4 recovers coefficients (INVARIANCE — reference) |
| 105 | +# ═══════════════════════════════════════════════════════════════ |
| 106 | + |
| 107 | +def test_v4_fit_intercept_false_recovers_coefficients(poly_df): |
| 108 | + """V4 with fit_intercept=False recovers known polynomial coefficients.""" |
| 109 | + _, dfGB = make_parallel_fit_v4( |
| 110 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 111 | + linear_columns=LIN_COLS, suffix='_test', |
| 112 | + fit_intercept=False, min_stat=10, |
| 113 | + ) |
| 114 | + _check_no_intercept_columns(dfGB, '_test', 'V4') |
| 115 | + _check_coefficients(dfGB, '_test', 'V4') |
| 116 | + |
| 117 | + |
| 118 | +# ═══════════════════════════════════════════════════════════════ |
| 119 | +# Test 2: V3 recovers coefficients (INVARIANCE) |
| 120 | +# ═══════════════════════════════════════════════════════════════ |
| 121 | + |
| 122 | +def test_v3_fit_intercept_false_recovers_coefficients(poly_df): |
| 123 | + """V3 with fit_intercept=False recovers known polynomial coefficients.""" |
| 124 | + _, dfGB = make_parallel_fit_v3( |
| 125 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 126 | + linear_columns=LIN_COLS, suffix='_test', |
| 127 | + fit_intercept=False, min_stat=10, |
| 128 | + ) |
| 129 | + _check_no_intercept_columns(dfGB, '_test', 'V3') |
| 130 | + _check_coefficients(dfGB, '_test', 'V3') |
| 131 | + |
| 132 | + |
| 133 | +# ═══════════════════════════════════════════════════════════════ |
| 134 | +# Test 3: V2 recovers coefficients (INVARIANCE) |
| 135 | +# ═══════════════════════════════════════════════════════════════ |
| 136 | + |
| 137 | +def test_v2_fit_intercept_false_recovers_coefficients(poly_df): |
| 138 | + """V2 with fit_intercept=False recovers known polynomial coefficients.""" |
| 139 | + _, dfGB = make_parallel_fit_v2( |
| 140 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 141 | + linear_columns=LIN_COLS, suffix='_test', |
| 142 | + fit_intercept=False, min_stat=10, |
| 143 | + ) |
| 144 | + _check_no_intercept_columns(dfGB, '_test', 'V2') |
| 145 | + _check_coefficients(dfGB, '_test', 'V2') |
| 146 | + |
| 147 | + |
| 148 | +# ═══════════════════════════════════════════════════════════════ |
| 149 | +# Test 4: SW fit recovers coefficients (INVARIANCE — bug target) |
| 150 | +# ═══════════════════════════════════════════════════════════════ |
| 151 | + |
| 152 | +def test_sw_fit_intercept_false_recovers_coefficients(poly_df): |
| 153 | + """SW with fit_intercept=False and window=0 recovers known coefficients. |
| 154 | +
|
| 155 | + This is the exact bug scenario: polynomial basis with constant term, |
| 156 | + fit_intercept=False, sliding window path. |
| 157 | + """ |
| 158 | + dfGB = make_sliding_window_fit( |
| 159 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 160 | + linear_columns=LIN_COLS, |
| 161 | + window_spec={'sec': 0, 'row_bin': 0}, |
| 162 | + suffix='_test', fit_intercept=False, min_stat=10, |
| 163 | + ) |
| 164 | + _check_no_failures(dfGB, '_test', 'SW') |
| 165 | + _check_no_intercept_columns(dfGB, '_test', 'SW') |
| 166 | + _check_coefficients(dfGB, '_test', 'SW') |
| 167 | + |
| 168 | + |
| 169 | +# ═══════════════════════════════════════════════════════════════ |
| 170 | +# Test 5: SW ≡ V4 with fit_intercept=False (INVARIANCE — gate) |
| 171 | +# ═══════════════════════════════════════════════════════════════ |
| 172 | + |
| 173 | +def test_sw_fit_intercept_false_matches_v4(poly_df): |
| 174 | + """SW fit with window=0 and fit_intercept=False ≡ V4 on same data.""" |
| 175 | + _, dfGB_v4 = make_parallel_fit_v4( |
| 176 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 177 | + linear_columns=LIN_COLS, suffix='_ref', |
| 178 | + fit_intercept=False, min_stat=10, |
| 179 | + ) |
| 180 | + |
| 181 | + dfGB_sw = make_sliding_window_fit( |
| 182 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 183 | + linear_columns=LIN_COLS, |
| 184 | + window_spec={'sec': 0, 'row_bin': 0}, |
| 185 | + suffix='_ref', fit_intercept=False, min_stat=10, |
| 186 | + ) |
| 187 | + |
| 188 | + v4 = dfGB_v4.sort_values(GB_COLS).reset_index(drop=True) |
| 189 | + sw = dfGB_sw.sort_values(GB_COLS).reset_index(drop=True) |
| 190 | + |
| 191 | + assert len(v4) == len(sw), f"Row count: v4={len(v4)}, sw={len(sw)}" |
| 192 | + |
| 193 | + slope_cols = [c for c in v4.columns if 'slope' in c] |
| 194 | + for col in slope_cols: |
| 195 | + if col in sw.columns: |
| 196 | + v4_vals = v4[col].values |
| 197 | + sw_vals = sw[col].values |
| 198 | + valid = np.isfinite(v4_vals) & np.isfinite(sw_vals) |
| 199 | + if valid.sum() > 0: |
| 200 | + np.testing.assert_allclose( |
| 201 | + sw_vals[valid], v4_vals[valid], |
| 202 | + rtol=1e-6, atol=1e-10, |
| 203 | + err_msg=f"SW ≠ V4 for {col} with fit_intercept=False") |
| 204 | + |
| 205 | + |
| 206 | +# ═══════════════════════════════════════════════════════════════ |
| 207 | +# Test 6: SW numba ≡ SW numpy with fit_intercept=False (INVARIANCE) |
| 208 | +# ═══════════════════════════════════════════════════════════════ |
| 209 | + |
| 210 | +def test_sw_fit_intercept_false_numba_matches_numpy(poly_df): |
| 211 | + """Numba path ≡ numpy path with fit_intercept=False in SW.""" |
| 212 | + ws = {'row_bin': 1} |
| 213 | + |
| 214 | + dfGB_numpy = make_sliding_window_fit( |
| 215 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 216 | + linear_columns=LIN_COLS, window_spec=ws, |
| 217 | + suffix='_test', fit_intercept=False, min_stat=10, |
| 218 | + backend='numpy', |
| 219 | + ) |
| 220 | + |
| 221 | + try: |
| 222 | + dfGB_numba = make_sliding_window_fit( |
| 223 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 224 | + linear_columns=LIN_COLS, window_spec=ws, |
| 225 | + suffix='_test', fit_intercept=False, min_stat=10, |
| 226 | + backend='numba', |
| 227 | + ) |
| 228 | + except Exception: |
| 229 | + pytest.skip("Numba not available") |
| 230 | + |
| 231 | + np_s = dfGB_numpy.sort_values(GB_COLS).reset_index(drop=True) |
| 232 | + nb_s = dfGB_numba.sort_values(GB_COLS).reset_index(drop=True) |
| 233 | + |
| 234 | + assert len(np_s) == len(nb_s) |
| 235 | + |
| 236 | + for name, df_check in [('numpy', np_s), ('numba', nb_s)]: |
| 237 | + _check_no_failures(df_check, '_test', f'SW-{name}') |
| 238 | + |
| 239 | + slope_cols = [c for c in np_s.columns if 'slope' in c] |
| 240 | + for col in slope_cols: |
| 241 | + if col in nb_s.columns: |
| 242 | + np_vals = np_s[col].values |
| 243 | + nb_vals = nb_s[col].values |
| 244 | + valid = np.isfinite(np_vals) & np.isfinite(nb_vals) |
| 245 | + if valid.sum() > 0: |
| 246 | + np.testing.assert_allclose( |
| 247 | + nb_vals[valid], np_vals[valid], |
| 248 | + rtol=1e-6, atol=1e-10, |
| 249 | + err_msg=f"numba ≠ numpy for {col} with fit_intercept=False") |
| 250 | + |
| 251 | + |
| 252 | +# ═══════════════════════════════════════════════════════════════ |
| 253 | +# Test 7: Cross-fitter parity V2 ≡ V3 ≡ V4 (INVARIANCE) |
| 254 | +# ═══════════════════════════════════════════════════════════════ |
| 255 | + |
| 256 | +def test_cross_fitter_parity_fit_intercept_false(poly_df): |
| 257 | + """All per-bin fitters agree with fit_intercept=False.""" |
| 258 | + results = {} |
| 259 | + |
| 260 | + for name, func in [('V2', make_parallel_fit_v2), |
| 261 | + ('V3', make_parallel_fit_v3), |
| 262 | + ('V4', make_parallel_fit_v4)]: |
| 263 | + _, dfGB = func( |
| 264 | + df=poly_df, gb_columns=GB_COLS, fit_columns=['y'], |
| 265 | + linear_columns=LIN_COLS, suffix='_test', |
| 266 | + fit_intercept=False, min_stat=10, |
| 267 | + ) |
| 268 | + results[name] = dfGB.sort_values(GB_COLS).reset_index(drop=True) |
| 269 | + |
| 270 | + # Compare V2 and V3 against V4 (reference) |
| 271 | + ref = results['V4'] |
| 272 | + slope_cols = [c for c in ref.columns if 'slope' in c] |
| 273 | + |
| 274 | + for name in ['V2', 'V3']: |
| 275 | + other = results[name] |
| 276 | + for col in slope_cols: |
| 277 | + if col in other.columns: |
| 278 | + ref_vals = ref[col].values |
| 279 | + other_vals = other[col].values |
| 280 | + valid = np.isfinite(ref_vals) & np.isfinite(other_vals) |
| 281 | + if valid.sum() > 0: |
| 282 | + np.testing.assert_allclose( |
| 283 | + other_vals[valid], ref_vals[valid], |
| 284 | + rtol=1e-4, atol=1e-8, |
| 285 | + err_msg=f"{name} ≠ V4 for {col} with fit_intercept=False") |
0 commit comments