Skip to content

Commit dc4343b

Browse files
TimDettmersclaude
andcommitted
Phase 1 inner loop opts: branchless absmax, interleaved extraction, two-tier k_splits
Three changes to the production kernel (kbit_gemm_prod): 1. Branchless absmax decode: new decode_e4m4_absmax_branchless() eliminates BSSY/BSYNC divergence-handling pairs in SASS. Subnormals treated as normal-path (acceptable since no real weight block has absmax < 2^-10). 2. Interleaved bit extraction: all 4 fragment elements' bit extractions interleaved in a single loop over K_BITS, giving the compiler more ILP across elements and bit-planes. 3. Two-tier k_splits heuristic: Tier 1 (severe underutil < 25%) splits aggressively. Tier 2 (new) splits conservatively (cap 2) when data exceeds L2 cache (> 24 MB) and SM utilization is moderate. Llama3-8B improves ~25% from k_splits=2. MoE shapes remain at 0.3-0.4x vs cuBLAS — the bottleneck is structural (1264 SASS instructions per k_tile, 1.3% tensor core utilization). Phase 2 restructuring (dequant-during-fetch) needed. Also adds optimization2.md documenting root cause analysis and the dequant-during-fetch restructuring plan. 195/195 tests pass. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 90cd7cf commit dc4343b

File tree

2 files changed

+627
-12
lines changed

2 files changed

+627
-12
lines changed

csrc/ops.cu

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,24 @@ __device__ __forceinline__ float decode_e4m4_absmax(unsigned char raw) {
736736
return __uint_as_float(ieee);
737737
}
738738

