@@ -936,15 +936,25 @@ def _get_rotation_matrix(device: torch.device) -> torch.Tensor:
936936 return _rotation_matrices [device ]
937937
938938
939- @register_kernel ("bitsandbytes::cutlass_fused_quantize_nvfp4" , "cuda" )
940- def _ (
939+ def _fused_quantize_nvfp4_impl (
941940 A : torch .Tensor ,
942941 tensor_scale : float ,
942+ packed_out : Optional [torch .Tensor ] = None ,
943+ scales_out : Optional [torch .Tensor ] = None ,
944+ global_scale_buf : Optional [torch .Tensor ] = None ,
943945) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
944- """CUTLASS-based fused quantize with randomized Hadamard rotation.
945-
946- The CUTLASS kernel requires M to be a multiple of 128. We pad here
947- and trim the output to maintain a transparent API.
946+ """Core CUTLASS fused quantize implementation.
947+
948+ When output buffers are provided, no allocations occur — safe for CUDA
949+ graph capture. When None, buffers are allocated (convenient but not
950+ graph-safe).
951+
952+ Args:
953+ A: BF16 input, numel must be divisible by 16.
954+ tensor_scale: Global tensor scale.
955+ packed_out: Pre-allocated uint8 output (padded_M * 8 bytes). None to allocate.
956+ scales_out: Pre-allocated uint8 scales (padded_M bytes). None to allocate.
957+ global_scale_buf: Pre-allocated float32 scalar buffer. None to allocate.
948958 """
949959 A = A .contiguous ()
950960 n = A .numel ()
@@ -954,10 +964,8 @@ def _(
954964 lambda : f"CUTLASS fused quantize requires bfloat16, got { A .dtype } " ,
955965 )
956966
957- # Reshape to 2D: (M, K) where K is the last dimension
958- # The fused quantize GEMM treats each group of 16 elements as one "row"
959- K = 16 # NVFP4 group size = GEMM K dimension
960- N = 16 # B matrix is 16x16
967+ K = 16
968+ N = 16
961969 orig_M = n // K
962970 padded_M = ((orig_M + 127 ) // 128 ) * 128
963971
@@ -970,26 +978,26 @@ def _(
970978 else :
971979 A_flat = A
972980
973- # Compute global_scale = 1/tensor_scale (QuTLASS convention)
974- global_scale = torch .tensor (
975- [1.0 / tensor_scale if tensor_scale > 0 else 0.0 ],
976- dtype = torch .float32 ,
977- device = A .device ,
978- )
979-
980- # Allocate output buffers (padded size)
981- packed_padded = torch .zeros (padded_M * K // 2 , dtype = torch .uint8 , device = A .device )
981+ # Use pre-allocated buffers or allocate new ones
982+ if global_scale_buf is not None :
983+ global_scale_buf .fill_ (1.0 / tensor_scale if tensor_scale > 0 else 0.0 )
984+ global_scale = global_scale_buf
985+ else :
986+ global_scale = torch .tensor (
987+ [1.0 / tensor_scale if tensor_scale > 0 else 0.0 ],
988+ dtype = torch .float32 ,
989+ device = A .device ,
990+ )
982991
983- # Scale output: one E4M3 scale per 16-element block = padded_M scales
984- # QuTLASS outputs as (padded_M, 1) but we flatten
985- scales_padded = torch .zeros (padded_M , dtype = torch .uint8 , device = A .device )
992+ packed_padded = (
993+ packed_out if packed_out is not None else torch .zeros (padded_M * K // 2 , dtype = torch .uint8 , device = A .device )
994+ )
995+ scales_padded = scales_out if scales_out is not None else torch .zeros (padded_M , dtype = torch .uint8 , device = A .device )
986996
987- # Get the cached randomized Hadamard rotation matrix for this device
988997 B = _get_rotation_matrix (A .device )
989998
990999 with _cuda_device_of (A ):
991- fn = lib .cfused_quantize_nvfp4_absmax
992- fn (
1000+ lib .cfused_quantize_nvfp4_absmax (
9931001 get_ptr (A_flat ),
9941002 get_ptr (B ),
9951003 get_ptr (packed_padded ),
@@ -1001,14 +1009,22 @@ def _(
10011009 _get_tensor_stream (A ),
10021010 )
10031011
1004- # Trim to original size
10051012 packed = packed_padded [: orig_M * K // 2 ] if padded_M != orig_M else packed_padded
10061013 block_scales = scales_padded [:orig_M ] if padded_M != orig_M else scales_padded
10071014
10081015 ts_out = torch .tensor ([tensor_scale ], dtype = torch .float32 , device = A .device )
10091016 return packed , block_scales , ts_out
10101017
10111018
1019+ @register_kernel ("bitsandbytes::cutlass_fused_quantize_nvfp4" , "cuda" )
1020+ def _ (
1021+ A : torch .Tensor ,
1022+ tensor_scale : float ,
1023+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1024+ """CUTLASS-based fused quantize with randomized Hadamard rotation."""
1025+ return _fused_quantize_nvfp4_impl (A , tensor_scale )
1026+
1027+
10121028# Scale reordering for CUTLASS block-scaled GEMM
10131029@register_kernel ("bitsandbytes::scale_to_blocked" , "cuda" )
10141030def _ (scales : torch .Tensor , H : int , W : int ) -> torch .Tensor :
@@ -1038,8 +1054,7 @@ def _(scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
10381054# quantization time by scale_to_blocked). Tensor scales are folded into
10391055# the CUTLASS epilogue alpha. Output is BF16, converted to FP32 for
10401056# API compatibility.
1041- @register_kernel ("bitsandbytes::gemm_nvfp4" , "cuda" )
1042- def _ (
1057+ def _gemm_nvfp4_impl (
10431058 A_packed : torch .Tensor ,
10441059 B_packed : torch .Tensor ,
10451060 A_scales : torch .Tensor ,
@@ -1049,12 +1064,27 @@ def _(
10491064 M : int ,
10501065 N : int ,
10511066 K : int ,
1067+ D_out : Optional [torch .Tensor ] = None ,
1068+ alpha_buf : Optional [torch .Tensor ] = None ,
10521069) -> torch .Tensor :
1070+ """Core NVFP4 GEMM implementation.
1071+
1072+ When D_out and alpha_buf are provided, no allocations occur — safe for
1073+ CUDA graph capture. When None, buffers are allocated.
1074+
1075+ Args:
1076+ D_out: Pre-allocated BF16 output (M, N). None to allocate.
1077+ alpha_buf: Pre-allocated float32 scalar buffer. None to allocate.
1078+ """
10531079 with _cuda_device_of (A_packed ):
1054- # A_scales and B_scales are already in CUTLASS block-scaled layout
1055- # (pre-computed at quantization time by scale_to_blocked)
1056- alpha = torch .tensor ([A_tensor_scale * B_tensor_scale ], dtype = torch .float32 , device = A_packed .device )
1057- D_out = torch .empty (M , N , dtype = torch .bfloat16 , device = A_packed .device )
1080+ if alpha_buf is not None :
1081+ alpha_buf .fill_ (A_tensor_scale * B_tensor_scale )
1082+ alpha = alpha_buf
1083+ else :
1084+ alpha = torch .tensor ([A_tensor_scale * B_tensor_scale ], dtype = torch .float32 , device = A_packed .device )
1085+
1086+ if D_out is None :
1087+ D_out = torch .empty (M , N , dtype = torch .bfloat16 , device = A_packed .device )
10581088
10591089 lib .cgemm_nvfp4_cutlass (
10601090 get_ptr (A_packed ),
@@ -1070,3 +1100,29 @@ def _(
10701100 )
10711101
10721102 return D_out .float ()
1103+
1104+
1105+ @register_kernel ("bitsandbytes::gemm_nvfp4" , "cuda" )
1106+ def _ (
1107+ A_packed : torch .Tensor ,
1108+ B_packed : torch .Tensor ,
1109+ A_scales : torch .Tensor ,
1110+ B_scales : torch .Tensor ,
1111+ A_tensor_scale : float ,
1112+ B_tensor_scale : float ,
1113+ M : int ,
1114+ N : int ,
1115+ K : int ,
1116+ ) -> torch .Tensor :
1117+ """NVFP4 GEMM: A @ B^T with block-scaled FP4 inputs."""
1118+ return _gemm_nvfp4_impl (
1119+ A_packed ,
1120+ B_packed ,
1121+ A_scales ,
1122+ B_scales ,
1123+ A_tensor_scale ,
1124+ B_tensor_scale ,
1125+ M ,
1126+ N ,
1127+ K ,
1128+ )
0 commit comments