@@ -1547,6 +1547,67 @@ def make_parallel_fit_v3(
15471547 return df_out , dfGB
15481548
15491549
1550+ def _get_batch_mad_kernel ():
1551+ """Compile numba kernel for batch MAD computation.
1552+
1553+ Phase 13.21.GB-PERF F3: replaces 1.26M per-bin np.median calls
1554+ (52s cumulative, 41μs Python dispatch per call) with a single
1555+ JIT-compiled prange loop over bins.
1556+
1557+ MAD(x) = median(|x - median(x)|)
1558+
1559+ Each bin: two sorts of a small per-bin scratch buffer (~20 elements).
1560+ Thread-safe: each prange iteration uses its own local scratch buffer.
1561+ """
1562+ import numba as nb
1563+
1564+ @nb .njit (cache = True , parallel = True )
1565+ def _batch_mad (resid_all , resid_offsets , n_groups , n_tgt , out_mad ):
1566+ """Compute MAD for all bins in parallel.
1567+
1568+ Parameters
1569+ ----------
1570+ resid_all : (total_valid_rows, n_tgt) float64
1571+ Flat array of valid residuals, packed contiguously by bin.
1572+ resid_offsets : (n_groups + 1,) int64
1573+ CSR offsets into resid_all per bin.
1574+ n_groups : int
1575+ n_tgt : int
1576+ out_mad : (n_groups, n_tgt) float64
1577+ Output array, pre-initialized to NaN.
1578+ """
1579+ for gi in nb .prange (n_groups ):
1580+ i0 = resid_offsets [gi ]
1581+ i1 = resid_offsets [gi + 1 ]
1582+ m = i1 - i0
1583+ if m == 0 :
1584+ continue # out_mad already NaN
1585+
1586+ for t in range (n_tgt ):
1587+ # Step 1: copy residuals to scratch, find median
1588+ buf = np .empty (m , dtype = np .float64 )
1589+ for r in range (m ):
1590+ buf [r ] = resid_all [i0 + r , t ]
1591+ buf .sort ()
1592+ if m % 2 == 1 :
1593+ med = buf [m // 2 ]
1594+ else :
1595+ med = (buf [m // 2 - 1 ] + buf [m // 2 ]) / 2.0
1596+
1597+ # Step 2: compute |resid - median|, find median of that
1598+ for r in range (m ):
1599+ buf [r ] = abs (resid_all [i0 + r , t ] - med )
1600+ buf .sort ()
1601+ if m % 2 == 1 :
1602+ mad_val = buf [m // 2 ]
1603+ else :
1604+ mad_val = (buf [m // 2 - 1 ] + buf [m // 2 ]) / 2.0
1605+
1606+ out_mad [gi , t ] = mad_val
1607+
1608+ return _batch_mad
1609+
1610+
15501611def make_parallel_fit_v4 (
15511612 * ,
15521613 df ,
@@ -1806,6 +1867,15 @@ def make_parallel_fit_v4(
18061867 # PROCESS EACH GROUP
18071868 # ========================================================================
18081869
1870+ # Phase 13.21.GB-PERF F3: batch MAD computation
1871+ import os
1872+ _use_batch_mad = os .environ .get ("GBAI_DISABLE_BATCH_MEDIAN" , "" ) != "1"
1873+ if _use_batch_mad :
1874+ # Pre-allocate flat residual buffer for post-loop batch MAD
1875+ _resid_flat = np .empty ((N , n_tgt ), dtype = np .float64 )
1876+ _resid_offsets = np .zeros (n_groups + 1 , dtype = np .int64 )
1877+ _resid_pos = 0
1878+
18091879 # NumPy fallback (Numba kernel would be similar but JIT-compiled)
18101880 for gi in range (n_groups ):
18111881 i0 , i1 = offsets [gi ], offsets [gi + 1 ]
@@ -1818,6 +1888,8 @@ def make_parallel_fit_v4(
18181888 n_valid_arr [gi ] = 0
18191889 n_filtered_arr [gi ] = 0
18201890 status_arr [gi ] = 'INSUFFICIENT_DATA'
1891+ if _use_batch_mad :
1892+ _resid_offsets [gi + 1 ] = _resid_pos
18211893 continue
18221894
18231895 # Extract data for this group
@@ -1848,6 +1920,8 @@ def make_parallel_fit_v4(
18481920 # Check if enough valid data remains
18491921 if n_valid < int (min_stat ):
18501922 status_arr [gi ] = 'INSUFFICIENT_DATA'
1923+ if _use_batch_mad :
1924+ _resid_offsets [gi + 1 ] = _resid_pos
18511925 continue
18521926
18531927 # Apply filter
@@ -1918,11 +1992,18 @@ def make_parallel_fit_v4(
19181992 y_pred_unweighted = X1 @ coeffs # X1 is unweighted design matrix
19191993 resid_unweighted = Yg - y_pred_unweighted
19201994
1921- # Compute MAD for each target
1922- for t_idx in range (n_tgt ):
1923- resid_t = resid_unweighted [:, t_idx ]
1924- mad_val = np .median (np .abs (resid_t - np .median (resid_t )))
1925- mad_arr [gi , t_idx ] = mad_val
1995+ if _use_batch_mad :
1996+ # Phase 13.21.GB-PERF F3: store residuals for post-loop batch
1997+ n_v = resid_unweighted .shape [0 ]
1998+ _resid_flat [_resid_pos :_resid_pos + n_v , :] = resid_unweighted
1999+ _resid_pos += n_v
2000+ _resid_offsets [gi + 1 ] = _resid_pos
2001+ else :
2002+ # Original per-bin MAD (fallback)
2003+ for t_idx in range (n_tgt ):
2004+ resid_t = resid_unweighted [:, t_idx ]
2005+ mad_val = np .median (np .abs (resid_t - np .median (resid_t )))
2006+ mad_arr [gi , t_idx ] = mad_val
19262007
19272008 # ================================================================
19282009 # COMPUTE PARAMETER ERRORS
@@ -1942,8 +2023,30 @@ def make_parallel_fit_v4(
19422023
19432024 except np .linalg .LinAlgError as e :
19442025 status_arr [gi ] = f'SINGULAR_MATRIX'
2026+ if _use_batch_mad :
2027+ _resid_offsets [gi + 1 ] = _resid_pos
19452028 continue
19462029
2030+ # ========================================================================
2031+ # BATCH MAD COMPUTATION (Phase 13.21.GB-PERF F3)
2032+ # ========================================================================
2033+ if _use_batch_mad :
2034+ # Forward-fill offsets for skipped bins
2035+ _resid_offsets = np .maximum .accumulate (_resid_offsets )
2036+ _resid_flat = _resid_flat [:_resid_pos , :] # trim to actual size
2037+ try :
2038+ _mad_kernel = _get_batch_mad_kernel ()
2039+ _mad_kernel (_resid_flat , _resid_offsets , n_groups , n_tgt , mad_arr )
2040+ except Exception :
2041+ # Fallback: per-bin numpy median (should not happen)
2042+ for gi in range (n_groups ):
2043+ r0 = _resid_offsets [gi ]
2044+ r1 = _resid_offsets [gi + 1 ]
2045+ if r1 > r0 :
2046+ for t in range (n_tgt ):
2047+ rt = _resid_flat [r0 :r1 , t ]
2048+ mad_arr [gi , t ] = np .median (np .abs (rt - np .median (rt )))
2049+
19472050 # ========================================================================
19482051 # VECTORIZED OUTPUT ASSEMBLY
19492052 # ========================================================================
0 commit comments