739+
// Branchless version for the GEMM inner loop. Eliminates BSSY/BSYNC
740+
// divergence-handling pairs that the branchy version generates.
741+
// Subnormals (e==0) are treated as normal-path (produces a small wrong
742+
// value, but no real weight block has absmax < 2^-10).
743+
__device__ __forceinline__ float decode_e4m4_absmax_branchless(unsigned char raw) {
744+
int e = raw >> 4;
745+
int m = raw & 0xF;
746+
// Normal path: construct IEEE 754 directly.
747+
// When raw==0 (e==0, m==0) this produces 2^(0-11+127)<<23 | 0 which
748+
// is some small positive float; we select 0.0 below via predicate.
749+
unsigned int ieee = (unsigned int)(e - E4M4_BIAS + 127) << 23
750+
| (unsigned int)m << 19;
751+
float result = __uint_as_float(ieee);
752+
// Zero-out for raw==0 using predicated select (no branch).
753+
// PTXAS emits a FSEL instruction (1 cycle, no divergence).
754+
return (raw != 0) ? result : 0.0f;
755+
}
756+
739757
// ---- E4M4 absmax encode ----
740758
// float -> uint8: inverse of decode_e4m4_absmax.
741759
// Normal (e_biased > 0): e_biased = floor(log2(val)) + BIAS, m = round((val/2^e_unbiased - 1) * 16)
@@ -1929,21 +1947,37 @@ __global__ void kbit_gemm_prod(
19291947
for (int b = 0; b < K_BITS; b++)
19301948
planes[b] = b_ptr[b_addr + b];
19311949

1932-
scalar_t scale = Ops::from_float(decode_e4m4_absmax(abs_ptr[col * KB_PER_TILE + k_block]));
1950+
scalar_t scale = Ops::from_float(decode_e4m4_absmax_branchless(abs_ptr[col * KB_PER_TILE + k_block]));
19331951

19341952
const int bit_offset = half_idx * 16;
19351953
const int rows[4] = {2 * tid, 2 * tid + 1, 2 * tid + 8, 2 * tid + 9};
1936-
scalar_t vals[4];
1937-
#pragma unroll
1938-
for (int r = 0; r < 4; r++) {
1939-
int bit_pos = bit_offset + rows[r];
1940-
int idx = 0;
1954+
1955+
// Dequantize 4 elements with interleaved bit extraction.
1956+
// Extract all bit values independently first (no serial
1957+
// dependency chain), then combine per-element.
1958+
int bp0 = bit_offset + rows[0];
1959+
int bp1 = bit_offset + rows[1];
1960+
int bp2 = bit_offset + rows[2];
1961+
int bp3 = bit_offset + rows[3];
1962+
1963+
// All 4*K_BITS extractions are independent — compiler
1964+
// can issue them in any order across ALU pipelines.
1965+
int idx0 = 0, idx1 = 0, idx2 = 0, idx3 = 0;
19411966
#pragma unroll
1942-
for (int b = 0; b < K_BITS; b++)
1943-
idx |= ((planes[b] >> bit_pos) & 1) << b;
1944-
vals[r] = Ops::mul(__shfl_sync(0xFFFFFFFF, cb_val, idx), scale);
1967+
for (int b = 0; b < K_BITS; b++) {
1968+
unsigned int p = planes[b];
1969+
idx0 |= ((p >> bp0) & 1) << b;
1970+
idx1 |= ((p >> bp1) & 1) << b;
1971+
idx2 |= ((p >> bp2) & 1) << b;
1972+
idx3 |= ((p >> bp3) & 1) << b;
19451973
}
19461974

1975+
scalar_t vals[4];
1976+
vals[0] = Ops::mul(__shfl_sync(0xFFFFFFFF, cb_val, idx0), scale);
1977+
vals[1] = Ops::mul(__shfl_sync(0xFFFFFFFF, cb_val, idx1), scale);
1978+
vals[2] = Ops::mul(__shfl_sync(0xFFFFFFFF, cb_val, idx2), scale);
1979+
vals[3] = Ops::mul(__shfl_sync(0xFFFFFFFF, cb_val, idx3), scale);
1980+
19471981
uint32_t frag_b[2];
19481982
frag_b[0] = pack_two<scalar_t>(vals[0], vals[1]);
19491983
frag_b[1] = pack_two<scalar_t>(vals[2], vals[3]);
@@ -2060,12 +2094,27 @@ static void kbitGemmProdLaunch(
20602094
int k_tiles = (K_dim + TILE_K - 1) / TILE_K;
20612095
int mn_tiles = m_tiles * n_tiles;
20622096

2063-
// Auto-select k_splits only for severe SM underutilization (< 25%).
2064-
// The atomicAdd + workspace overhead of k_splits > 1 is significant,
2065-
// so only use it when the utilization gain clearly outweighs the cost.
2097+
// Two-tier k_splits heuristic:
2098+
//
2099+
// Tier 1: Severe underutilization (< 25% of SMs active).
2100+
// Even with L2-cached data, having 75%+ SMs idle wastes parallelism.
2101+
// Split aggressively to fill SMs.
2102+
//
2103+
// Tier 2: Moderate underutilization with DRAM-bound data.
2104+
// When data exceeds L2 cache, more SMs generate more DRAM requests.
2105+
// Split conservatively (k_splits <= 2) to avoid atomicAdd overhead.
2106+
long long b_data_bytes = (long long)N * (K_dim / BS) * K * sizeof(unsigned int)
2107+
+ (long long)N * (K_dim / BS); // packed + absmax
2108+
constexpr long long DRAM_THRESHOLD = 24LL * 1024 * 1024; // 24 MB
2109+
20662110
int k_splits = 1;
20672111
if (mn_tiles < num_sms / 4 && k_tiles > 1) {
2112+
// Tier 1: severe underutil — split aggressively
2113+
k_splits = min(k_tiles, (num_sms + mn_tiles - 1) / mn_tiles);
2114+
} else if (mn_tiles < num_sms && k_tiles > 1 && b_data_bytes > DRAM_THRESHOLD) {
2115+
// Tier 2: DRAM-bound with moderate underutil — split conservatively
20682116
k_splits = min(k_tiles, (num_sms + mn_tiles - 1) / mn_tiles);
2117+
k_splits = min(k_splits, 2);
20692118
}
20702119

20712120
int total_work = mn_tiles * k_splits;

0 commit comments

Comments
 (0)