@@ -52,7 +52,7 @@ namespace HWY_NAMESPACE {
5252static HWY_INLINE void QDotK (const size_t start_pos, const size_t last_pos,
5353 const hwy::Divisor& div_seq_len,
5454 const float * HWY_RESTRICT q,
55- const MatPtrT<BF16 >& k, float * HWY_RESTRICT att,
55+ const MatPtrT<KV_t >& k, float * HWY_RESTRICT att,
5656 const size_t worker) {
5757 PROFILER_ZONE2 (worker, " Gen.Attention.QDotK" );
5858 if (HWY_LIKELY (last_pos < static_cast <size_t >(div_seq_len.GetDivisor ()))) {
@@ -100,7 +100,7 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
100100static HWY_INLINE void WeightedSumV (
101101 const size_t start_pos, const size_t last_pos,
102102 const hwy::Divisor& div_seq_len, const float * HWY_RESTRICT att,
103- const MatPtrT<BF16 >& v, float * HWY_RESTRICT att_out, const size_t worker) {
103+ const MatPtrT<KV_t >& v, float * HWY_RESTRICT att_out, const size_t worker) {
104104 if (HWY_LIKELY (last_pos < static_cast <size_t >(div_seq_len.GetDivisor ()))) {
105105 // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
106106 // we supported non-transposed B.
@@ -125,7 +125,7 @@ static HWY_INLINE void WeightedSumV(
125125// in place for RMSNorm.
126126void SingleDotSoftmaxWeightedSum (
127127 const size_t pos, const size_t start_pos, const size_t last_pos,
128- float * HWY_RESTRICT q, const MatPtrT<BF16 >& k, const MatPtrT<BF16 >& v,
128+ float * HWY_RESTRICT q, const MatPtrT<KV_t >& k, const MatPtrT<KV_t >& v,
129129 const size_t layer_idx, const LayerWeightsPtrs& layer,
130130 const AttentionActivations& activations, float * HWY_RESTRICT att,
131131 float * HWY_RESTRICT att_out, const size_t worker) {
@@ -218,9 +218,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
218218 // this query and head.
219219 const size_t head_offset = (head / kHeadGroups ) * qkv_dim * 2 ;
220220 const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
221- MatPtrT<BF16 > k (" k_view" , Extents2D (seq_len, qkv_dim));
221+ MatPtrT<KV_t > k (" k_view" , Extents2D (seq_len, qkv_dim));
222222 k.SetPtr (kv_cache.Row (0 ) + kv_head_offset, kv_cache.Stride ());
223- MatPtrT<BF16 > v (" v_view" , Extents2D (seq_len, qkv_dim));
223+ MatPtrT<KV_t > v (" v_view" , Extents2D (seq_len, qkv_dim));
224224 v.SetPtr (kv_cache.Row (0 ) + kv_head_offset + qkv_dim, kv_cache.Stride ());
225225
226226 SingleDotSoftmaxWeightedSum (pos, start_pos, last_pos, q, k, v, layer_idx,
@@ -259,7 +259,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
259259 // Set up MatMul row pointers for writing to KV, which consists of
260260 // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
261261 // because rows are computed modulo seq_len.
262- MatPtrT<BF16 > kv_rows (" kv" , Extents2D (activations.pre_att_rms_out .Rows (),
262+ MatPtrT<KV_t > kv_rows (" kv" , Extents2D (activations.pre_att_rms_out .Rows (),
263263 layer.qkv_einsum_w2 .Rows ()));
264264 for (size_t interleaved_idx = 0 ; interleaved_idx < num_interleaved;
265265 ++interleaved_idx) {
@@ -287,7 +287,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
287287 const size_t pos = qbatch.Pos (qi) + batch_idx;
288288 const size_t cache_pos = activations.div_seq_len .Remainder (pos);
289289 auto & kv_cache = qbatch.KV (qi).kv_cache ;
290- BF16 * HWY_RESTRICT kv = kv_cache.Row (cache_pos) +
290+ KV_t * HWY_RESTRICT kv = kv_cache.Row (cache_pos) +
291291 layer_idx * cache_layer_size +
292292 head * qkv_dim * 2 ;
293293
0 commit comments