Skip to content

Commit 1f3b45b

Browse files
committed
Phase 13.20.GB-PERF: numba-ize _aggregate_window_dense inner loop
F1: _get_gather_window_rows_kernel replaces per-bin Python loop with JIT-compiled two-pass CSR kernel (count → prefix sum → fill). F1a: removed redundant flat_indices.clip (16s, bounds mask suffices) F1c: removed np.unique(concatenate) — rows from different bins are disjoint (counting-sort guarantee), concatenation alone suffices (37s) Stats computation factored into _compute_window_stats (shared by both numba and numpy paths). Numpy fallback retained via GBAI_DISABLE_AGG_DENSE_NUMBA=1 env flag. T1: 6 new invariance tests (numba kernel vs numpy fallback, parametrized window/fit_intercept/weights). Total: 20 pytest items. All 20 passed on alma2.
1 parent 627feda commit 1f3b45b

2 files changed

Lines changed: 309 additions & 63 deletions

File tree

UTILS/dfextensions/groupby_regression/groupby_regression_sliding_window.py

Lines changed: 238 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,107 @@ def _assign_bin_ids_fast(
685685
return bin_ids, n_bins, bin_coords, bounds
686686

687687

688+
def _get_gather_window_rows_kernel():
689+
"""Compile numba kernel for gathering window row indices.
690+
691+
Phase 13.20.GB-PERF: replaces the per-bin Python loop in
692+
_aggregate_window_dense (159s self-time, 240s cumulative on 82M rows)
693+
with a JIT-compiled two-pass kernel.
694+
695+
Key insight: rows from different neighbor bins are disjoint (each row
696+
belongs to exactly one bin via counting sort), so np.unique is
697+
unnecessary — pure concatenation suffices.
698+
699+
Pass 1: count rows per center bin (neighbor lookup + range sum).
700+
Pass 2: fill row indices into preallocated flat array (CSR layout).
701+
"""
702+
import numba as nb
703+
704+
@nb.njit(cache=True)
705+
def _gather_rows(
706+
bin_coords, # (n_bins, n_dims) int64
707+
neighbor_offsets, # (K, n_dims) int64
708+
bounds_lo, # (n_dims,) int64
709+
bounds_hi, # (n_dims,) int64
710+
lookup, # (lookup_len,) int64
711+
lookup_mins, # (n_dims,) int64
712+
lookup_strides, # (n_dims,) int64
713+
order, # (n_sorted_rows,) int64
714+
cs_offsets, # (n_compact_bins+1,) int64
715+
# outputs (preallocated by caller)
716+
out_row_offsets, # (n_bins+1,) int64 — CSR offsets
717+
out_n_neighbors, # (n_bins,) int64
718+
out_eff_frac, # (n_bins,) float64
719+
):
720+
n_bins = bin_coords.shape[0]
721+
n_offsets = neighbor_offsets.shape[0]
722+
n_dims = bin_coords.shape[1]
723+
lookup_len = lookup.shape[0]
724+
expected_neighbors = n_offsets if n_offsets > 0 else 1
725+
726+
# ---- Pass 1: count rows per bin ----
727+
for bi in range(n_bins):
728+
count = np.int64(0)
729+
n_nbrs = np.int64(0)
730+
for ni in range(n_offsets):
731+
valid = True
732+
flat_idx = np.int64(0)
733+
for d in range(n_dims):
734+
nb_d = bin_coords[bi, d] + neighbor_offsets[ni, d]
735+
if nb_d < bounds_lo[d] or nb_d > bounds_hi[d]:
736+
valid = False
737+
break
738+
flat_idx += (nb_d - lookup_mins[d]) * lookup_strides[d]
739+
if not valid:
740+
continue
741+
if flat_idx < 0 or flat_idx >= lookup_len:
742+
continue
743+
cid = lookup[flat_idx]
744+
if cid < 0:
745+
continue
746+
n_nbrs += 1
747+
count += cs_offsets[cid + 1] - cs_offsets[cid]
748+
out_row_offsets[bi + 1] = count
749+
out_n_neighbors[bi] = n_nbrs
750+
out_eff_frac[bi] = n_nbrs / expected_neighbors
751+
752+
# ---- Prefix sum ----
753+
for bi in range(n_bins):
754+
out_row_offsets[bi + 1] += out_row_offsets[bi]
755+
756+
total_rows = out_row_offsets[n_bins]
757+
out_rows = np.empty(total_rows, dtype=np.int64)
758+
759+
# ---- Pass 2: fill row indices ----
760+
for bi in range(n_bins):
761+
pos = out_row_offsets[bi]
762+
for ni in range(n_offsets):
763+
valid = True
764+
flat_idx = np.int64(0)
765+
for d in range(n_dims):
766+
nb_d = bin_coords[bi, d] + neighbor_offsets[ni, d]
767+
if nb_d < bounds_lo[d] or nb_d > bounds_hi[d]:
768+
valid = False
769+
break
770+
flat_idx += (nb_d - lookup_mins[d]) * lookup_strides[d]
771+
if not valid:
772+
continue
773+
if flat_idx < 0 or flat_idx >= lookup_len:
774+
continue
775+
cid = lookup[flat_idx]
776+
if cid < 0:
777+
continue
778+
start = cs_offsets[cid]
779+
end = cs_offsets[cid + 1]
780+
for ri in range(start, end):
781+
out_rows[pos] = order[ri]
782+
pos += 1
783+
784+
return out_rows
785+
786+
return _gather_rows
787+
788+
688789
def _aggregate_window_dense(
689790
df: pd.DataFrame,
690791
bin_ids: np.ndarray,
@@ -704,15 +805,20 @@ def _aggregate_window_dense(
704805
agg_columns: Optional[List[str]] = None,
705806
agg_median: bool = False,
706807
) -> List[_AggResult]:
707-
"""Dense-lookup replacement for _aggregate_window_zerocopy.
808+
"""Dense-lookup window aggregation for the V1/V2 recompute path.
809+
810+
Phase 13.19.GB-PERF: initial implementation (dense lookup replaces
811+
_build_bin_index_map + _get_neighbor_bins V3a).
708812
709-
Phase 13.19.GB-PERF: eliminates two profile bottlenecks:
710-
- _build_bin_index_map (205s) → replaced by vectorized _assign_bin_ids_fast + _counting_sort_indices
711-
- _get_neighbor_bins V3a (152s) → replaced by inline vectorized offset + dense lookup
813+
Phase 13.20.GB-PERF: inner per-bin loop moved to numba kernel
814+
(_gather_window_rows_numba). Eliminates 2.35M Python calls to
815+
np.unique/clip/ones. Row indices from different neighbor bins are
816+
disjoint (counting-sort guarantee), so np.unique is unnecessary.
712817
713-
Same output contract as _aggregate_window_zerocopy: returns List[_AggResult]
714-
consumed by _fit_window_regression_numba/_numpy and _assemble_results.
818+
Set env GBAI_DISABLE_AGG_DENSE_NUMBA=1 to force numpy fallback
819+
(for testing invariance between JIT and Python paths).
715820
"""
821+
import os
716822
results: List[_AggResult] = []
717823
expected_neighbors = int(neighbor_offsets.shape[0]) if neighbor_offsets.size else 1
718824
n_dims = len(gb_columns)
@@ -723,34 +829,86 @@ def _aggregate_window_dense(
723829
_agg_cols = agg_columns or []
724830
agg_arrays = {c: df[c].to_numpy(dtype=np.float64) for c in _agg_cols}
725831

726-
# Bounds as arrays for vectorized checks
832+
# Bounds as arrays
727833
bounds_lo = np.array([bounds[dim][0] for dim in gb_columns], dtype=np.int64)
728834
bounds_hi = np.array([bounds[dim][1] for dim in gb_columns], dtype=np.int64)
729-
lookup_len = len(lookup)
730835

836+
# ---- Dispatch: numba kernel or numpy fallback ----
837+
use_numba = (
838+
neighbor_offsets.size > 0
839+
and os.environ.get("GBAI_DISABLE_AGG_DENSE_NUMBA", "") != "1"
840+
)
841+
842+
if use_numba:
843+
try:
844+
_gather_kernel = _get_gather_window_rows_kernel()
845+
846+
out_row_offsets = np.zeros(n_bins + 1, dtype=np.int64)
847+
out_n_neighbors = np.zeros(n_bins, dtype=np.int64)
848+
out_eff_frac = np.zeros(n_bins, dtype=np.float64)
849+
850+
out_rows = _gather_kernel(
851+
bin_coords, neighbor_offsets,
852+
bounds_lo, bounds_hi,
853+
lookup, lookup_mins, lookup_strides,
854+
order, offsets,
855+
out_row_offsets, out_n_neighbors, out_eff_frac,
856+
)
857+
858+
# Assemble _AggResult from CSR output
859+
for bi in range(n_bins):
860+
center = tuple(int(bin_coords[bi, d]) for d in range(n_dims))
861+
r_start = out_row_offsets[bi]
862+
r_end = out_row_offsets[bi + 1]
863+
idx_unique = out_rows[r_start:r_end]
864+
n_used = int(out_n_neighbors[bi])
865+
eff_frac = float(out_eff_frac[bi])
866+
n_rows = int(idx_unique.size)
867+
868+
stats, agg_st = _compute_window_stats(
869+
idx_unique, n_rows, fit_columns, target_arrays,
870+
w_array, weights, _agg_cols, agg_arrays, agg_median,
871+
)
872+
873+
results.append(_AggResult(
874+
center=center,
875+
n_neighbors_used=n_used,
876+
n_rows_aggregated=n_rows,
877+
effective_window_fraction=eff_frac,
878+
stats=stats,
879+
row_indices=idx_unique,
880+
agg_stats=agg_st,
881+
))
882+
return results
883+
884+
except Exception:
885+
# Fall through to numpy path on JIT failure
886+
pass
887+
888+
# ---- Numpy fallback (original Phase 13.19 code) ----
889+
lookup_len = len(lookup)
731890
for bi in range(n_bins):
732891
center = tuple(int(bin_coords[bi, d]) for d in range(n_dims))
733-
center_arr = bin_coords[bi] # (n_dims,) int64
892+
center_arr = bin_coords[bi]
734893

735-
# Vectorized neighbor computation (replaces _get_neighbor_bins V3a)
736894
if neighbor_offsets.size > 0:
737-
cand = center_arr + neighbor_offsets # (K, D)
895+
cand = center_arr + neighbor_offsets
738896
mask = np.ones(len(cand), dtype=bool)
739897
for j in range(n_dims):
740898
mask &= (cand[:, j] >= bounds_lo[j]) & (cand[:, j] <= bounds_hi[j])
741-
valid_cand = cand[mask] # (K', D)
899+
valid_cand = cand[mask]
742900
else:
743901
valid_cand = center_arr.reshape(1, -1)
744902

745-
# Vectorized dense-lookup: neighbor coords → compact bin indices
746-
shifted = valid_cand - lookup_mins # (K', D)
747-
flat_indices = (shifted * lookup_strides).sum(axis=1) # (K',)
903+
shifted = valid_cand - lookup_mins
904+
flat_indices = (shifted * lookup_strides).sum(axis=1)
748905
in_range = (flat_indices >= 0) & (flat_indices < lookup_len)
749-
compact_ids = np.where(in_range, lookup[flat_indices.clip(0, lookup_len - 1)], -1)
906+
# F1a: remove redundant clip — bounds mask guarantees in-range
907+
safe_idx = np.where(in_range, flat_indices, 0)
908+
compact_ids = np.where(in_range, lookup[safe_idx], -1)
750909
populated = compact_ids[compact_ids >= 0]
751910
n_used = len(populated)
752911

753-
# Gather row indices from counting-sort output
754912
idx_parts = []
755913
for cid in populated:
756914
start = offsets[cid]
@@ -759,64 +917,81 @@ def _aggregate_window_dense(
759917
idx_parts.append(order[start:end])
760918

761919
if idx_parts:
762-
idx_unique = np.unique(np.concatenate(idx_parts))
920+
# Rows from different bins are disjoint (counting-sort guarantee),
921+
# so np.unique is unnecessary — concatenation suffices.
922+
idx_unique = np.concatenate(idx_parts)
763923
else:
764924
idx_unique = np.array([], dtype=np.int64)
765925

766926
eff_frac = (n_used / expected_neighbors) if expected_neighbors > 0 else np.nan
767927
n_rows = int(idx_unique.size)
768928

769-
stats: Dict[str, Dict[str, float]] = {}
770-
agg_st: Optional[Dict[str, Dict[str, float]]] = None
929+
stats, agg_st = _compute_window_stats(
930+
idx_unique, n_rows, fit_columns, target_arrays,
931+
w_array, weights, _agg_cols, agg_arrays, agg_median,
932+
)
771933

772-
if n_rows > 0:
773-
if w_array is not None:
774-
w_win = w_array[idx_unique]
775-
w_valid = np.isfinite(w_win) & (w_win >= 0)
776-
else:
777-
w_win = None
778-
w_valid = None
934+
results.append(_AggResult(
935+
center=center,
936+
n_neighbors_used=n_used,
937+
n_rows_aggregated=n_rows,
938+
effective_window_fraction=eff_frac,
939+
stats=stats,
940+
row_indices=idx_unique,
941+
agg_stats=agg_st,
942+
))
779943

780-
for t in fit_columns:
781-
stats[t] = {}
944+
return results
782945

783-
if _agg_cols:
784-
agg_st = {}
785-
for c in _agg_cols:
786-
y = agg_arrays[c][idx_unique]
787-
y_finite = np.isfinite(y)
788-
if weights is None:
789-
x = y[y_finite]
790-
mean, std = _weighted_mean_std(x, None)
791-
else:
792-
valid = y_finite & w_valid
793-
x = y[valid]
794-
ww = w_win[valid]
795-
mean, std = _weighted_mean_std(x, ww)
796-
if agg_median and int(np.sum(y_finite)) > 0:
797-
median = float(np.median(y[y_finite]))
798-
else:
799-
median = np.nan
800-
agg_st[c] = {"mean": mean, "std": std, "median": median}
946+
947+
def _compute_window_stats(
948+
idx_unique, n_rows, fit_columns, target_arrays,
949+
w_array, weights, agg_cols, agg_arrays, agg_median,
950+
):
951+
"""Compute per-window statistics from gathered row indices.
952+
953+
Factored out of _aggregate_window_dense so both the numba and
954+
numpy paths share the same stats logic.
955+
"""
956+
stats: Dict[str, Dict[str, float]] = {}
957+
agg_st: Optional[Dict[str, Dict[str, float]]] = None
958+
959+
if n_rows > 0:
960+
if w_array is not None:
961+
w_win = w_array[idx_unique]
962+
w_valid = np.isfinite(w_win) & (w_win >= 0)
801963
else:
802-
for t in fit_columns:
803-
stats[t] = {}
804-
if _agg_cols:
805-
agg_st = {c: {"mean": np.nan, "std": np.nan, "median": np.nan} for c in _agg_cols}
964+
w_win = None
965+
w_valid = None
806966

807-
results.append(
808-
_AggResult(
809-
center=center,
810-
n_neighbors_used=n_used,
811-
n_rows_aggregated=n_rows,
812-
effective_window_fraction=eff_frac,
813-
stats=stats,
814-
row_indices=idx_unique,
815-
agg_stats=agg_st,
816-
)
817-
)
967+
for t in fit_columns:
968+
stats[t] = {}
818969

819-
return results
970+
if agg_cols:
971+
agg_st = {}
972+
for c in agg_cols:
973+
y = agg_arrays[c][idx_unique]
974+
y_finite = np.isfinite(y)
975+
if weights is None:
976+
x = y[y_finite]
977+
mean, std = _weighted_mean_std(x, None)
978+
else:
979+
valid = y_finite & w_valid
980+
x = y[valid]
981+
ww = w_win[valid]
982+
mean, std = _weighted_mean_std(x, ww)
983+
if agg_median and int(np.sum(y_finite)) > 0:
984+
median = float(np.median(y[y_finite]))
985+
else:
986+
median = np.nan
987+
agg_st[c] = {"mean": mean, "std": std, "median": median}
988+
else:
989+
for t in fit_columns:
990+
stats[t] = {}
991+
if agg_cols:
992+
agg_st = {c: {"mean": np.nan, "std": np.nan, "median": np.nan} for c in agg_cols}
993+
994+
return stats, agg_st
820995

821996
def _sanitize_suffix(name: str) -> str:
822997
return "".join(ch if ch.isalnum() else "_" for ch in str(name))

0 commit comments

Comments
 (0)