Skip to content

Commit b64bb91

Browse files
TimDettmersclaude
andcommitted
Add ldmatrix + XOR swizzle for A-fragment loading in production kernel
Replace 8 element-by-element shared memory reads per A fragment with a single ldmatrix.sync.aligned.m8n8.x4.shared.b16 instruction. Add XOR bank-conflict swizzle: col_group ^ (row % 8) at 8-half granularity. Without swizzle, all 8 threads in an ldmatrix group hit the same bank (8-way conflict) because TILE_K=64 gives a stride that's a multiple of the bank repeat distance. The XOR swizzle distributes threads across 8 different banks (zero conflicts). All 139 tests still pass. The fp16 path produces identical output to the element-by-element version (verified by test_prod_fp16_matches_splitk). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 24406d2 commit b64bb91

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

csrc/ops.cu

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,14 +1846,18 @@ __global__ void kbit_gemm_prod(
18461846
if (threadIdx.x < ABS_INT4S)
18471847
cp_async_cg_16(&abs_dst[threadIdx.x], &abs_src[threadIdx.x]);
18481848

1849-
// A tile (synchronous, with bounds check)
1849+
// A tile (synchronous, with bounds check + XOR swizzle for bank-conflict-free ldmatrix)
1850+
// Swizzle: col_group (8-half granularity) XOR'd with (row % 8)
18501851
scalar_t* a_dst = sh_a(stage);
18511852
for (int i = threadIdx.x; i < A_STAGE_ELEMS; i += blockDim.x) {
18521853
int row = i / TILE_K;
18531854
int col = i % TILE_K;
1855+
int col_group = col / 8;
1856+
int swizzled_group = col_group ^ (row % 8);
1857+
int swizzled_col = swizzled_group * 8 + (col % 8);
18541858
int gr = m_base + row;
18551859
int gc = k_base + col;
1856-
a_dst[row * TILE_K + col] = (gr < M && gc < K_dim) ? A[gr * K_dim + gc] : Ops::from_float(0.0f);
1860+
a_dst[row * TILE_K + swizzled_col] = (gr < M && gc < K_dim) ? A[gr * K_dim + gc] : Ops::from_float(0.0f);
18571861
}
18581862
};
18591863

@@ -1868,26 +1872,27 @@ __global__ void kbit_gemm_prod(
18681872
const int k_block = ks / 2;
18691873
const int half_idx = ks % 2;
18701874

1871-
// Load A fragment
1875+
// Load A fragment via ldmatrix with XOR swizzle
18721876
uint32_t frag_a[4];
18731877
{
1874-
const int kc0 = ks * 16 + tid * 2;
1875-
const int kc1 = ks * 16 + tid * 2 + 8;
1876-
const int r0 = gid;
1877-
const int r1 = gid + 8;
1878-
scalar_t zero = Ops::from_float(0.0f);
1879-
frag_a[0] = pack_two<scalar_t>(
1880-
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0] : zero,
1881-
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0 + 1] : zero);
1882-
frag_a[1] = pack_two<scalar_t>(
1883-
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0] : zero,
1884-
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0 + 1] : zero);
1885-
frag_a[2] = pack_two<scalar_t>(
1886-
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1] : zero,
1887-
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1 + 1] : zero);
1888-
frag_a[3] = pack_two<scalar_t>(
1889-
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1] : zero,
1890-
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1 + 1] : zero);
1878+
// Thread t is in matrix (lane_id / 8), row (lane_id % 8) within that matrix.
1879+
// Matrix layout: 0=top/k_lo, 1=bottom/k_lo, 2=top/k_hi, 3=bottom/k_hi
1880+
const int matrix_id = lane_id / 8;
1881+
const int row_in_matrix = lane_id % 8;
1882+
const int a_row = row_in_matrix + (matrix_id % 2) * 8;
1883+
const int col_start = ks * 16 + (matrix_id / 2) * 8;
1884+
1885+
// Apply same XOR swizzle as write path
1886+
const int col_group = col_start / 8;
1887+
const int swizzled_group = col_group ^ (a_row % 8);
1888+
const int swizzled_col_start = swizzled_group * 8;
1889+
1890+
const scalar_t* addr = &a_ptr[a_row * TILE_K + swizzled_col_start];
1891+
uint32_t smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(addr));
1892+
1893+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
1894+
: "=r"(frag_a[0]), "=r"(frag_a[1]), "=r"(frag_a[2]), "=r"(frag_a[3])
1895+
: "r"(smem_addr));
18911896
}
18921897

18931898
#pragma unroll

0 commit comments

Comments
 (0)