Skip to content

Commit 3f3bce2

Browse files
committed
Phase 13.21.GB-PERF Turn 2: batch MAD in make_parallel_fit_v4
F3: _get_batch_mad_kernel replaces 1.26M per-bin np.median calls (52s cumulative) with one numba prange kernel call. Per-bin residuals stored in flat CSR array during OLS loop, batch MAD computed post-loop. Env flag GBAI_DISABLE_BATCH_MEDIAN=1 for fallback. Tests: 575/3/0 new. T2 re-profile pending.
1 parent bd74974 commit 3f3bce2

1 file changed

Lines changed: 108 additions & 5 deletions

File tree

UTILS/dfextensions/groupby_regression/groupby_regression_optimized.py

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,6 +1547,67 @@ def make_parallel_fit_v3(
15471547
return df_out, dfGB
15481548

15491549

1550+
def _get_batch_mad_kernel():
1551+
"""Compile numba kernel for batch MAD computation.
1552+
1553+
Phase 13.21.GB-PERF F3: replaces 1.26M per-bin np.median calls
1554+
(52s cumulative, 41μs Python dispatch per call) with a single
1555+
JIT-compiled prange loop over bins.
1556+
1557+
MAD(x) = median(|x - median(x)|)
1558+
1559+
Each bin: two sorts of a small per-bin scratch buffer (~20 elements).
1560+
Thread-safe: each prange iteration uses its own local scratch buffer.
1561+
"""
1562+
import numba as nb
1563+
1564+
@nb.njit(cache=True, parallel=True)
1565+
def _batch_mad(resid_all, resid_offsets, n_groups, n_tgt, out_mad):
1566+
"""Compute MAD for all bins in parallel.
1567+
1568+
Parameters
1569+
----------
1570+
resid_all : (total_valid_rows, n_tgt) float64
1571+
Flat array of valid residuals, packed contiguously by bin.
1572+
resid_offsets : (n_groups + 1,) int64
1573+
CSR offsets into resid_all per bin.
1574+
n_groups : int
1575+
n_tgt : int
1576+
out_mad : (n_groups, n_tgt) float64
1577+
Output array, pre-initialized to NaN.
1578+
"""
1579+
for gi in nb.prange(n_groups):
1580+
i0 = resid_offsets[gi]
1581+
i1 = resid_offsets[gi + 1]
1582+
m = i1 - i0
1583+
if m == 0:
1584+
continue # out_mad already NaN
1585+
1586+
for t in range(n_tgt):
1587+
# Step 1: copy residuals to scratch, find median
1588+
buf = np.empty(m, dtype=np.float64)
1589+
for r in range(m):
1590+
buf[r] = resid_all[i0 + r, t]
1591+
buf.sort()
1592+
if m % 2 == 1:
1593+
med = buf[m // 2]
1594+
else:
1595+
med = (buf[m // 2 - 1] + buf[m // 2]) / 2.0
1596+
1597+
# Step 2: compute |resid - median|, find median of that
1598+
for r in range(m):
1599+
buf[r] = abs(resid_all[i0 + r, t] - med)
1600+
buf.sort()
1601+
if m % 2 == 1:
1602+
mad_val = buf[m // 2]
1603+
else:
1604+
mad_val = (buf[m // 2 - 1] + buf[m // 2]) / 2.0
1605+
1606+
out_mad[gi, t] = mad_val
1607+
1608+
return _batch_mad
1609+
1610+
15501611
def make_parallel_fit_v4(
15511612
*,
15521613
df,
@@ -1806,6 +1867,15 @@ def make_parallel_fit_v4(
18061867
# PROCESS EACH GROUP
18071868
# ========================================================================
18081869

1870+
# Phase 13.21.GB-PERF F3: batch MAD computation
1871+
import os
1872+
_use_batch_mad = os.environ.get("GBAI_DISABLE_BATCH_MEDIAN", "") != "1"
1873+
if _use_batch_mad:
1874+
# Pre-allocate flat residual buffer for post-loop batch MAD
1875+
_resid_flat = np.empty((N, n_tgt), dtype=np.float64)
1876+
_resid_offsets = np.zeros(n_groups + 1, dtype=np.int64)
1877+
_resid_pos = 0
1878+
18091879
# NumPy fallback (Numba kernel would be similar but JIT-compiled)
18101880
for gi in range(n_groups):
18111881
i0, i1 = offsets[gi], offsets[gi + 1]
@@ -1818,6 +1888,8 @@ def make_parallel_fit_v4(
18181888
n_valid_arr[gi] = 0
18191889
n_filtered_arr[gi] = 0
18201890
status_arr[gi] = 'INSUFFICIENT_DATA'
1891+
if _use_batch_mad:
1892+
_resid_offsets[gi + 1] = _resid_pos
18211893
continue
18221894

18231895
# Extract data for this group
@@ -1848,6 +1920,8 @@ def make_parallel_fit_v4(
18481920
# Check if enough valid data remains
18491921
if n_valid < int(min_stat):
18501922
status_arr[gi] = 'INSUFFICIENT_DATA'
1923+
if _use_batch_mad:
1924+
_resid_offsets[gi + 1] = _resid_pos
18511925
continue
18521926

18531927
# Apply filter
@@ -1918,11 +1992,18 @@ def make_parallel_fit_v4(
19181992
y_pred_unweighted = X1 @ coeffs # X1 is unweighted design matrix
19191993
resid_unweighted = Yg - y_pred_unweighted
19201994

1921-
# Compute MAD for each target
1922-
for t_idx in range(n_tgt):
1923-
resid_t = resid_unweighted[:, t_idx]
1924-
mad_val = np.median(np.abs(resid_t - np.median(resid_t)))
1925-
mad_arr[gi, t_idx] = mad_val
1995+
if _use_batch_mad:
1996+
# Phase 13.21.GB-PERF F3: store residuals for post-loop batch
1997+
n_v = resid_unweighted.shape[0]
1998+
_resid_flat[_resid_pos:_resid_pos + n_v, :] = resid_unweighted
1999+
_resid_pos += n_v
2000+
_resid_offsets[gi + 1] = _resid_pos
2001+
else:
2002+
# Original per-bin MAD (fallback)
2003+
for t_idx in range(n_tgt):
2004+
resid_t = resid_unweighted[:, t_idx]
2005+
mad_val = np.median(np.abs(resid_t - np.median(resid_t)))
2006+
mad_arr[gi, t_idx] = mad_val
19262007

19272008
# ================================================================
19282009
# COMPUTE PARAMETER ERRORS
@@ -1942,8 +2023,30 @@ def make_parallel_fit_v4(
19422023

19432024
except np.linalg.LinAlgError as e:
19442025
status_arr[gi] = f'SINGULAR_MATRIX'
2026+
if _use_batch_mad:
2027+
_resid_offsets[gi + 1] = _resid_pos
19452028
continue
19462029

2030+
# ========================================================================
2031+
# BATCH MAD COMPUTATION (Phase 13.21.GB-PERF F3)
2032+
# ========================================================================
2033+
if _use_batch_mad:
2034+
# Forward-fill offsets for skipped bins
2035+
_resid_offsets = np.maximum.accumulate(_resid_offsets)
2036+
_resid_flat = _resid_flat[:_resid_pos, :] # trim to actual size
2037+
try:
2038+
_mad_kernel = _get_batch_mad_kernel()
2039+
_mad_kernel(_resid_flat, _resid_offsets, n_groups, n_tgt, mad_arr)
2040+
except Exception:
2041+
# Fallback: per-bin numpy median (should not happen)
2042+
for gi in range(n_groups):
2043+
r0 = _resid_offsets[gi]
2044+
r1 = _resid_offsets[gi + 1]
2045+
if r1 > r0:
2046+
for t in range(n_tgt):
2047+
rt = _resid_flat[r0:r1, t]
2048+
mad_arr[gi, t] = np.median(np.abs(rt - np.median(rt)))
2049+
19472050
# ========================================================================
19482051
# VECTORIZED OUTPUT ASSEMBLY
19492052
# ========================================================================

0 commit comments

Comments
 (0)