Skip to content

Commit 3ea86cf

Browse files
jan-wassenbergcopybara-github
authored andcommitted
1.1x prefill and decode speedup (attention/activations)
Optimizations - Better load-balancing in attention threading (Previously, clusters were limited by #heads) - Add MulByConstTo to avoid zero-init - Parallel activations Cleanup - Prepare for RowPtr in A or B - Pass through thread_id to ops - Avoid warning in bench_matmul PiperOrigin-RevId: 773579903
1 parent 7630ec0 commit 3ea86cf

12 files changed

Lines changed: 270 additions & 195 deletions

File tree

gemma/activations.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <math.h> // sqrtf
2020
#include <stddef.h>
21+
#include <stdint.h>
2122

2223
#include <atomic>
2324
#include <vector>
@@ -99,6 +100,7 @@ struct AttentionActivations {
99100
1000000.0)),
100101

101102
div_seq_len(static_cast<uint32_t>(seq_len)),
103+
div_heads(static_cast<uint32_t>(layer_config.heads)),
102104
query_scale(ChooseQueryScale(config)) {
103105
// Batch size can be 0 in experimental code so do not assert.
104106
if (batch_size == 0) {
@@ -125,10 +127,6 @@ struct AttentionActivations {
125127
att_sums.OverrideRows(batch_size);
126128
}
127129

128-
bool IsGlobalLayer(size_t layer_idx) const {
129-
return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor();
130-
}
131-
132130
const ModelConfig& config;
133131

134132
MatStorageT<float> q; // query
@@ -144,6 +142,8 @@ struct AttentionActivations {
144142
MatStorageT<float> inv_timescale_global;
145143

146144
hwy::Divisor div_seq_len;
145+
// Unfortunately, some models (Griffin) have non-power-of-two heads.
146+
hwy::Divisor div_heads;
147147
float query_scale;
148148
};
149149

gemma/attention.cc

Lines changed: 87 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ namespace HWY_NAMESPACE {
5252
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
5353
const hwy::Divisor& div_seq_len,
5454
const float* HWY_RESTRICT q,
55-
const MatPtrT<BF16>& k, float* HWY_RESTRICT att) {
55+
const MatPtrT<BF16>& k, float* HWY_RESTRICT att,
56+
const size_t worker) {
57+
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
5658
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
5759
// Slightly faster: no wraparound.
5860
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@@ -71,7 +73,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
7173
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
7274
const LayerWeightsPtrs& layer,
7375
const AttentionActivations& activations,
74-
const size_t pos, const float mul = 1.0f) {
76+
const size_t worker, const size_t pos,
77+
const float mul = 1.0f) {
7578
const size_t qkv_dim = layer.layer_config.qkv_dim;
7679
const PostQKType& post_qk = layer.layer_config.post_qk;
7780
// qk is either q or k, so qkv_dim is the length we operate on.
@@ -83,50 +86,49 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
8386
}
8487
// PostQKType::Rope
8588
if (post_qk == PostQKType::HalfRope) {
86-
Rope(qk, qkv_dim / 2, inv_timescale, pos);
87-
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
89+
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
90+
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
8891
} else {
89-
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos);
92+
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
9093
}
9194
}
9295

9396
// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
9497
// `att_out`. Equivalent in gemma/modules.py:
9598
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
9699
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
97-
static HWY_INLINE void WeightedSumV(const size_t start_pos,
98-
const size_t last_pos,
99-
const hwy::Divisor& div_seq_len,
100-
const float* HWY_RESTRICT att,
101-
const MatPtrT<BF16>& v,
102-
float* HWY_RESTRICT att_out) {
103-
const size_t qkv_dim = v.Cols();
104-
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
105-
100+
static HWY_INLINE void WeightedSumV(
101+
const size_t start_pos, const size_t last_pos,
102+
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
103+
const MatPtrT<BF16>& v, float* HWY_RESTRICT att_out, const size_t worker) {
106104
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
107-
// Slightly faster: no wraparound.
108-
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
109-
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
105+
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
106+
// we supported non-transposed B.
107+
// TODO: 2..4x unroll
108+
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
109+
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
110+
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
110111
}
111112
} else {
112-
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
113-
const size_t pos_modulo = div_seq_len.Remainder(pos);
114-
const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
115-
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
113+
{
114+
const size_t pos_mod = div_seq_len.Remainder(start_pos);
115+
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
116+
}
117+
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
118+
const size_t pos_mod = div_seq_len.Remainder(pos);
119+
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
116120
}
117121
}
118122
}
119123

