Skip to content

Commit b0f1d30

Browse files
TimDettmersclaude
andcommitted
Make batched MoE GEMM CUDA-graph-safe with device-side alpha
- Persist Gemm object in MoeGemmState (avoids stack-local params_ destruction) - Move gemm.initialize() to _init (triggers cudaFuncSetAttribute once) - _run rebuilds params from arguments then calls gemm.run() (graph-safe) - Change alpha from host float to device pointer (const float*); CUTLASS epilogue reads via alpha_ptr for zero host-GPU sync - Update op definition, registered kernel, functional API, and benchmark Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 830491d commit b0f1d30

File tree

6 files changed

+49
-22
lines changed

6 files changed

+49
-22
lines changed

benchmarks/bench_moe_gemm_sm100.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,14 @@ def bench_batched_nvfp4(lib, max_M, N, K, num_experts):
121121

122122
lib.cgemm_nvfp4_moe_sm100_run.restype = ct.c_int
123123

124+
alpha_dev = torch.tensor([1.0], dtype=torch.float32, device=device)
125+
124126
def run_kernel():
125127
lib.cgemm_nvfp4_moe_sm100_run(
126128
get_ptr(A_batched), get_ptr(B_batched),
127129
get_ptr(SFA), get_ptr(SFB),
128130
get_ptr(D_out),
129-
ct.c_float(1.0), stream_ptr)
131+
get_ptr(alpha_dev), stream_ptr)
130132

131133
# Warmup
132134
for _ in range(WARMUP):

