Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <math.h> // sqrtf
#include <stddef.h>
#include <stdint.h>

#include <atomic>
#include <vector>
Expand Down Expand Up @@ -99,6 +100,7 @@ struct AttentionActivations {
1000000.0)),

div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(layer_config.heads)),
query_scale(ChooseQueryScale(config)) {
// Batch size can be 0 in experimental code so do not assert.
if (batch_size == 0) {
Expand All @@ -125,10 +127,6 @@ struct AttentionActivations {
att_sums.OverrideRows(batch_size);
}

bool IsGlobalLayer(size_t layer_idx) const {
return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor();
}

const ModelConfig& config;

MatStorageT<float> q; // query
Expand All @@ -144,6 +142,8 @@ struct AttentionActivations {
MatStorageT<float> inv_timescale_global;

hwy::Divisor div_seq_len;
// Unfortunately, some models (Griffin) have non-power-of-two heads.
hwy::Divisor div_heads;
float query_scale;
};

Expand Down
191 changes: 87 additions & 104 deletions gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ namespace HWY_NAMESPACE {
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q,
const MatPtrT<BF16>& k, float* HWY_RESTRICT att) {
const MatPtrT<BF16>& k, float* HWY_RESTRICT att,
const size_t worker) {
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
Expand All @@ -71,7 +73,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
const size_t pos, const float mul = 1.0f) {
const size_t worker, const size_t pos,
const float mul = 1.0f) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
Expand All @@ -83,50 +86,49 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
}
// PostQKType::Rope
if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
} else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos);
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
}
}

// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
// `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void WeightedSumV(const size_t start_pos,
const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT att,
const MatPtrT<BF16>& v,
float* HWY_RESTRICT att_out) {
const size_t qkv_dim = v.Cols();
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));

static HWY_INLINE void WeightedSumV(
const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const MatPtrT<BF16>& v, float* HWY_RESTRICT att_out, const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
// we supported non-transposed B.
// TODO: 2..4x unroll
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
{
const size_t pos_mod = div_seq_len.Remainder(start_pos);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
}
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
}
}
}

// Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm.
void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
const size_t last_pos, float* HWY_RESTRICT q,
const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out) {
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, const size_t worker) {
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
Expand All @@ -136,20 +138,22 @@ void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, q,
layer.layer_config.qkv_dim);
layer.layer_config.qkv_dim, worker);
});
}

PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
query_scale);

QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);

// SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
MaybeLogitsSoftCap(att_cap, att, att_len);
Softmax(att, att_len);
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
Softmax(att, att_len, /*temperature=*/1.0f, worker);

WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
worker);
}

// The attention window usually starts at 0 unless `pos` is larger than
Expand Down Expand Up @@ -179,75 +183,52 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
// All layers should have the same number of heads.
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);

// For each head/token/query, compute Q.K, softmax, and weighted V.

// Statically partition token/query across packages.
const size_t num_tq = num_tokens * div_qbatch.GetDivisor();
const IndexRangePartition tq_ranges =
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
ParallelizeOneRange(
tq_ranges, pools.AllPackages(),
[&](const IndexRange& tq_range, const size_t pkg_idx) {
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
pools.AllClusters(pkg_idx).Run(
tq_range.begin(), tq_range.end(),
[&](const size_t tq_idx, const size_t cluster_idx) {
const HWY_MAYBE_UNUSED size_t cluster_base =
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = qbatch.KV(qi).kv_cache;

// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos =
StartPos(pos, activations.config, layer_idx);
size_t last_pos = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1;
}

pools.Cluster(pkg_idx, cluster_idx)
.Run(
0, layer_config.heads,
[&](const size_t head, size_t thread) HWY_ATTR {
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
const size_t tq_idx = activations.div_heads.Divide(task);
const size_t head = activations.div_heads.Remainder(task);
#if PROFILER_ENABLED
const hwy::Zone zone(cluster_base + thread,
zone_id_par);
const hwy::Zone zone(worker, zone_id_par);
#endif

const size_t head_offset =
(head / kHeadGroups) * qkv_dim * 2;

float* HWY_RESTRICT q =
activations.q.Row(tq_idx) + head * qkv_dim;

float* HWY_RESTRICT att =
activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT att_out =
activations.att_out.Row(tq_idx) + head * qkv_dim;

// Make strided read-only views into the kv cache for
// this query and head.
const size_t kv_head_offset =
layer_idx * cache_layer_size + head_offset;
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
kv_cache.Stride());
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
kv_cache.Stride());

SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q,
k, v, layer_idx, layer,
activations, att, att_out);
});
});
});
const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = qbatch.KV(qi).kv_cache;

// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
size_t last_pos = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1;
}

float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim;
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT att_out =
activations.att_out.Row(tq_idx) + head * qkv_dim;

// Make strided read-only views into the kv cache for
// this query and head.
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());

SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
layer, activations, att, att_out, worker);
};

ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, pools,
/*pkg_idx=*/0, func);
}

// Different functions use different naming conventions for the number of
Expand Down Expand Up @@ -286,10 +267,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos =
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
env.row_ptrs[2][interleaved_idx] = reinterpret_cast<uint8_t*>(
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
}
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
kv_rows.AttachRowPtrs(env.row_ptrs[2].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);

Expand All @@ -298,7 +279,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// tasks are very lightweight.
env.ctx.pools.Pool(0).Run(
0, kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
[&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx);
Expand All @@ -318,11 +299,13 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim);
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
thread);
});
}

PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
pos);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
});
Expand Down
2 changes: 1 addition & 1 deletion gemma/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace gcpp {
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \
float* HWY_RESTRICT att_out, size_t worker); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
Expand Down
Loading
Loading