120124
// Calculates the attention outputs for a single q, which may be updated
121125
// in place for RMSNorm.
122-
void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
123-
const size_t last_pos, float* HWY_RESTRICT q,
124-
const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
125-
const size_t layer_idx,
126-
const LayerWeightsPtrs& layer,
127-
const AttentionActivations& activations,
128-
float* HWY_RESTRICT att,
129-
float* HWY_RESTRICT att_out) {
126+
void SingleDotSoftmaxWeightedSum(
127+
const size_t pos, const size_t start_pos, const size_t last_pos,
128+
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
129+
const size_t layer_idx, const LayerWeightsPtrs& layer,
130+
const AttentionActivations& activations, float* HWY_RESTRICT att,
131+
float* HWY_RESTRICT att_out, const size_t worker) {
130132
const float att_cap = activations.config.att_cap;
131133
const float query_scale = activations.query_scale;
132134
const size_t seq_len =
@@ -136,20 +138,22 @@ void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
136138
if (layer.query_norm_scale.HasPtr()) {
137139
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
138140
RMSNormInplace(weights_t->PackedScale1(), 0, q,
139-
layer.layer_config.qkv_dim);
141+
layer.layer_config.qkv_dim, worker);
140142
});
141143
}
142144

143-
PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
145+
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
146+
query_scale);
144147

145-
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
148+
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);
146149

147150
// SoftMax with optional SoftCap yields "probabilities" in att.
148151
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
149-
MaybeLogitsSoftCap(att_cap, att, att_len);
150-
Softmax(att, att_len);
152+
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
153+
Softmax(att, att_len, /*temperature=*/1.0f, worker);
151154

152-
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out);
155+
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
156+
worker);
153157
}
154158

155159
// The attention window usually starts at 0 unless `pos` is larger than
@@ -179,75 +183,52 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
179183
const size_t cache_layer_size = layer_config.CacheLayerSize();
180184
const size_t seq_len =
181185
static_cast<size_t>(activations.div_seq_len.GetDivisor());
186+
// All layers should have the same number of heads.
187+
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
182188

