Skip to content

Commit 201e561

Browse files
TimDettmersclaude
andcommitted
fix: Remove merge artifacts that broke build (duplicate defs, stale instantiations)
After merging feature/qutlass-nvfp4-gemm into QLORA-2 (c25d7af), three files had duplicate/stale code that prevented compilation: - _ops.py: duplicate torch.library.define for dequantize_nvfp4, cutlass_fused_quantize_nvfp4, scale_to_blocked, gemm_nvfp4 - ops.cu: stale template instantiation block with wrong signatures (missing cudaStream_t, wrong absmax template params), duplicate testMMA, dead kbitGroupedScalarGemv launcher (kernel removed in ac7d6ff) - pythonInterface.cpp: missing BUILD_CUDA guard on training kernel bindings (half/__nv_bfloat16 unavailable to host compiler), duplicate cquantize/cdequantize_blockwise wrappers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1763839 commit 201e561

File tree

3 files changed

+3
-329
lines changed

3 files changed

+3
-329
lines changed

bitsandbytes/_ops.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,20 +1783,6 @@ def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tens
17831783
return packed, block_scales, ts_out
17841784

17851785

1786-
# NVFP4 dequantization
1787-
torch.library.define(
1788-
"bitsandbytes::dequantize_nvfp4",
1789-
"(Tensor packed, Tensor block_scales, float tensor_scale, int numel, ScalarType dtype) -> Tensor",
1790-
)
1791-
1792-
1793-
@register_fake("bitsandbytes::dequantize_nvfp4")
1794-
def _(
1795-
packed: torch.Tensor, block_scales: torch.Tensor, tensor_scale: float, numel: int, dtype: torch.dtype
1796-
) -> torch.Tensor:
1797-
return torch.empty(numel, dtype=dtype, device=packed.device)
1798-
1799-
18001786
# NVFP4 Hadamard rotation (in-place)
18011787
torch.library.define(
18021788
"bitsandbytes::hadamard_rotate_nvfp4",
@@ -1825,66 +1811,3 @@ def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tens
18251811
block_scales = torch.empty(n // 16, dtype=torch.uint8, device=A.device)
18261812
ts_out = torch.empty(1, dtype=torch.float32, device=A.device)
18271813
return packed, block_scales, ts_out
1828-
1829-
1830-
# CUTLASS-based fused quantize for NVFP4 (SM_120+)
1831-
# Uses QuTLASS GEMM-as-quantize approach with always-on randomized Hadamard
1832-
# rotation. The rotation is free (baked into the GEMM B operand) and improves
1833-
# quantization quality by spreading outliers across blocks.
1834-
torch.library.define(
1835-
"bitsandbytes::cutlass_fused_quantize_nvfp4",
1836-
"(Tensor A, float tensor_scale) -> (Tensor, Tensor, Tensor)",
1837-
)
1838-
1839-
1840-
@register_fake("bitsandbytes::cutlass_fused_quantize_nvfp4")
1841-
def _(
1842-
A: torch.Tensor,
1843-
tensor_scale: float,
1844-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1845-
n = A.numel()
1846-
torch._check(n % 16 == 0, lambda: f"NVFP4 requires numel divisible by 16, got {n}")
1847-
packed = torch.empty(n // 2, dtype=torch.uint8, device=A.device)
1848-
block_scales = torch.empty(n // 16, dtype=torch.uint8, device=A.device)
1849-
ts_out = torch.empty(1, dtype=torch.float32, device=A.device)
1850-
return packed, block_scales, ts_out
1851-
1852-
1853-
# Scale reordering for CUTLASS block-scaled GEMM
1854-
torch.library.define(
1855-
"bitsandbytes::scale_to_blocked",
1856-
"(Tensor scales, int H, int W) -> Tensor",
1857-
)
1858-
1859-
1860-
@register_fake("bitsandbytes::scale_to_blocked")
1861-
def _(scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
1862-
n_row_blocks = (H + 127) // 128
1863-
n_col_blocks = (W + 3) // 4
1864-
out_size = n_row_blocks * n_col_blocks * 128 * 4
1865-
return torch.empty(out_size, dtype=torch.uint8, device=scales.device)
1866-
1867-
1868-
# NVFP4 GEMM (A @ B^T with block-scaled FP4 inputs)
1869-
torch.library.define(
1870-
"bitsandbytes::gemm_nvfp4",
1871-
"(Tensor A_packed, Tensor B_packed, Tensor A_scales, Tensor B_scales, "
1872-
"float A_tensor_scale, float B_tensor_scale, int M, int N, int K) -> Tensor",
1873-
)
1874-
1875-
1876-
@register_fake("bitsandbytes::gemm_nvfp4")
1877-
def _(
1878-
A_packed: torch.Tensor,
1879-
B_packed: torch.Tensor,
1880-
A_scales: torch.Tensor,
1881-
B_scales: torch.Tensor,
1882-
A_tensor_scale: float,
1883-
B_tensor_scale: float,
1884-
M: int,
1885-
N: int,
1886-
K: int,
1887-
) -> torch.Tensor:
1888-
torch._check_is_size(M)
1889-
torch._check_is_size(N)
1890-
torch._check_is_size(K)

csrc/ops.cu

Lines changed: 2 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -5761,222 +5761,8 @@ INSTANTIATE_VQ_SCALAR_GEMV_F32(3, 8)
57615761
INSTANTIATE_VQ_SCALAR_GEMV_F32(3, 10)
57625762
INSTANTIATE_VQ_SCALAR_GEMV_F32(4, 8)
57635763

5764-
// ============================================================================
5765-
// Training Kernels (from QLORA-2 branch)
5766-
// ============================================================================
5767-
5768-
}
5769-
}
5770-
5771-
// ---- Grouped scalar GEMV launcher ----
5772-
template <int K, typename scalar_t>
5773-
void kbitGroupedScalarGemv(
5774-
const scalar_t* A_concat, const unsigned int* B_packed_all, const unsigned char* B_absmax_all,
5775-
const float* codebook, scalar_t* C_concat, const int* expert_offsets, int K_dim, int N, int num_experts
5776-
) {
5777-
constexpr int COLS_PER_BLOCK = 4;
5778-
constexpr int BLOCK_SIZE = 128;
5779-
int n_groups = (N + COLS_PER_BLOCK - 1) / COLS_PER_BLOCK;
5780-
dim3 grid(n_groups, num_experts);
5781-
5782-
kbit_grouped_scalar_gemv<K, 4, scalar_t><<<grid, BLOCK_SIZE>>>(
5783-
A_concat, B_packed_all, B_absmax_all, codebook, C_concat, expert_offsets, K_dim, N, num_experts
5784-
);
5785-
CUDA_CHECK_RETURN(cudaPeekAtLastError());
5786-
}
5787-
5788-
// ---- Debug: Simple MMA test kernel ----
5789-
// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
5790-
__global__ void test_mma_kernel(const half* __restrict__ A, const half* __restrict__ B, float* __restrict__ C) {
5791-
int lane_id = threadIdx.x % 32;
5792-
int gid = lane_id / 4;
5793-
int tid = lane_id % 4;
5794-
5795-
// Load A fragment: A is [16,16] row-major
5796-
// m16n8k16 register order (from Turing m16n8k8 decomposition):
5797-
// a[0]: row_lo (gid), k_lo (tid*2..tid*2+1)
5798-
// a[1]: row_hi (gid+8), k_lo (tid*2..tid*2+1)
5799-
// a[2]: row_lo (gid), k_hi (tid*2+8..tid*2+9)
5800-
// a[3]: row_hi (gid+8), k_hi (tid*2+8..tid*2+9)
5801-
uint32_t frag_a[4];
5802-
{
5803-
half2 h_rlo_klo = __halves2half2(A[gid * 16 + tid * 2], A[gid * 16 + tid * 2 + 1]);
5804-
half2 h_rhi_klo = __halves2half2(A[(gid + 8) * 16 + tid * 2], A[(gid + 8) * 16 + tid * 2 + 1]);
5805-
half2 h_rlo_khi = __halves2half2(A[gid * 16 + tid * 2 + 8], A[gid * 16 + tid * 2 + 9]);
5806-
half2 h_rhi_khi = __halves2half2(A[(gid + 8) * 16 + tid * 2 + 8], A[(gid + 8) * 16 + tid * 2 + 9]);
5807-
frag_a[0] = *reinterpret_cast<uint32_t*>(&h_rlo_klo);
5808-
frag_a[1] = *reinterpret_cast<uint32_t*>(&h_rhi_klo);
5809-
frag_a[2] = *reinterpret_cast<uint32_t*>(&h_rlo_khi);
5810-
frag_a[3] = *reinterpret_cast<uint32_t*>(&h_rhi_khi);
5811-
}
5812-
5813-
// Load B fragment: B is [16,8] row-major. MMA B is col-major, so B_col[k,n] = B_row[k,n].
5814-
uint32_t frag_b[2];
5815-
{
5816-
half2 b0 = __halves2half2(B[(tid * 2) * 8 + gid], B[(tid * 2 + 1) * 8 + gid]);
5817-
half2 b1 = __halves2half2(B[(tid * 2 + 8) * 8 + gid], B[(tid * 2 + 9) * 8 + gid]);
5818-
frag_b[0] = *reinterpret_cast<uint32_t*>(&b0);
5819-
frag_b[1] = *reinterpret_cast<uint32_t*>(&b1);
5820-
}
5821-
5822-
float c[4] = {0, 0, 0, 0};
5823-
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
5824-
"{%0, %1, %2, %3}, "
5825-
"{%4, %5, %6, %7}, "
5826-
"{%8, %9}, "
5827-
"{%10, %11, %12, %13};\n"
5828-
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
5829-
: "r"(frag_a[0]), "r"(frag_a[1]), "r"(frag_a[2]), "r"(frag_a[3]), "r"(frag_b[0]), "r"(frag_b[1]),
5830-
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
5831-
5832-
// Write C[16,8] row-major
5833-
C[gid * 8 + tid * 2] = c[0];
5834-
C[gid * 8 + tid * 2 + 1] = c[1];
5835-
C[(gid + 8) * 8 + tid * 2] = c[2];
5836-
C[(gid + 8) * 8 + tid * 2 + 1] = c[3];
5837-
}
5838-
5839-
void testMMA(const half* A, const half* B, float* C) {
5840-
test_mma_kernel<<<1, 32>>>(A, B, C);
5841-
CUDA_CHECK_RETURN(cudaPeekAtLastError());
5842-
}
5843-
5844-
// ---- Template instantiations ----
5845-
5846-
#define INSTANTIATE_KBIT_QUANT(T, K) \
5847-
template void quantizeBlockwise_kbit<T, K>(const float*, const T*, float*, unsigned int*, int);
5848-
5849-
INSTANTIATE_KBIT_QUANT(half, 2)
5850-
INSTANTIATE_KBIT_QUANT(half, 3)
5851-
INSTANTIATE_KBIT_QUANT(half, 4)
5852-
INSTANTIATE_KBIT_QUANT(half, 5)
5853-
INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 2)
5854-
INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 3)
5855-
INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 4)
5856-
INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 5)
5857-
INSTANTIATE_KBIT_QUANT(float, 2)
5858-
INSTANTIATE_KBIT_QUANT(float, 3)
5859-
INSTANTIATE_KBIT_QUANT(float, 4)
5860-
INSTANTIATE_KBIT_QUANT(float, 5)
5861-
5862-
// Dequant instantiations: all output types × absmax types × K values
5863-
#define INSTANTIATE_KBIT_DEQUANT(T, K, ABSMAX_T) \
5864-
template void dequantizeBlockwise_kbit<T, K, ABSMAX_T>( \
5865-
const unsigned int*, const float*, const ABSMAX_T*, T*, int, cudaStream_t \
5866-
);
5867-
5868-
// uint8 E4M4 absmax (default)
5869-
INSTANTIATE_KBIT_DEQUANT(half, 2, unsigned char)
5870-
INSTANTIATE_KBIT_DEQUANT(half, 3, unsigned char)
5871-
INSTANTIATE_KBIT_DEQUANT(half, 4, unsigned char)
5872-
INSTANTIATE_KBIT_DEQUANT(half, 5, unsigned char)
5873-
INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 2, unsigned char)
5874-
INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 3, unsigned char)
5875-
INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 4, unsigned char)
5876-
INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 5, unsigned char)
5877-
INSTANTIATE_KBIT_DEQUANT(float, 2, unsigned char)
5878-
INSTANTIATE_KBIT_DEQUANT(float, 3, unsigned char)
5879-
INSTANTIATE_KBIT_DEQUANT(float, 4, unsigned char)
5880-
INSTANTIATE_KBIT_DEQUANT(float, 5, unsigned char)
5881-
5882-
// fp16 absmax (option)
5883-
INSTANTIATE_KBIT_DEQUANT(half, 2, half)
5884-
INSTANTIATE_KBIT_DEQUANT(half, 3, half)
5885-
INSTANTIATE_KBIT_DEQUANT(half, 4, half)
5886-
INSTANTIATE_KBIT_DEQUANT(half, 5, half)
5887-
INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 2, half)
5888-
INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 3, half)
5889-
INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 4, half)
5890-
INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 5, half)
5891-
INSTANTIATE_KBIT_DEQUANT(float, 2, half)
5892-
INSTANTIATE_KBIT_DEQUANT(float, 3, half)
5893-
INSTANTIATE_KBIT_DEQUANT(float, 4, half)
5894-
INSTANTIATE_KBIT_DEQUANT(float, 5, half)
5895-
5896-
// Repack instantiations: one per K value
5897-
#define INSTANTIATE_KBIT_REPACK(K) \
5898-
template void repackKbit<K>(const unsigned int*, const float*, unsigned int*, unsigned char*, int, int);
5899-
5900-
INSTANTIATE_KBIT_REPACK(2)
5901-
INSTANTIATE_KBIT_REPACK(3)
5902-
INSTANTIATE_KBIT_REPACK(4)
5903-
INSTANTIATE_KBIT_REPACK(5)
5904-
5905-
// GEMM instantiations: one per K value (fp16 only)
5906-
#define INSTANTIATE_KBIT_GEMM(K) \
5907-
template void kbitGemmMinimal<K>( \
5908-
const half*, const unsigned int*, const unsigned char*, const float*, half*, int, int, int \
5909-
); \
5910-
template void kbitGemmPipelined<K>( \
5911-
const half*, const unsigned int*, const unsigned char*, const float*, half*, int, int, int \
5912-
); \
5913-
template void kbitGemmSplitK<K>( \
5914-
const half*, const unsigned int*, const unsigned char*, const float*, half*, float*, int*, int, int, int, int \
5915-
);
5916-
5917-
INSTANTIATE_KBIT_GEMM(2)
5918-
INSTANTIATE_KBIT_GEMM(3)
5919-
INSTANTIATE_KBIT_GEMM(4)
5920-
INSTANTIATE_KBIT_GEMM(5)
5921-
5922-
// Production kernel instantiations (fp16 and bf16)
5923-
#define INSTANTIATE_KBIT_GEMM_PROD(K) \
5924-
template void kbitGemmProd<K, half>( \
5925-
const half*, const unsigned int*, const unsigned char*, const float*, half*, float*, int*, int, int, int, int \
5926-
); \
5927-
template void kbitGemmProd<K, __nv_bfloat16>( \
5928-
const __nv_bfloat16*, const unsigned int*, const unsigned char*, const float*, __nv_bfloat16*, float*, int*, \
5929-
int, int, int, int \
5930-
);
5931-
5932-
INSTANTIATE_KBIT_GEMM_PROD(2)
5933-
INSTANTIATE_KBIT_GEMM_PROD(3)
5934-
INSTANTIATE_KBIT_GEMM_PROD(4)
5935-
INSTANTIATE_KBIT_GEMM_PROD(5)
5936-
5937-
// Grouped expert GEMM instantiations (fp16 and bf16)
5938-
#define INSTANTIATE_KBIT_GROUPED_GEMM_PROD(K) \
5939-
template void kbitGroupedGemmProd<K, half>( \
5940-
const half*, const unsigned int*, const unsigned char*, const float*, half*, const int*, int, int, int \
5941-
); \
5942-
template void kbitGroupedGemmProd<K, __nv_bfloat16>( \
5943-
const __nv_bfloat16*, const unsigned int*, const unsigned char*, const float*, __nv_bfloat16*, const int*, \
5944-
int, int, int \
5945-
);
5946-
5947-
INSTANTIATE_KBIT_GROUPED_GEMM_PROD(2)
5948-
INSTANTIATE_KBIT_GROUPED_GEMM_PROD(3)
5949-
INSTANTIATE_KBIT_GROUPED_GEMM_PROD(4)
5950-
INSTANTIATE_KBIT_GROUPED_GEMM_PROD(5)
5951-
5952-
// Scalar GEMV instantiations (fp16 and bf16) — flat layout, float32 absmax, C=1
5953-
#define INSTANTIATE_KBIT_SCALAR_GEMV(K) \
5954-
template void kbitScalarGemv<K, half>( \
5955-
const half*, const unsigned int*, const float*, const float*, half*, int, int, int \
5956-
); \
5957-
template void kbitScalarGemv<K, __nv_bfloat16>( \
5958-
const __nv_bfloat16*, const unsigned int*, const float*, const float*, __nv_bfloat16*, int, int, int \
5959-
);
5960-
5961-
INSTANTIATE_KBIT_SCALAR_GEMV(2)
5962-
INSTANTIATE_KBIT_SCALAR_GEMV(3)
5963-
INSTANTIATE_KBIT_SCALAR_GEMV(4)
5964-
INSTANTIATE_KBIT_SCALAR_GEMV(5)
5965-
5966-
// Grouped scalar GEMV instantiations (fp16 and bf16)
5967-
#define INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV(K) \
5968-
template void kbitGroupedScalarGemv<K, half>( \
5969-
const half*, const unsigned int*, const unsigned char*, const float*, half*, const int*, int, int, int \
5970-
); \
5971-
template void kbitGroupedScalarGemv<K, __nv_bfloat16>( \
5972-
const __nv_bfloat16*, const unsigned int*, const unsigned char*, const float*, __nv_bfloat16*, const int*, \
5973-
int, int, int \
5974-
);
5975-
5976-
INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV(2)
5977-
INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV(3)
5978-
INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV(4)
5979-
INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV(5)
5764+
// NOTE: kbitGroupedScalarGemv was removed (grouped MMA covers all MoE shapes).
5765+
// See commit ac7d6ff.
59805766

