Skip to content

Commit 189fdf4

Browse files
committed
Remove dead generic KVarN materialize CUDA kernel
The generic materialize kernel was only reachable through the switch default case, but ggml_cuda_kvarn_valid_bits() restricts bits to {2,3,4,5,6,8} and every one has a fast/v4_pair case, so the default was unreachable — and GGML_KVARN_MAT_GENERIC, which used to force it, was already removed. Drop the generic kernel, the use_fast_materialize fallback launch, and the now-orphaned kvarn_unpack_record helper and KVAR_N_MATERIALIZE_CHUNK constant. The switch default now GGML_ABORTs, which is stricter than the old silent generic fallback if a future bit-width is ever added to ggml_cuda_kvarn_valid_bits() without a matching fast kernel. Leaves fast/v4_pair as the single materialize kernel path.
1 parent 765d047 commit 189fdf4

1 file changed

Lines changed: 57 additions & 174 deletions

File tree

ggml/src/ggml-cuda/kvarn.cu

Lines changed: 57 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ static constexpr int KVAR_N_SHARED_BYTES = KVAR_N_SHARED_FLOATS * sizeof(float);
1515
static constexpr int KVAR_N_LOWSHMEM_FLOATS = 6 * KVAR_N_DIM + 2;
1616
static constexpr int KVAR_N_LOWSHMEM_BYTES = KVAR_N_LOWSHMEM_FLOATS * sizeof(float);
1717
static constexpr int KVAR_N_STAGE_CHUNK = 4;
18-
static constexpr int KVAR_N_MATERIALIZE_CHUNK = 2;
1918
static constexpr int KVAR_N_MATERIALIZE_FAST_CHUNK = 16;
2019
static constexpr int KVAR_N_OP_PARAM_BITS = 0;
2120
static constexpr int KVAR_N_OP_PARAM_ITERS = 1;
@@ -1206,28 +1205,6 @@ static __global__ void kvarn_store_workspace_commit_kernel(
12061205
workspace[((int64_t) token * n_heads + head) * KVAR_N_DIM + threadIdx.x];
12071206
}
12081207

1209-
static __device__ uint8_t kvarn_unpack_record(const uint8_t * record, int index, int bits) {
1210-
if (bits == 8) {
1211-
return record[index];
1212-
}
1213-
if (bits == 4) {
1214-
const uint8_t packed = record[index >> 1];
1215-
return (packed >> ((index & 1) * 4)) & 0x0fu;
1216-
}
1217-
if (bits == 2) {
1218-
const uint8_t packed = record[index >> 2];
1219-
return (packed >> ((index & 3) * 2)) & 0x03u;
1220-
}
1221-
1222-
uint8_t value = 0;
1223-
const int bit_offset = index * bits;
1224-
for (int bit = 0; bit < bits; ++bit) {
1225-
const int src_bit = bit_offset + bit;
1226-
value |= ((record[src_bit / 8] >> (src_bit % 8)) & 1u) << bit;
1227-
}
1228-
return value;
1229-
}
1230-
12311208
static __global__ void kvarn_live_groups_kernel(
12321209
const int64_t * indices,
12331210
int n_indices,
@@ -1266,80 +1243,6 @@ static __global__ void kvarn_live_groups_kernel(
12661243
}
12671244
}
12681245