183189
// For each head/token/query, compute Q.K, softmax, and weighted V.
184-
185-
// Statically partition token/query across packages.
186-
const size_t num_tq = num_tokens * div_qbatch.GetDivisor();
187-
const IndexRangePartition tq_ranges =
188-
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
189-
ParallelizeOneRange(
190-
tq_ranges, pools.AllPackages(),
191-
[&](const IndexRange& tq_range, const size_t pkg_idx) {
192-
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
193-
pools.AllClusters(pkg_idx).Run(
194-
tq_range.begin(), tq_range.end(),
195-
[&](const size_t tq_idx, const size_t cluster_idx) {
196-
const HWY_MAYBE_UNUSED size_t cluster_base =
197-
pkg_base + cluster_idx * pools.MaxWorkersPerCluster();
198-
const size_t qi = div_qbatch.Remainder(tq_idx);
199-
const size_t batch_idx = div_qbatch.Divide(tq_idx);
200-
auto& kv_cache = qbatch.KV(qi).kv_cache;
201-
202-
// Find the token position in the query and calculate
203-
// the range of cache positions to attend to.
204-
const size_t pos = qbatch.Pos(qi) + batch_idx;
205-
const size_t start_pos =
206-
StartPos(pos, activations.config, layer_idx);
207-
size_t last_pos = pos;
208-
const size_t prefix_end = qbatch.PrefixEnd(qi);
209-
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
210-
// last_pos in QDotK and WeightedSumV is inclusive.
211-
last_pos = prefix_end - 1;
212-
}
213-
214-
pools.Cluster(pkg_idx, cluster_idx)
215-
.Run(
216-
0, layer_config.heads,
217-
[&](const size_t head, size_t thread) HWY_ATTR {
190+
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
191+
const size_t tq_idx = activations.div_heads.Divide(task);
192+
const size_t head = activations.div_heads.Remainder(task);
218193
#if PROFILER_ENABLED
219-
const hwy::Zone zone(cluster_base + thread,
220-
zone_id_par);
194+
const hwy::Zone zone(worker, zone_id_par);
221195
#endif
222196

223-
const size_t head_offset =
224-
(head / kHeadGroups) * qkv_dim * 2;
225-
226-
float* HWY_RESTRICT q =
227-
activations.q.Row(tq_idx) + head * qkv_dim;
228-
229-
float* HWY_RESTRICT att =
230-
activations.att.Row(tq_idx) + head * seq_len;
231-
float* HWY_RESTRICT att_out =
232-
activations.att_out.Row(tq_idx) + head * qkv_dim;
233-
234-
// Make strided read-only views into the kv cache for
235-
// this query and head.
236-
const size_t kv_head_offset =
237-
layer_idx * cache_layer_size + head_offset;
238-
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
239-
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
240-
kv_cache.Stride());
241-
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
242-
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
243-
kv_cache.Stride());
244-
245-
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q,
246-
k, v, layer_idx, layer,
247-
activations, att, att_out);
248-
});
249-
});
250-
});
197+
const size_t qi = div_qbatch.Remainder(tq_idx);
198+
const size_t batch_idx = div_qbatch.Divide(tq_idx);
199+
auto& kv_cache = qbatch.KV(qi).kv_cache;
200+
201+
// Find the token position in the query and calculate
202+
// the range of cache positions to attend to.
203+
const size_t pos = qbatch.Pos(qi) + batch_idx;
204+
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
205+
size_t last_pos = pos;
206+
const size_t prefix_end = qbatch.PrefixEnd(qi);
207+
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
208+
// last_pos in QDotK and WeightedSumV is inclusive.
209+
last_pos = prefix_end - 1;
210+
}
211+
212+
float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim;
213+
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
214+
float* HWY_RESTRICT att_out =
215+
activations.att_out.Row(tq_idx) + head * qkv_dim;
216+
217+
// Make strided read-only views into the kv cache for
218+
// this query and head.
219+
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
220+
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
221+
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
222+
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
223+
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
224+
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
225+
226+
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
227+
layer, activations, att, att_out, worker);
228+
};
229+
230+
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, pools,
231+
/*pkg_idx=*/0, func);
251232
}
252233

253234
// Different functions use different naming conventions for the number of
@@ -286,10 +267,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
286267
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
287268
const size_t cache_pos =
288269
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
289-
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
270+
env.row_ptrs[2][interleaved_idx] = reinterpret_cast<uint8_t*>(
290271
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
291272
}
292-
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
273+
kv_rows.AttachRowPtrs(env.row_ptrs[2].get());
293274
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
294275
/*add=*/nullptr, env, kv_rows);
295276

@@ -298,7 +279,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
298279
// tasks are very lightweight.
299280
env.ctx.pools.Pool(0).Run(
300281
0, kv_heads * num_interleaved,
301-
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
282+
[&](uint64_t task, size_t thread) HWY_ATTR {
302283
const size_t head = task % kv_heads;
303284
const size_t interleaved_idx = task / kv_heads;
304285
const size_t qi = div_qbatch.Remainder(interleaved_idx);
@@ -318,11 +299,13 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
318299
// Apply further processing to K.
319300
if (layer.key_norm_scale.HasPtr()) {
320301
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
321-
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim);
302+
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
303+
thread);
322304
});
323305
}
324306

325-
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
307+
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
308+
pos);
326309
CompressPerThread tls;
327310
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
328311
});

gemma/attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace gcpp {
3333
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
3434
size_t layer_idx, const LayerWeightsPtrs& layer, \
3535
const AttentionActivations& activations, float* HWY_RESTRICT att, \
36-
float* HWY_RESTRICT att_out); \
36+
float* HWY_RESTRICT att_out, size_t worker); \
3737
\
3838
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
3939
const LayerWeightsPtrs& layer, \

0 commit comments

Comments
 (0)