59815767
// ============================================================================
59825768
// Training Kernels: SwiGLU, RMSNorm, RoPE

csrc/pythonInterface.cpp

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2569,6 +2569,7 @@ void chadamard_rotate_full_bf16(
25692569
#endif
25702570
}
25712571

2572+
#if BUILD_CUDA || BUILD_HIP
25722573
// ============================================================================
25732574
// Training Kernel Bindings (from QLORA-2 branch)
25742575
// ============================================================================
@@ -2675,42 +2676,6 @@ void ccross_entropy_backward_bf16(
26752676

26762677
extern "C" {
26772678
#if BUILD_CUDA || BUILD_HIP
2678-
void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); }
2679-
2680-
void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
2681-
dequantize(code, A, out, n, stream);
2682-
}
2683-
2684-
void cdequantize_blockwise_fp16_fp4(
2685-
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
2686-
) {
2687-
dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream);
2688-
}
2689-
2690-
void cdequantize_blockwise_fp16(
2691-
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
2692-
) {
2693-
dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream);
2694-
}
2695-
2696-
void cdequantize_blockwise_fp16_nf4(
2697-
float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream
2698-
) {
2699-
dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream);
2700-
}
2701-
2702-
void cquantize_blockwise_fp16(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
2703-
quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n);
2704-
}
2705-
2706-
void cquantize_blockwise_fp16_fp4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
2707-
quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n);
2708-
}
2709-
2710-
void cquantize_blockwise_fp16_nf4(float* code, half* A, float* absmax, unsigned char* out, int blocksize, const int n) {
2711-
quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n);
2712-
}
2713-
27142679
// Training kernel extern C wrappers
27152680
void cswiglu_forward_fp16_c(const half* gate, const half* up, half* out, int n) {
27162681
cswiglu_forward_fp16(gate, up, out, n);

0 commit comments

Comments
 (0)