1269-
static __global__ void kvarn_materialize_kernel(
1270-
const uint8_t * records,
1271-
const half * stage,
1272-
const int * live_groups,
1273-
half * dst,
1274-
int n_heads,
1275-
int n_kv,
1276-
int stream_start,
1277-
int groups_per_stream,
1278-
int record_bytes,
1279-
int bits,
1280-
bool value,
1281-
bool emit_rotated) {
1282-
const int head = blockIdx.x;
1283-
const int lane = threadIdx.x / KVAR_N_DIM;
1284-
const int dim = threadIdx.x - lane * KVAR_N_DIM;
1285-
const int token = blockIdx.y * KVAR_N_MATERIALIZE_CHUNK + lane;
1286-
const int out_stream = blockIdx.z;
1287-
const int stream = stream_start + out_stream;
1288-
__shared__ float rotated[KVAR_N_MATERIALIZE_CHUNK * KVAR_N_DIM];
1289-
float * rotated_lane = rotated + lane * KVAR_N_DIM;
1290-
const int live_group = live_groups[out_stream];
1291-
1292-
float x = 0.0f;
1293-
if (token < n_kv) {
1294-
const int group = token / KVAR_N_DIM;
1295-
const int pos = token % KVAR_N_DIM;
1296-
const int stage_base = stream * KVAR_N_DIM * KVAR_N_STAGE_GROUPS;
1297-
if (group == 0 || (group > 0 && group <= live_group && group + 1 >= live_group)) {
1298-
const int stage_pos = stage_base + (group == 0 ? pos : KVAR_N_DIM + ((group - 1) & 1) * KVAR_N_DIM + pos);
1299-
x = __half2float(stage[(stage_pos * n_heads + head) * KVAR_N_DIM + dim]);
1300-
} else if (group < live_group && group < groups_per_stream) {
1301-
const int record_group = stream * groups_per_stream + group;
1302-
const uint8_t * record = records + ((int64_t) record_group * n_heads + head) * record_bytes;
1303-
const int row = value ? pos : dim;
1304-
const int col = value ? dim : pos;
1305-
const int payload_bytes = KVAR_N_TILE_VALUES * bits / 8;
1306-
const half * scale_axis = (const half *) (record + payload_bytes);
1307-
const half * zp_axis = scale_axis + KVAR_N_DIM;
1308-
const half * other_axis = zp_axis + KVAR_N_DIM;
1309-
const uint8_t q = kvarn_unpack_record(record, row * KVAR_N_DIM + col, bits);
1310-
x = (q * __half2float(scale_axis[row]) + __half2float(zp_axis[row])) * __half2float(other_axis[col]);
1311-
}
1312-
}
1313-
1314-
if (emit_rotated) {
1315-
// Rotated-domain attention: emit the dequantized K_rot/V_rot directly
1316-
// (pre-inverse-WHT). The query is rotated and the attention output is
1317-
// inverse-rotated in the graph, so the per-token butterfly is skipped.
1318-
if (token < n_kv) {
1319-
dst[((out_stream * n_kv + token) * n_heads + head) * KVAR_N_DIM + dim] =
1320-
__float2half_rn(x);
1321-
}
1322-
return;
1323-
}
1324-
1325-
rotated_lane[dim] = x;
1326-
__syncthreads();
1327-
for (int stride = 1; stride < KVAR_N_DIM; stride *= 2) {
1328-
if (dim < 64) {
1329-
const int j = (dim / stride) * (2 * stride) + (dim % stride);
1330-
const float a = rotated_lane[j];
1331-
const float b = rotated_lane[j + stride];
1332-
rotated_lane[j] = a + b;
1333-
rotated_lane[j + stride] = a - b;
1334-
}
1335-
__syncthreads();
1336-
}
1337-
if (token < n_kv) {
1338-
dst[((out_stream * n_kv + token) * n_heads + head) * KVAR_N_DIM + dim] =
1339-
__float2half_rn(rotated_lane[dim] * 0.08838834764831845f);
1340-
}
1341-
}
1342-
13431246
template<int BITS>
13441247
static __device__ __forceinline__ uint8_t kvarn_unpack_record_fast(const uint8_t * record, int index) {
13451248
if constexpr (BITS == 8) {
@@ -2057,83 +1960,63 @@ void ggml_cuda_op_kvarn_materialize(ggml_backend_cuda_context & ctx, ggml_tensor
20571960
kvarn_prof_end(prof_live, stream);
20581961

20591962
auto prof_mat = kvarn_prof_begin(ctx, stream, kvarn_prof_kind::MATERIALIZE, value, bits, (int) dst->ne[2], ggml_nbytes(dst));
2060-
bool use_fast_materialize = true;
2061-
if (use_fast_materialize) {
2062-
switch (bits) {
2063-
case 2:
2064-
if (value) {
2065-
kvarn_launch_materialize_fast<2, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2066-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2067-
} else {
2068-
kvarn_launch_materialize_fast<2, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2069-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2070-
}
2071-
break;
2072-
case 3:
2073-
if (value) {
2074-
kvarn_launch_materialize_fast<3, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2075-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2076-
} else {
2077-
kvarn_launch_materialize_fast<3, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2078-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2079-
}
2080-
break;
2081-
case 4:
2082-
if (value) {
2083-
kvarn_launch_materialize_v4_pair((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2084-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2085-
} else {
2086-
kvarn_launch_materialize_fast<4, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2087-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2088-
}
2089-
break;
2090-
case 5:
2091-
if (value) {
2092-
kvarn_launch_materialize_fast<5, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2093-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2094-
} else {
2095-
kvarn_launch_materialize_fast<5, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2096-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2097-
}
2098-
break;
2099-
case 6:
2100-
if (value) {
2101-
kvarn_launch_materialize_fast<6, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2102-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2103-
} else {
2104-
kvarn_launch_materialize_fast<6, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2105-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2106-
}
2107-
break;
2108-
case 8:
2109-
if (value) {
2110-
kvarn_launch_materialize_fast<8, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2111-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2112-
} else {
2113-
kvarn_launch_materialize_fast<8, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2114-
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2115-
}
2116-
break;
2117-
default:
2118-
use_fast_materialize = false;
2119-
break;
2120-
}
2121-
}
2122-
if (!use_fast_materialize) {
2123-
dim3 blocks((uint32_t) dst->ne[1], (uint32_t) ((dst->ne[2] + KVAR_N_MATERIALIZE_CHUNK - 1) / KVAR_N_MATERIALIZE_CHUNK), (uint32_t) dst->ne[3]);
2124-
kvarn_materialize_kernel<<<blocks, KVAR_N_DIM * KVAR_N_MATERIALIZE_CHUNK, 0, stream>>>(
2125-
(const uint8_t *) records->data,
2126-
(const half *) stage->data,
2127-
live_groups.get(),
2128-
(half *) dst->data,
2129-
(int) dst->ne[1],
2130-
(int) dst->ne[2],
2131-
stream_start,
2132-
groups_per_stream,
2133-
(int) records->ne[0],
2134-
bits,
2135-
value,
2136-
emit_rotated);
1963+
switch (bits) {
1964+
case 2:
1965+
if (value) {
1966+
kvarn_launch_materialize_fast<2, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
1967+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
1968+
} else {
1969+
kvarn_launch_materialize_fast<2, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
1970+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
1971+
}
1972+
break;
1973+
case 3:
1974+
if (value) {
1975+
kvarn_launch_materialize_fast<3, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
1976+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
1977+
} else {
1978+
kvarn_launch_materialize_fast<3, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
1979+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
1980+
}
1981+
break;
1982+
case 4:
1983+
if (value) {
1984+
kvarn_launch_materialize_v4_pair((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
1985+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
1986+
} else {
1987+
kvarn_launch_materialize_fast<4, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
1988+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
1989+
}
1990+
break;
1991+
case 5:
1992+
if (value) {
1993+
kvarn_launch_materialize_fast<5, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
1994+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
1995+
} else {
1996+
kvarn_launch_materialize_fast<5, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
1997+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
1998+
}
1999+
break;
2000+
case 6:
2001+
if (value) {
2002+
kvarn_launch_materialize_fast<6, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2003+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2004+
} else {
2005+
kvarn_launch_materialize_fast<6, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2006+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2007+
}
2008+
break;
2009+
case 8:
2010+
if (value) {
2011+
kvarn_launch_materialize_fast<8, true>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2012+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2013+
} else {
2014+
kvarn_launch_materialize_fast<8, false>((const uint8_t *) records->data, (const half *) stage->data, live_groups.get(), (half *) dst->data,
2015+
(int) dst->ne[1], (int) dst->ne[2], n_stream, stream_start, groups_per_stream, (int) records->ne[0], emit_rotated, stream);
2016+
}
2017+
break;
2018+
default:
2019+
GGML_ABORT("kvarn: no fast materialize kernel for bits %d", bits);
21372020
}
21382021
kvarn_prof_end(prof_mat, stream);
21392022
}

0 commit comments

Comments
 (0)