2020
2121#include " compression/types.h" // GEMMA_DISABLED_TARGETS
2222#include " util/zones.h"
23+ #include " hwy/base.h"
2324#ifndef HWY_DISABLED_TARGETS
2425#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
2526#endif // HWY_DISABLED_TARGETS
@@ -58,8 +59,8 @@ size_t FloatsPerVector() {
5859
5960// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
6061// done already, reshape it to take NF into account.
61- void MaybeReshapeCache (const MatPtrT<KV_t>& kv , MatPtrT<KV_t>& cache) {
62- if (kv. Cols () > cache.Cols ()) {
62+ void MaybeReshapeCache (const size_t default_cols , MatPtrT<KV_t>& cache) {
63+ if (default_cols == cache.Cols ()) {
6364 cache.ReshapePackedRowsToCols (2 * FloatsPerVector ());
6465 }
6566}
@@ -71,13 +72,50 @@ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
7172 // is a tiny fraction of the overall computation, and it is linear in the
7273 // token length.
7374 const size_t kFloatsPerTile = 2 * FloatsPerVector ();
75+ const size_t kRoundedQkvDim = hwy::RoundUpTo (qkv_dim, kMaxBF16PerVector );
7476 for (size_t i = 0 ; i < qkv_dim; i += 2 ) {
7577 k[i * kFloatsPerTile ] = kv[i];
7678 k[i * kFloatsPerTile + 1 ] = kv[i + 1 ];
7779 }
80+ for (size_t i = qkv_dim; i < kRoundedQkvDim ; i += 2 ) {
81+ k[i * kFloatsPerTile ] = hwy::ConvertScalarTo<KV_t>(0 .0f );
82+ k[i * kFloatsPerTile + 1 ] = hwy::ConvertScalarTo<KV_t>(0 .0f );
83+ }
7884 for (size_t i = 0 ; i < qkv_dim; i += kFloatsPerTile ) {
85+ if (i + kFloatsPerTile <= qkv_dim) {
86+ for (size_t j = 0 ; j < kFloatsPerTile ; j++) {
87+ v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
88+ }
89+ } else {
90+ for (size_t j = 0 ; j < qkv_dim - i; j++) {
91+ v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
92+ }
93+ for (size_t j = qkv_dim - i; j < kFloatsPerTile ; j++) {
94+ v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0 .0f );
95+ }
96+ }
97+ }
98+ for (size_t i = hwy::RoundUpTo (qkv_dim, kFloatsPerTile ); i < kRoundedQkvDim ;
99+ i += kFloatsPerTile ) {
79100 for (size_t j = 0 ; j < kFloatsPerTile ; j++) {
80- v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
101+ v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0 .0f );
102+ }
103+ }
104+ }
105+
106+ // Zeros out a part of k and v that corresponds to out-of-bounds cache
107+ // positions.
108+ void TransposeOOBKVCacheRow (KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v,
109+ size_t qkv_dim) {
110+ const size_t kFloatsPerTile = 2 * FloatsPerVector ();
111+ const size_t kRoundedQkvDim = hwy::RoundUpTo (qkv_dim, kMaxBF16PerVector );
112+ for (size_t i = 0 ; i < kRoundedQkvDim ; i += 2 ) {
113+ k[i * kFloatsPerTile ] = hwy::ConvertScalarTo<KV_t>(0 .0f );
114+ k[i * kFloatsPerTile + 1 ] = hwy::ConvertScalarTo<KV_t>(0 .0f );
115+ }
116+ for (size_t i = 0 ; i < kRoundedQkvDim ; i += kFloatsPerTile ) {
117+ for (size_t j = 0 ; j < kFloatsPerTile ; j++) {
118+ v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo<KV_t>(0 .0f );
81119 }
82120 }
83121}
@@ -314,23 +352,51 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
314352 CallMatMul (activations.pre_att_rms_out , layer.qkv_einsum_w2 ,
315353 /* add=*/ nullptr , env, kv_rows);
316354 for (size_t qi = 0 ; qi < qbatch.Size (); ++qi) {
317- MaybeReshapeCache (qbatch.KV (qi).kv_cache , qbatch.KV (qi).k_cache );
318- MaybeReshapeCache (qbatch.KV (qi).kv_cache , qbatch.KV (qi).v_cache );
355+ MaybeReshapeCache (qbatch.KV (qi).cache ->KOrVDefaultCols (),
356+ qbatch.KV (qi).k_cache );
357+ MaybeReshapeCache (qbatch.KV (qi).cache ->KOrVDefaultCols (),
358+ qbatch.KV (qi).v_cache );
319359 }
320360 const size_t kFloatsPerVector = FloatsPerVector ();
361+ const size_t kRoundedTokens =
362+ hwy::RoundUpTo (num_tokens, 2 * kFloatsPerVector );
363+ const size_t kRoundedNumInterleaved =
364+ kRoundedTokens * div_qbatch.GetDivisor ();
321365
322366 // Apply positional encodings for K.
323367 // Note that 2D parallelism is not worth the fork/join overhead because the
324368 // tasks are very lightweight.
325369 ParallelFor (
326- Parallelism::kFlat , kv_heads * num_interleaved , env.ctx ,
370+ Parallelism::kFlat , kv_heads * kRoundedNumInterleaved , env.ctx ,
327371 /* cluster_idx=*/ 0 , Callers::kAttComputeQKV ,
328372 [&](size_t task, size_t worker) HWY_ATTR {
329373 const size_t head = task % kv_heads;
330374 const size_t interleaved_idx = task / kv_heads;
331375 const size_t qi = div_qbatch.Remainder (interleaved_idx);
332376 const size_t token_idx = div_qbatch.Divide (interleaved_idx);
333377 const size_t cache_pos = qbatch.Pos (qi) + token_idx;
378+ if (token_idx >= kRoundedTokens ) {
379+ return ;
380+ }
381+ // The innermost dimension of v is 2NF values from qkv_dim because they
382+ // will be loaded into a BF16 vector to be scaled and added to the
383+ // cached attention output in 2 NF-sized registers.
384+ auto & k_cache = qbatch.KV (qi).k_cache ;
385+ KV_t* HWY_RESTRICT k =
386+ k_cache.Row (cache_pos / (2 * kFloatsPerVector )) +
387+ qbatch.KV (qi).cache ->KOffset (layer_idx, head, kFloatsPerVector ,
388+ cache_pos);
389+ auto & v_cache = qbatch.KV (qi).v_cache ;
390+ KV_t* HWY_RESTRICT v =
391+ v_cache.Row (cache_pos / (2 * kFloatsPerVector )) +
392+ qbatch.KV (qi).cache ->VOffset (layer_idx, head, kFloatsPerVector ,
393+ cache_pos);
394+ if (token_idx >= num_tokens) {
395+ // Create a zero-filled K/V pair for padding for out-of-sequence
396+ // tokens.
397+ TransposeOOBKVCacheRow (k, v, qkv_dim);
398+ return ;
399+ }
334400 // --seq_len must be large enough to avoid wraparound.
335401 HWY_DASSERT (cache_pos < activations.SeqLen ());
336402 auto & kv_cache = qbatch.KV (qi).kv_cache ;
@@ -341,22 +407,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
341407 // The innermost dimension of k is 2 values from qkv_dim because they
342408 // are going to be used in a BF16 dot product involving pairs of
343409 // values over NF k positions.
344- // The innermost dimension of v is 2NF values from qkv_dim because they
345- // will be loaded into a BF16 vector to be scaled and added to the
346- // cached attention output in 2 NF-sized registers.
347- // TODO(rays): factor out these calculations into functions.
348- auto & k_cache = qbatch.KV (qi).k_cache ;
349- KV_t* HWY_RESTRICT k =
350- k_cache.Row (cache_pos / (2 * kFloatsPerVector )) +
351- (layer_idx * cache_layer_size + head * qkv_dim * 2 ) *
352- kFloatsPerVector +
353- (cache_pos % (2 * kFloatsPerVector )) * 2 ;
354- auto & v_cache = qbatch.KV (qi).v_cache ;
355- KV_t* HWY_RESTRICT v =
356- v_cache.Row (cache_pos / (2 * kFloatsPerVector )) +
357- (layer_idx * cache_layer_size + head * qkv_dim * 2 ) *
358- kFloatsPerVector +
359- (cache_pos % (2 * kFloatsPerVector )) * 2 * kFloatsPerVector ;
360410
361411 HWY_ALIGN float kv_f32[2 * kMaxQKVDim ];
362412 const hn::ScalableTag<float > df;
0 commit comments