@@ -15,7 +15,6 @@ static constexpr int KVAR_N_SHARED_BYTES = KVAR_N_SHARED_FLOATS * sizeof(float);
1515static constexpr int KVAR_N_LOWSHMEM_FLOATS = 6 * KVAR_N_DIM + 2 ;
1616static constexpr int KVAR_N_LOWSHMEM_BYTES = KVAR_N_LOWSHMEM_FLOATS * sizeof (float );
1717static constexpr int KVAR_N_STAGE_CHUNK = 4 ;
18- static constexpr int KVAR_N_MATERIALIZE_CHUNK = 2 ;
1918static constexpr int KVAR_N_MATERIALIZE_FAST_CHUNK = 16 ;
2019static constexpr int KVAR_N_OP_PARAM_BITS = 0 ;
2120static constexpr int KVAR_N_OP_PARAM_ITERS = 1 ;
@@ -1206,28 +1205,6 @@ static __global__ void kvarn_store_workspace_commit_kernel(
12061205 workspace[((int64_t ) token * n_heads + head) * KVAR_N_DIM + threadIdx .x ];
12071206}
12081207
1209- static __device__ uint8_t kvarn_unpack_record (const uint8_t * record, int index, int bits) {
1210- if (bits == 8 ) {
1211- return record[index];
1212- }
1213- if (bits == 4 ) {
1214- const uint8_t packed = record[index >> 1 ];
1215- return (packed >> ((index & 1 ) * 4 )) & 0x0fu ;
1216- }
1217- if (bits == 2 ) {
1218- const uint8_t packed = record[index >> 2 ];
1219- return (packed >> ((index & 3 ) * 2 )) & 0x03u ;
1220- }
1221-
1222- uint8_t value = 0 ;
1223- const int bit_offset = index * bits;
1224- for (int bit = 0 ; bit < bits; ++bit) {
1225- const int src_bit = bit_offset + bit;
1226- value |= ((record[src_bit / 8 ] >> (src_bit % 8 )) & 1u ) << bit;
1227- }
1228- return value;
1229- }
1230-
12311208static __global__ void kvarn_live_groups_kernel (
12321209 const int64_t * indices,
12331210 int n_indices,
@@ -1266,80 +1243,6 @@ static __global__ void kvarn_live_groups_kernel(
12661243 }
12671244}
12681245
1269- static __global__ void kvarn_materialize_kernel (
1270- const uint8_t * records,
1271- const half * stage,
1272- const int * live_groups,
1273- half * dst,
1274- int n_heads,
1275- int n_kv,
1276- int stream_start,
1277- int groups_per_stream,
1278- int record_bytes,
1279- int bits,
1280- bool value,
1281- bool emit_rotated) {
1282- const int head = blockIdx .x ;
1283- const int lane = threadIdx .x / KVAR_N_DIM ;
1284- const int dim = threadIdx .x - lane * KVAR_N_DIM ;
1285- const int token = blockIdx .y * KVAR_N_MATERIALIZE_CHUNK + lane;
1286- const int out_stream = blockIdx .z ;
1287- const int stream = stream_start + out_stream;
1288- __shared__ float rotated[KVAR_N_MATERIALIZE_CHUNK * KVAR_N_DIM ];
1289- float * rotated_lane = rotated + lane * KVAR_N_DIM ;
1290- const int live_group = live_groups[out_stream];
1291-
1292- float x = 0 .0f ;
1293- if (token < n_kv) {
1294- const int group = token / KVAR_N_DIM ;
1295- const int pos = token % KVAR_N_DIM ;
1296- const int stage_base = stream * KVAR_N_DIM * KVAR_N_STAGE_GROUPS ;
1297- if (group == 0 || (group > 0 && group <= live_group && group + 1 >= live_group)) {
1298- const int stage_pos = stage_base + (group == 0 ? pos : KVAR_N_DIM + ((group - 1 ) & 1 ) * KVAR_N_DIM + pos);
1299- x = __half2float (stage[(stage_pos * n_heads + head) * KVAR_N_DIM + dim]);
1300- } else if (group < live_group && group < groups_per_stream) {
1301- const int record_group = stream * groups_per_stream + group;
1302- const uint8_t * record = records + ((int64_t ) record_group * n_heads + head) * record_bytes;
1303- const int row = value ? pos : dim;
1304- const int col = value ? dim : pos;
1305- const int payload_bytes = KVAR_N_TILE_VALUES * bits / 8 ;
1306- const half * scale_axis = (const half *) (record + payload_bytes);
1307- const half * zp_axis = scale_axis + KVAR_N_DIM ;
1308- const half * other_axis = zp_axis + KVAR_N_DIM ;
1309- const uint8_t q = kvarn_unpack_record (record, row * KVAR_N_DIM + col, bits);
1310- x = (q * __half2float (scale_axis[row]) + __half2float (zp_axis[row])) * __half2float (other_axis[col]);
1311- }
1312- }
1313-
1314- if (emit_rotated) {
1315- // Rotated-domain attention: emit the dequantized K_rot/V_rot directly
1316- // (pre-inverse-WHT). The query is rotated and the attention output is
1317- // inverse-rotated in the graph, so the per-token butterfly is skipped.
1318- if (token < n_kv) {
1319- dst[((out_stream * n_kv + token) * n_heads + head) * KVAR_N_DIM + dim] =
1320- __float2half_rn (x);
1321- }
1322- return ;
1323- }
1324-
1325- rotated_lane[dim] = x;
1326- __syncthreads ();
1327- for (int stride = 1 ; stride < KVAR_N_DIM ; stride *= 2 ) {
1328- if (dim < 64 ) {
1329- const int j = (dim / stride) * (2 * stride) + (dim % stride);
1330- const float a = rotated_lane[j];
1331- const float b = rotated_lane[j + stride];
1332- rotated_lane[j] = a + b;
1333- rotated_lane[j + stride] = a - b;
1334- }
1335- __syncthreads ();
1336- }
1337- if (token < n_kv) {
1338- dst[((out_stream * n_kv + token) * n_heads + head) * KVAR_N_DIM + dim] =
1339- __float2half_rn (rotated_lane[dim] * 0 .08838834764831845f );
1340- }
1341- }
1342-
13431246template <int BITS >
13441247static __device__ __forceinline__ uint8_t kvarn_unpack_record_fast (const uint8_t * record, int index) {
13451248 if constexpr (BITS == 8 ) {
@@ -2057,83 +1960,63 @@ void ggml_cuda_op_kvarn_materialize(ggml_backend_cuda_context & ctx, ggml_tensor
20571960 kvarn_prof_end (prof_live, stream);
20581961
20591962 auto prof_mat = kvarn_prof_begin (ctx, stream, kvarn_prof_kind::MATERIALIZE , value, bits, (int ) dst->ne [2 ], ggml_nbytes (dst));
2060- bool use_fast_materialize = true ;
2061- if (use_fast_materialize) {
2062- switch (bits) {
2063- case 2 :
2064- if (value) {
2065- kvarn_launch_materialize_fast<2 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2066- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2067- } else {
2068- kvarn_launch_materialize_fast<2 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2069- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2070- }
2071- break ;
2072- case 3 :
2073- if (value) {
2074- kvarn_launch_materialize_fast<3 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2075- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2076- } else {
2077- kvarn_launch_materialize_fast<3 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2078- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2079- }
2080- break ;
2081- case 4 :
2082- if (value) {
2083- kvarn_launch_materialize_v4_pair ((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2084- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2085- } else {
2086- kvarn_launch_materialize_fast<4 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2087- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2088- }
2089- break ;
2090- case 5 :
2091- if (value) {
2092- kvarn_launch_materialize_fast<5 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2093- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2094- } else {
2095- kvarn_launch_materialize_fast<5 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2096- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2097- }
2098- break ;
2099- case 6 :
2100- if (value) {
2101- kvarn_launch_materialize_fast<6 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2102- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2103- } else {
2104- kvarn_launch_materialize_fast<6 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2105- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2106- }
2107- break ;
2108- case 8 :
2109- if (value) {
2110- kvarn_launch_materialize_fast<8 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2111- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2112- } else {
2113- kvarn_launch_materialize_fast<8 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2114- (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2115- }
2116- break ;
2117- default :
2118- use_fast_materialize = false ;
2119- break ;
2120- }
2121- }
2122- if (!use_fast_materialize) {
2123- dim3 blocks ((uint32_t ) dst->ne [1 ], (uint32_t ) ((dst->ne [2 ] + KVAR_N_MATERIALIZE_CHUNK - 1 ) / KVAR_N_MATERIALIZE_CHUNK ), (uint32_t ) dst->ne [3 ]);
2124- kvarn_materialize_kernel<<<blocks, KVAR_N_DIM * KVAR_N_MATERIALIZE_CHUNK , 0 , stream>>> (
2125- (const uint8_t *) records->data ,
2126- (const half *) stage->data ,
2127- live_groups.get (),
2128- (half *) dst->data ,
2129- (int ) dst->ne [1 ],
2130- (int ) dst->ne [2 ],
2131- stream_start,
2132- groups_per_stream,
2133- (int ) records->ne [0 ],
2134- bits,
2135- value,
2136- emit_rotated);
1963+ switch (bits) {
1964+ case 2 :
1965+ if (value) {
1966+ kvarn_launch_materialize_fast<2 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
1967+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
1968+ } else {
1969+ kvarn_launch_materialize_fast<2 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
1970+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
1971+ }
1972+ break ;
1973+ case 3 :
1974+ if (value) {
1975+ kvarn_launch_materialize_fast<3 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
1976+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
1977+ } else {
1978+ kvarn_launch_materialize_fast<3 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
1979+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
1980+ }
1981+ break ;
1982+ case 4 :
1983+ if (value) {
1984+ kvarn_launch_materialize_v4_pair ((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
1985+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
1986+ } else {
1987+ kvarn_launch_materialize_fast<4 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
1988+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
1989+ }
1990+ break ;
1991+ case 5 :
1992+ if (value) {
1993+ kvarn_launch_materialize_fast<5 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
1994+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
1995+ } else {
1996+ kvarn_launch_materialize_fast<5 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
1997+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
1998+ }
1999+ break ;
2000+ case 6 :
2001+ if (value) {
2002+ kvarn_launch_materialize_fast<6 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2003+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2004+ } else {
2005+ kvarn_launch_materialize_fast<6 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2006+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2007+ }
2008+ break ;
2009+ case 8 :
2010+ if (value) {
2011+ kvarn_launch_materialize_fast<8 , true >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2012+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2013+ } else {
2014+ kvarn_launch_materialize_fast<8 , false >((const uint8_t *) records->data , (const half *) stage->data , live_groups.get (), (half *) dst->data ,
2015+ (int ) dst->ne [1 ], (int ) dst->ne [2 ], n_stream, stream_start, groups_per_stream, (int ) records->ne [0 ], emit_rotated, stream);
2016+ }
2017+ break ;
2018+ default :
2019+ GGML_ABORT (" kvarn: no fast materialize kernel for bits %d" , bits);
21372020 }
21382021 kvarn_prof_end (prof_mat, stream);
21392022}
0 commit comments