Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace HWY_NAMESPACE {
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q,
const MatPtrT<BF16>& k, float* HWY_RESTRICT att,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
const size_t worker) {
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
Expand Down Expand Up @@ -100,7 +100,7 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
static HWY_INLINE void WeightedSumV(
const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const MatPtrT<BF16>& v, float* HWY_RESTRICT att_out, const size_t worker) {
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
// we supported non-transposed B.
Expand All @@ -125,7 +125,7 @@ static HWY_INLINE void WeightedSumV(
// in place for RMSNorm.
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, const size_t worker) {
Expand Down Expand Up @@ -218,9 +218,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// this query and head.
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());

SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
Expand Down Expand Up @@ -259,7 +259,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
// because rows are computed modulo seq_len.
MatPtrT<BF16> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
MatPtrT<KV_t> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
Expand Down Expand Up @@ -287,7 +287,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = qbatch.KV(qi).kv_cache;
BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;

Expand Down
2 changes: 1 addition & 1 deletion gemma/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace gcpp {
namespace NAMESPACE { \
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, size_t worker); \
Expand Down
4 changes: 3 additions & 1 deletion gemma/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

namespace gcpp {

using KV_t = float;

struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args);

Expand All @@ -42,7 +44,7 @@ struct KVCache {
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]

MatStorageT<BF16> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]

private:
// For use by other ctor and Copy()
Expand Down
Loading