Skip to content

Commit ec82303

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Back to f32 kv_cache, but via typedef
PiperOrigin-RevId: 785356182
1 parent 56c9196 commit ec82303

3 files changed

Lines changed: 11 additions & 9 deletions

File tree

gemma/attention.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ namespace HWY_NAMESPACE {
5252
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
5353
const hwy::Divisor& div_seq_len,
5454
const float* HWY_RESTRICT q,
55-
const MatPtrT<BF16>& k, float* HWY_RESTRICT att,
55+
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
5656
const size_t worker) {
5757
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
5858
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
@@ -100,7 +100,7 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
100100
static HWY_INLINE void WeightedSumV(
101101
const size_t start_pos, const size_t last_pos,
102102
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
103-
const MatPtrT<BF16>& v, float* HWY_RESTRICT att_out, const size_t worker) {
103+
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
104104
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
105105
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
106106
// we supported non-transposed B.
@@ -125,7 +125,7 @@ static HWY_INLINE void WeightedSumV(
125125
// in place for RMSNorm.
126126
void SingleDotSoftmaxWeightedSum(
127127
const size_t pos, const size_t start_pos, const size_t last_pos,
128-
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
128+
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
129129
const size_t layer_idx, const LayerWeightsPtrs& layer,
130130
const AttentionActivations& activations, float* HWY_RESTRICT att,
131131
float* HWY_RESTRICT att_out, const size_t worker) {
@@ -218,9 +218,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
218218
// this query and head.
219219
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
220220
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
221-
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
221+
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
222222
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
223-
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
223+
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
224224
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
225225

226226
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
@@ -259,7 +259,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
259259
// Set up MatMul row pointers for writing to KV, which consists of
260260
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
261261
// because rows are computed modulo seq_len.
262-
MatPtrT<BF16> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
262+
MatPtrT<KV_t> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
263263
layer.qkv_einsum_w2.Rows()));
264264
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
265265
++interleaved_idx) {
@@ -287,7 +287,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
287287
const size_t pos = qbatch.Pos(qi) + batch_idx;
288288
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
289289
auto& kv_cache = qbatch.KV(qi).kv_cache;
290-
BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
290+
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
291291
layer_idx * cache_layer_size +
292292
head * qkv_dim * 2;
293293

gemma/attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace gcpp {
3030
namespace NAMESPACE { \
3131
void SingleDotSoftmaxWeightedSum( \
3232
const size_t pos, const size_t start_pos, const size_t last_pos, \
33-
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
33+
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
3434
size_t layer_idx, const LayerWeightsPtrs& layer, \
3535
const AttentionActivations& activations, float* HWY_RESTRICT att, \
3636
float* HWY_RESTRICT att_out, size_t worker); \

gemma/kv_cache.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
namespace gcpp {
2727

28+
using KV_t = float;
29+
2830
struct KVCache {
2931
KVCache(const ModelConfig& config, const InferenceArgs& inference_args);
3032

@@ -42,7 +44,7 @@ struct KVCache {
4244
MatStorageT<float> conv1d_cache;
4345
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
4446

45-
MatStorageT<BF16> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
47+
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
4648

4749
private:
4850
// For use by other ctor and Copy()

0 commit comments

Comments
 (0)