Skip to content

Commit 2fe1032

Browse files
committed
Add debug script for CUDA graph capture
1 parent 945b583 commit 2fe1032

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

bench_debug.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Debug CUDA graph capture for NVFP4 GEMM."""
2+
import ctypes as ct
3+
import torch
4+
5+
def get_ptr(t):
6+
return ct.c_void_p(t.data_ptr())
7+
8+
def main():
9+
from bitsandbytes.cextension import lib
10+
11+
device = torch.device("cuda")
12+
gpu = torch.cuda.get_device_name(0)
13+
cap = torch.cuda.get_device_capability(0)
14+
print(f"GPU: {gpu} (SM {cap[0]}.{cap[1]})")
15+
16+
num_experts, max_M, N, K = 8, 128, 13696, 4096
17+
half_K = K // 2
18+
19+
A_bat = torch.randint(0, 255, (num_experts * max_M * half_K,),
20+
dtype=torch.uint8, device=device)
21+
B_all = torch.randint(0, 255, (num_experts * N * half_K,),
22+
dtype=torch.uint8, device=device)
23+
24+
lib.cgemm_nvfp4_moe_sm100_sfa_size.restype = ct.c_size_t
25+
lib.cgemm_nvfp4_moe_sm100_sfb_size.restype = ct.c_size_t
26+
sfa_bytes = lib.cgemm_nvfp4_moe_sm100_sfa_size(
27+
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts))
28+
sfb_bytes = lib.cgemm_nvfp4_moe_sm100_sfb_size(
29+
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts))
30+
SFA = torch.randint(0, 255, (max(sfa_bytes, 1),), dtype=torch.uint8, device=device)
31+
SFB = torch.randint(0, 255, (max(sfb_bytes, 1),), dtype=torch.uint8, device=device)
32+
33+
D_out = torch.empty(num_experts * max_M, N, dtype=torch.bfloat16, device=device)
34+
alpha = torch.tensor([1.0], dtype=torch.float32, device=device)
35+
36+
lib.cgemm_nvfp4_moe_sm100_workspace_size.restype = ct.c_size_t
37+
ws_size = lib.cgemm_nvfp4_moe_sm100_workspace_size(
38+
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts))
39+
workspace = torch.empty(max(ws_size, 1), dtype=torch.uint8, device=device)
40+
41+
stream = torch.cuda.current_stream()
42+
stream_ptr = ct.c_void_p(stream.cuda_stream)
43+
44+
print(f"Stream ptr: {stream.cuda_stream}")
45+
print(f"Workspace size: {ws_size}")
46+
47+
# Init
48+
lib.cgemm_nvfp4_moe_sm100_init.restype = ct.c_int
49+
ret = lib.cgemm_nvfp4_moe_sm100_init(
50+
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts),
51+
get_ptr(A_bat), get_ptr(B_all),
52+
get_ptr(SFA), get_ptr(SFB),
53+
get_ptr(D_out), get_ptr(alpha),
54+
get_ptr(workspace), ct.c_size_t(ws_size), stream_ptr,
55+
)
56+
print(f"Init returned: {ret}")
57+
if ret != 0:
58+
print("Init failed!")
59+
return
60+
61+
# Warmup eager run
62+
lib.cgemm_nvfp4_moe_sm100_run.restype = ct.c_int
63+
for i in range(3):
64+
ret = lib.cgemm_nvfp4_moe_sm100_run(stream_ptr)
65+
print(f"Eager run {i}: {ret}")
66+
torch.cuda.synchronize()
67+
print("Eager runs OK")
68+
69+
# Test 1: BF16 bmm graph (sanity check graph capture works)
70+
A_bf = torch.randn(8, 128, 4096, dtype=torch.bfloat16, device=device)
71+
B_bf = torch.randn(8, 4096, 13696, dtype=torch.bfloat16, device=device)
72+
C_bf = torch.empty(8, 128, 13696, dtype=torch.bfloat16, device=device)
73+
torch.bmm(A_bf, B_bf, out=C_bf)
74+
torch.cuda.synchronize()
75+
76+
g1 = torch.cuda.CUDAGraph()
77+
with torch.cuda.graph(g1):
78+
torch.bmm(A_bf, B_bf, out=C_bf)
79+
g1.replay()
80+
torch.cuda.synchronize()
81+
print("BF16 graph capture: OK")
82+
83+
# Test 2: NVFP4 GEMM graph capture
84+
print("Attempting NVFP4 graph capture...")
85+
86+
# The graph capture stream — check what stream it uses
87+
g2 = torch.cuda.CUDAGraph()
88+
with torch.cuda.graph(g2):
89+
cap_stream = torch.cuda.current_stream()
90+
cap_stream_ptr = ct.c_void_p(cap_stream.cuda_stream)
91+
print(f" Capture stream ptr: {cap_stream.cuda_stream}")
92+
ret = lib.cgemm_nvfp4_moe_sm100_run(cap_stream_ptr)
93+
# Note: ret might not be meaningful during capture
94+
print(f" Graph capture complete, run returned: {ret}")
95+
96+
# Replay
97+
print("Replaying NVFP4 graph...")
98+
g2.replay()
99+
torch.cuda.synchronize()
100+
print("NVFP4 graph replay: OK")
101+
102+
# Timing
103+
start = torch.cuda.Event(enable_timing=True)
104+
end = torch.cuda.Event(enable_timing=True)
105+
start.record()
106+
for _ in range(100):
107+
g2.replay()
108+
end.record()
109+
torch.cuda.synchronize()
110+
ms = start.elapsed_time(end) / 100
111+
flops = 2 * num_experts * max_M * N * K
112+
tflops = flops / (ms * 1e-3) / 1e12
113+
print(f"NVFP4 GEMM graph: {ms:.3f} ms, {tflops:.1f} TFLOPS")
114+
115+
116+
if __name__ == "__main__":
117+
main()

0 commit comments

Comments
 (0)