bitsandbytes/_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def _(
651651
torch.library.define(
652652
"bitsandbytes::gemm_nvfp4_moe",
653653
"(Tensor A_batched, Tensor B_batched, Tensor SFA, Tensor SFB, "
654-
"float alpha, int max_M, int N, int K, int num_experts) -> Tensor",
654+
"Tensor alpha, int max_M, int N, int K, int num_experts) -> Tensor",
655655
)
656656

657657

@@ -661,7 +661,7 @@ def _(
661661
B_batched: torch.Tensor,
662662
SFA: torch.Tensor,
663663
SFB: torch.Tensor,
664-
alpha: float,
664+
alpha: torch.Tensor,
665665
max_M: int,
666666
N: int,
667667
K: int,

bitsandbytes/backends/cuda/ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,7 +1352,7 @@ def _(
13521352
B_batched: torch.Tensor,
13531353
SFA: torch.Tensor,
13541354
SFB: torch.Tensor,
1355-
alpha: float,
1355+
alpha: torch.Tensor,
13561356
max_M: int,
13571357
N: int,
13581358
K: int,
@@ -1377,13 +1377,16 @@ def _(
13771377

13781378
_moe_batched_cache = {"key": key, "workspace": workspace}
13791379

1380+
# Ensure alpha is a float32 device tensor
1381+
alpha_dev = alpha.to(dtype=torch.float32, device=A_batched.device).contiguous()
1382+
13801383
D_out = torch.empty(num_experts * max_M * N, dtype=torch.bfloat16, device=A_batched.device)
13811384

13821385
ret = lib.cgemm_nvfp4_moe_sm100_run(
13831386
get_ptr(A_batched), get_ptr(B_batched),
13841387
get_ptr(SFA), get_ptr(SFB),
13851388
get_ptr(D_out),
1386-
ct.c_float(alpha),
1389+
get_ptr(alpha_dev),
13871390
ct.c_void_p(_get_tensor_stream(A_batched)),
13881391
)
13891392
if ret != 0:

bitsandbytes/functional.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,10 +1543,9 @@ def gemm_nvfp4_grouped(
15431543
def gemm_nvfp4_moe(
15441544
A_batched: torch.Tensor,
15451545
SFA_batched: torch.Tensor,
1546-
A_tensor_scale: float,
1546+
alpha: torch.Tensor,
15471547
B_batched: torch.Tensor,
15481548
SFB_batched: torch.Tensor,
1549-
B_tensor_scale: float,
15501549
max_M: int,
15511550
N: int,
15521551
K: int,
@@ -1561,20 +1560,18 @@ def gemm_nvfp4_moe(
15611560
Args:
15621561
A_batched: Packed FP4 activations, batched (num_experts * max_M * K // 2,).
15631562
SFA_batched: Per-expert swizzled activation scales (concatenated).
1564-
A_tensor_scale: Shared tensor scale for activations.
1563+
alpha: Device tensor (float32, 0-dim or 1-element) = act_scale * weight_scale.
15651564
B_batched: Packed FP4 weights, batched (num_experts * N * K // 2,).
15661565
SFB_batched: Per-expert swizzled weight scales (concatenated).
1567-
B_tensor_scale: Shared tensor scale for weights.
15681566
max_M: Max tokens per expert (all experts padded to this).
15691567
N: Output dimension per expert.
15701568
K: Input dimension per expert.
15711569
num_experts: Number of experts (batch dimension L).
15721570
15731571
Returns:
15741572
Output tensor (num_experts, max_M, N) in bfloat16 with tensor scales
1575-
applied via the CUTLASS epilogue alpha.
1573+
applied via the CUTLASS epilogue alpha (device-side).
15761574
"""
1577-
alpha = A_tensor_scale * B_tensor_scale
15781575
return torch.ops.bitsandbytes.gemm_nvfp4_moe(
15791576
A_batched, B_batched, SFA_batched, SFB_batched,
15801577
alpha, max_M, N, K, num_experts,

bitsandbytes/nn/modules.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -950,10 +950,14 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
950950
A_batched = torch.cat(all_act_packed)
951951
SFA_batched = torch.cat(all_act_scales)
952952

953-
# Run batched GEMM
953+
# Run batched GEMM (alpha is a device tensor for graph safety)
954+
alpha_dev = torch.tensor(
955+
[act_tensor_scale * self.weight_tensor_scale],
956+
dtype=torch.float32, device=x.device,
957+
)
954958
D = gemm_nvfp4_moe(
955-
A_batched, SFA_batched, act_tensor_scale,
956-
self.weight_packed, self.weight_scales_batched, self.weight_tensor_scale,
959+
A_batched, SFA_batched, alpha_dev,
960+
self.weight_packed, self.weight_scales_batched,
957961
max_M, N, K, num_experts,
958962
)
959963

csrc/qutlass/gemm_nvfp4_moe_sm100.cu

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ struct MoeGemmState {
151151
// Workspace
152152
void* workspace_dev = nullptr;
153153
size_t workspace_size = 0;
154+
155+
// Persistent GEMM object: avoids stack allocation per call, keeps
156+
// params_ alive for CUDA graph replay. init() triggers the one-time
157+
// cudaFuncSetAttribute call; run() reuses the object.
158+
Gemm gemm;
154159
};
155160

156161
static MoeGemmState s_state;
@@ -267,13 +272,20 @@ extern "C" int cgemm_nvfp4_moe_sm100_init(
267272
arguments.epilogue.thread.alpha = 1.0f;
268273
arguments.epilogue.thread.beta = 0.0f;
269274

270-
Gemm gemm;
271-
auto status = gemm.can_implement(arguments);
275+
auto status = st.gemm.can_implement(arguments);
272276
if (status != cutlass::Status::kSuccess) {
273277
fprintf(stderr, "MoE GEMM can_implement failed: %d\n", (int)status);
274278
return -1;
275279
}
276280

281+
// Initialize the persistent Gemm object: triggers cudaFuncSetAttribute
282+
// (one-time, not graph-safe) and fills internal params_ with dummy pointers.
283+
status = st.gemm.initialize(arguments, st.workspace_dev, stream);
284+
if (status != cutlass::Status::kSuccess) {
285+
fprintf(stderr, "MoE GEMM initial initialize failed: %d\n", (int)status);
286+
return -2;
287+
}
288+
277289
st.initialized = true;
278290
return 0;
279291

@@ -324,13 +336,17 @@ extern "C" size_t cgemm_nvfp4_moe_sm100_workspace_size(
324336
// SFA_dev: activation scale factors (batched swizzled layout)
325337
// SFB_dev: weight scale factors (batched swizzled layout)
326338
// D_dev: output (num_experts, max_M, N_output) BF16, row-major per expert
339+
// alpha_dev: device pointer to float alpha (= act_scale * weight_scale)
340+
//
341+
// Graph-safe: only host-side param building + kernel launch.
342+
// cudaFuncSetAttribute was already called during _init.
327343
extern "C" int cgemm_nvfp4_moe_sm100_run(
328344
const void* A_dev, // activations (packed FP4)
329345
const void* B_dev, // weights (packed FP4)
330346
const void* SFA_dev, // activation scale factors
331347
const void* SFB_dev, // weight scale factors
332348
void* D_dev, // output (BF16)
333-
float alpha,
349+
const float* alpha_dev, // device pointer to alpha scalar
334350
cudaStream_t stream
335351
) {
336352
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
@@ -363,18 +379,23 @@ extern "C" int cgemm_nvfp4_moe_sm100_run(
363379
static_cast<ElementD*>(D_dev), st.stride_D},
364380
st.hw_info
365381
};
366-
arguments.epilogue.thread.alpha = alpha;
382+
// Device-side alpha: if alpha_dev is non-null, kernel reads from device ptr.
383+
// alpha_ptr takes precedence over the scalar alpha value.
384+
arguments.epilogue.thread.alpha = 1.0f; // fallback (ignored when alpha_ptr set)
385+
arguments.epilogue.thread.alpha_ptr = alpha_dev;
367386
arguments.epilogue.thread.beta = 0.0f;
368387

369-
Gemm gemm;
370-
371-
auto status = gemm.initialize(arguments, st.workspace_dev, stream);
388+
// Rebuild params from arguments (host-side only, no CUDA API calls).
389+
// cudaFuncSetAttribute was already called during _init on the persistent
390+
// gemm object, so we call initialize() which is idempotent for the
391+
// attribute and only updates params_.
392+
auto status = st.gemm.initialize(arguments, st.workspace_dev, stream);
372393
if (status != cutlass::Status::kSuccess) {
373394
fprintf(stderr, "MoE GEMM initialize failed: %d\n", (int)status);
374395
return -2;
375396
}
376397

377-
status = gemm.run(stream);
398+
status = st.gemm.run(stream);
378399
if (status != cutlass::Status::kSuccess) {
379400
fprintf(stderr, "MoE GEMM run failed: %d\n", (int)status);
380401
return -3;

0 commit comments

Comments
 (0)