@@ -52,7 +52,9 @@ namespace HWY_NAMESPACE {
5252static HWY_INLINE void QDotK (const size_t start_pos, const size_t last_pos,
5353 const hwy::Divisor& div_seq_len,
5454 const float * HWY_RESTRICT q,
55- const MatPtrT<BF16 >& k, float * HWY_RESTRICT att) {
55+ const MatPtrT<BF16 >& k, float * HWY_RESTRICT att,
56+ const size_t worker) {
57+ PROFILER_ZONE2 (worker, " Gen.Attention.QDotK" );
5658 if (HWY_LIKELY (last_pos < static_cast <size_t >(div_seq_len.GetDivisor ()))) {
5759 // Slightly faster: no wraparound.
5860 for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@@ -71,7 +73,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
7173static void PositionalEncodingQK (float * qk, const size_t layer_idx,
7274 const LayerWeightsPtrs& layer,
7375 const AttentionActivations& activations,
74- const size_t pos, const float mul = 1 .0f ) {
76+ const size_t worker, const size_t pos,
77+ const float mul = 1 .0f ) {
7578 const size_t qkv_dim = layer.layer_config .qkv_dim ;
7679 const PostQKType& post_qk = layer.layer_config .post_qk ;
7780 // qk is either q or k, so qkv_dim is the length we operate on.
@@ -83,50 +86,49 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
8386 }
8487 // PostQKType::Rope
8588 if (post_qk == PostQKType::HalfRope) {
86- Rope (qk, qkv_dim / 2 , inv_timescale, pos);
87- if (mul != 1 .0f ) MulByConst (mul, qk, qkv_dim);
89+ Rope (qk, qkv_dim / 2 , inv_timescale, pos, worker );
90+ if (mul != 1 .0f ) MulByConst (mul, qk, qkv_dim, worker );
8891 } else {
89- RopeAndMulBy (mul, qk, qkv_dim, inv_timescale, pos);
92+ RopeAndMulBy (mul, qk, qkv_dim, inv_timescale, pos, worker );
9093 }
9194}
9295
9396// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
9497// `att_out`. Equivalent in gemma/modules.py:
9598// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
9699// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
97- static HWY_INLINE void WeightedSumV (const size_t start_pos,
98- const size_t last_pos,
99- const hwy::Divisor& div_seq_len,
100- const float * HWY_RESTRICT att,
101- const MatPtrT<BF16 >& v,
102- float * HWY_RESTRICT att_out) {
103- const size_t qkv_dim = v.Cols ();
104- hwy::ZeroBytes (att_out, qkv_dim * sizeof (*att_out));
105-
100+ static HWY_INLINE void WeightedSumV (
101+ const size_t start_pos, const size_t last_pos,
102+ const hwy::Divisor& div_seq_len, const float * HWY_RESTRICT att,
103+ const MatPtrT<BF16 >& v, float * HWY_RESTRICT att_out, const size_t worker) {
106104 if (HWY_LIKELY (last_pos < static_cast <size_t >(div_seq_len.GetDivisor ()))) {
107- // Slightly faster: no wraparound.
108- for (size_t pos = start_pos; pos <= last_pos; ++pos) {
109- MulByConstAndAdd (att[pos], v.Row (pos), att_out, v.Cols ());
105+ // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
106+ // we supported non-transposed B.
107+ // TODO: 2..4x unroll
108+ MulByConstTo (att[start_pos], v.Row (start_pos), att_out, v.Cols (), worker);
109+ for (size_t pos = start_pos + 1 ; pos <= last_pos; ++pos) {
110+ MulByConstAndAdd (att[pos], v.Row (pos), att_out, v.Cols (), worker);
110111 }
111112 } else {
112- for (size_t pos = start_pos; pos <= last_pos; ++pos) {
113- const size_t pos_modulo = div_seq_len.Remainder (pos);
114- const BF16 * HWY_RESTRICT v_ptr = v.Row (pos_modulo);
115- MulByConstAndAdd (att[pos_modulo], v_ptr, att_out, v.Cols ());
113+ {
114+ const size_t pos_mod = div_seq_len.Remainder (start_pos);
115+ MulByConstTo (att[pos_mod], v.Row (pos_mod), att_out, v.Cols (), worker);
116+ }
117+ for (size_t pos = start_pos + 1 ; pos <= last_pos; ++pos) {
118+ const size_t pos_mod = div_seq_len.Remainder (pos);
119+ MulByConstAndAdd (att[pos_mod], v.Row (pos_mod), att_out, v.Cols (), worker);
116120 }
117121 }
118122}
119123
120124// Calculates the attention outputs for a single q, which may be updated
121125// in place for RMSNorm.
122- void SingleDotSoftmaxWeightedSum (const size_t pos, const size_t start_pos,
123- const size_t last_pos, float * HWY_RESTRICT q,
124- const MatPtrT<BF16 >& k, const MatPtrT<BF16 >& v,
125- const size_t layer_idx,
126- const LayerWeightsPtrs& layer,
127- const AttentionActivations& activations,
128- float * HWY_RESTRICT att,
129- float * HWY_RESTRICT att_out) {
126+ void SingleDotSoftmaxWeightedSum (
127+ const size_t pos, const size_t start_pos, const size_t last_pos,
128+ float * HWY_RESTRICT q, const MatPtrT<BF16 >& k, const MatPtrT<BF16 >& v,
129+ const size_t layer_idx, const LayerWeightsPtrs& layer,
130+ const AttentionActivations& activations, float * HWY_RESTRICT att,
131+ float * HWY_RESTRICT att_out, const size_t worker) {
130132 const float att_cap = activations.config .att_cap ;
131133 const float query_scale = activations.query_scale ;
132134 const size_t seq_len =
@@ -136,20 +138,22 @@ void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
136138 if (layer.query_norm_scale .HasPtr ()) {
137139 CallUpcasted (&layer.query_norm_scale , [&](const auto * weights_t ) {
138140 RMSNormInplace (weights_t ->PackedScale1 (), 0 , q,
139- layer.layer_config .qkv_dim );
141+ layer.layer_config .qkv_dim , worker );
140142 });
141143 }
142144
143- PositionalEncodingQK (q, layer_idx, layer, activations, pos, query_scale);
145+ PositionalEncodingQK (q, layer_idx, layer, activations, worker, pos,
146+ query_scale);
144147
145- QDotK (start_pos, last_pos, activations.div_seq_len , q, k, att);
148+ QDotK (start_pos, last_pos, activations.div_seq_len , q, k, att, worker );
146149
147150 // SoftMax with optional SoftCap yields "probabilities" in att.
148151 const size_t att_len = HWY_MIN (last_pos + 1 , seq_len);
149- MaybeLogitsSoftCap (att_cap, att, att_len);
150- Softmax (att, att_len);
152+ MaybeLogitsSoftCap (att_cap, att, att_len, worker );
153+ Softmax (att, att_len, /* temperature= */ 1 . 0f , worker );
151154
152- WeightedSumV (start_pos, last_pos, activations.div_seq_len , att, v, att_out);
155+ WeightedSumV (start_pos, last_pos, activations.div_seq_len , att, v, att_out,
156+ worker);
153157}
154158
155159// The attention window usually starts at 0 unless `pos` is larger than
@@ -179,75 +183,52 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
179183 const size_t cache_layer_size = layer_config.CacheLayerSize ();
180184 const size_t seq_len =
181185 static_cast <size_t >(activations.div_seq_len .GetDivisor ());
186+ // All layers should have the same number of heads.
187+ HWY_DASSERT (activations.div_heads .GetDivisor () == layer_config.heads );
182188
183189 // For each head/token/query, compute Q.K, softmax, and weighted V.
184-
185- // Statically partition token/query across packages.
186- const size_t num_tq = num_tokens * div_qbatch.GetDivisor ();
187- const IndexRangePartition tq_ranges =
188- StaticPartition (IndexRange (0 , num_tq), pools.NumPackages (), 1 );
189- ParallelizeOneRange (
190- tq_ranges, pools.AllPackages (),
191- [&](const IndexRange& tq_range, const size_t pkg_idx) {
192- const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage ();
193- pools.AllClusters (pkg_idx).Run (
194- tq_range.begin (), tq_range.end (),
195- [&](const size_t tq_idx, const size_t cluster_idx) {
196- const HWY_MAYBE_UNUSED size_t cluster_base =
197- pkg_base + cluster_idx * pools.MaxWorkersPerCluster ();
198- const size_t qi = div_qbatch.Remainder (tq_idx);
199- const size_t batch_idx = div_qbatch.Divide (tq_idx);
200- auto & kv_cache = qbatch.KV (qi).kv_cache ;
201-
202- // Find the token position in the query and calculate
203- // the range of cache positions to attend to.
204- const size_t pos = qbatch.Pos (qi) + batch_idx;
205- const size_t start_pos =
206- StartPos (pos, activations.config , layer_idx);
207- size_t last_pos = pos;
208- const size_t prefix_end = qbatch.PrefixEnd (qi);
209- if (prefix_end > 0 && prefix_end - 1 > last_pos) {
210- // last_pos in QDotK and WeightedSumV is inclusive.
211- last_pos = prefix_end - 1 ;
212- }
213-
214- pools.Cluster (pkg_idx, cluster_idx)
215- .Run (
216- 0 , layer_config.heads ,
217- [&](const size_t head, size_t thread) HWY_ATTR {
190+ const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
191+ const size_t tq_idx = activations.div_heads .Divide (task);
192+ const size_t head = activations.div_heads .Remainder (task);
218193#if PROFILER_ENABLED
219- const hwy::Zone zone (cluster_base + thread,
220- zone_id_par);
194+ const hwy::Zone zone (worker, zone_id_par);
221195#endif
222196
223- const size_t head_offset =
224- (head / kHeadGroups ) * qkv_dim * 2 ;
225-
226- float * HWY_RESTRICT q =
227- activations.q .Row (tq_idx) + head * qkv_dim;
228-
229- float * HWY_RESTRICT att =
230- activations.att .Row (tq_idx) + head * seq_len;
231- float * HWY_RESTRICT att_out =
232- activations.att_out .Row (tq_idx) + head * qkv_dim;
233-
234- // Make strided read-only views into the kv cache for
235- // this query and head.
236- const size_t kv_head_offset =
237- layer_idx * cache_layer_size + head_offset;
238- MatPtrT<BF16 > k (" k_view" , Extents2D (seq_len, qkv_dim));
239- k.SetPtr (kv_cache.Row (0 ) + kv_head_offset,
240- kv_cache.Stride ());
241- MatPtrT<BF16 > v (" v_view" , Extents2D (seq_len, qkv_dim));
242- v.SetPtr (kv_cache.Row (0 ) + kv_head_offset + qkv_dim,
243- kv_cache.Stride ());
244-
245- SingleDotSoftmaxWeightedSum (pos, start_pos, last_pos, q,
246- k, v, layer_idx, layer,
247- activations, att, att_out);
248- });
249- });
250- });
197+ const size_t qi = div_qbatch.Remainder (tq_idx);
198+ const size_t batch_idx = div_qbatch.Divide (tq_idx);
199+ auto & kv_cache = qbatch.KV (qi).kv_cache ;
200+
201+ // Find the token position in the query and calculate
202+ // the range of cache positions to attend to.
203+ const size_t pos = qbatch.Pos (qi) + batch_idx;
204+ const size_t start_pos = StartPos (pos, activations.config , layer_idx);
205+ size_t last_pos = pos;
206+ const size_t prefix_end = qbatch.PrefixEnd (qi);
207+ if (prefix_end > 0 && prefix_end - 1 > last_pos) {
208+ // last_pos in QDotK and WeightedSumV is inclusive.
209+ last_pos = prefix_end - 1 ;
210+ }
211+
212+ float * HWY_RESTRICT q = activations.q .Row (tq_idx) + head * qkv_dim;
213+ float * HWY_RESTRICT att = activations.att .Row (tq_idx) + head * seq_len;
214+ float * HWY_RESTRICT att_out =
215+ activations.att_out .Row (tq_idx) + head * qkv_dim;
216+
217+ // Make strided read-only views into the kv cache for
218+ // this query and head.
219+ const size_t head_offset = (head / kHeadGroups ) * qkv_dim * 2 ;
220+ const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
221+ MatPtrT<BF16 > k (" k_view" , Extents2D (seq_len, qkv_dim));
222+ k.SetPtr (kv_cache.Row (0 ) + kv_head_offset, kv_cache.Stride ());
223+ MatPtrT<BF16 > v (" v_view" , Extents2D (seq_len, qkv_dim));
224+ v.SetPtr (kv_cache.Row (0 ) + kv_head_offset + qkv_dim, kv_cache.Stride ());
225+
226+ SingleDotSoftmaxWeightedSum (pos, start_pos, last_pos, q, k, v, layer_idx,
227+ layer, activations, att, att_out, worker);
228+ };
229+
230+ ParallelFor (num_tokens * div_qbatch.GetDivisor () * layer_config.heads , pools,
231+ /* pkg_idx=*/ 0 , func);
251232}
252233
253234// Different functions use different naming conventions for the number of
@@ -286,10 +267,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
286267 const size_t batch_idx = div_qbatch.Divide (interleaved_idx);
287268 const size_t cache_pos =
288269 activations.div_seq_len .Remainder (qbatch.Pos (qi) + batch_idx);
289- env.row_ptrs [0 ][interleaved_idx] = reinterpret_cast <uint8_t *>(
270+ env.row_ptrs [2 ][interleaved_idx] = reinterpret_cast <uint8_t *>(
290271 qbatch.KV (qi).kv_cache .Row (cache_pos) + layer_idx * cache_layer_size);
291272 }
292- kv_rows.AttachRowPtrs (env.row_ptrs [0 ].get ());
273+ kv_rows.AttachRowPtrs (env.row_ptrs [2 ].get ());
293274 CallMatMul (activations.pre_att_rms_out , layer.qkv_einsum_w2 ,
294275 /* add=*/ nullptr , env, kv_rows);
295276
@@ -298,7 +279,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
298279 // tasks are very lightweight.
299280 env.ctx .pools .Pool (0 ).Run (
300281 0 , kv_heads * num_interleaved,
301- [&](uint64_t task, size_t /* thread*/ ) HWY_ATTR {
282+ [&](uint64_t task, size_t thread) HWY_ATTR {
302283 const size_t head = task % kv_heads;
303284 const size_t interleaved_idx = task / kv_heads;
304285 const size_t qi = div_qbatch.Remainder (interleaved_idx);
@@ -318,11 +299,13 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
318299 // Apply further processing to K.
319300 if (layer.key_norm_scale .HasPtr ()) {
320301 CallUpcasted (&layer.key_norm_scale , [&](const auto * weights_t ) {
321- RMSNormInplace (weights_t ->PackedScale1 (), 0 , kv_f32, qkv_dim);
302+ RMSNormInplace (weights_t ->PackedScale1 (), 0 , kv_f32, qkv_dim,
303+ thread);
322304 });
323305 }
324306
325- PositionalEncodingQK (kv_f32, layer_idx, layer, activations, pos);
307+ PositionalEncodingQK (kv_f32, layer_idx, layer, activations, thread,
308+ pos);
326309 CompressPerThread tls;
327310 Compress (kv_f32, 2 * qkv_dim, tls, MakeSpan (kv, 2 * qkv_dim), 0 );
328311 });
0 commit comments