Skip to content

Commit 267dbe0

Browse files
pcullitoncopybara-github
authored andcommitted
Fixes to activations and tensor params
PiperOrigin-RevId: 824820179
1 parent 5a05857 commit 267dbe0

5 files changed

Lines changed: 14 additions & 10 deletions

File tree

gemma/activations.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ struct Activations {
203203
ffw_out.OverrideRows(batch_size);
204204

205205
attention_storage.SetBatchSize(batch_size);
206+
attention.q = attention_storage.q;
207+
attention.q_T = attention_storage.q_T;
208+
attention.pre_att_rms_out = attention_storage.pre_att_rms_out;
209+
attention.att = attention_storage.att;
210+
attention.att_out = attention_storage.att_out;
211+
attention.att_sums = attention_storage.att_sums;
206212
}
207213

208214
const LayerConfig& layer_config;

gemma/attention.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ static HWY_INLINE void WeightedSumV(
130130
void SingleDotSoftmaxWeightedSum(
131131
const size_t pos, const size_t start_pos, const size_t last_pos,
132132
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
133-
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
133+
const MatPtr& query_norm_scale, const size_t layer_idx,
134134
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
135135
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
136136
const float att_cap = activations.config.att_cap;
@@ -169,7 +169,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
169169
}
170170

171171
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
172-
const MatPtrT<float>& query_norm_scale,
172+
const MatPtr& query_norm_scale,
173173
AttentionActivationsPtrs& activations,
174174
QBatch& qbatch, ThreadingContext& ctx) {
175175
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);

gemma/attention.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ namespace gcpp {
3838
void SingleDotSoftmaxWeightedSum( \
3939
const size_t pos, const size_t start_pos, const size_t last_pos, \
4040
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
41-
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
41+
const MatPtr& query_norm_scale, size_t layer_idx, \
4242
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
4343
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
4444
\
4545
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
46-
const MatPtrT<float>& query_norm_scale, \
46+
const MatPtr& query_norm_scale, \
4747
AttentionActivationsPtrs& activations, \
4848
QBatch& qbatch, ThreadingContext& ctx); \
4949
\

gemma/flash_attention.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
9191
// Updates q in place for RMSNorm and positional encoding.
9292
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
9393
MatPtrT<float>& q,
94-
const MatPtrT<float>& query_norm_scale,
94+
const MatPtr& query_norm_scale,
9595
const size_t layer_idx,
9696
const AttentionActivationsPtrs& activations,
9797
ThreadingContext& ctx) {
@@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
592592
// grouped together so that mode 1 or 2 can be used, and choosing which of the
593593
// 3 modes to use for best efficiency.
594594
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
595-
const size_t layer_idx,
596-
const MatPtrT<float>& query_norm_scale,
595+
const size_t layer_idx, const MatPtr& query_norm_scale,
597596
AttentionActivationsPtrs& activations, QBatch& qbatch,
598597
ThreadingContext& ctx) {
599598
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);

gemma/flash_attention.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace gcpp {
3030
namespace NAMESPACE { \
3131
void RMSNormAndPositionalEncoding( \
3232
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
33-
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
33+
const MatPtr& query_norm_scale, size_t layer_idx, \
3434
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
3535
\
3636
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
@@ -45,8 +45,7 @@ namespace gcpp {
4545
size_t total_tasks, size_t target_parallelism); \
4646
\
4747
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
48-
size_t layer_idx, \
49-
const MatPtrT<float>& query_norm_scale, \
48+
size_t layer_idx, const MatPtr& query_norm_scale, \
5049
AttentionActivationsPtrs& activations, QBatch& qbatch, \
5150
ThreadingContext& ctx); \
5251
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \

0 commit comments

Comments
 (0)