Skip to content

Commit bea8b1c

Browse files
theraysmithcopybara-github
authored andcommitted
Replaced attention in ViT with flash - 8x speedup of image tokenizer on AMD
PiperOrigin-RevId: 880877209
1 parent 029cfd0 commit bea8b1c

9 files changed

Lines changed: 304 additions & 149 deletions

File tree

BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ cc_library(
555555
":ops",
556556
":tensor_stats",
557557
":threading_context",
558+
"@highway//:abort_header_only",
558559
],
559560
)
560561

@@ -678,6 +679,7 @@ cc_library(
678679
":attention",
679680
":basics",
680681
":configs",
682+
":flash_structs",
681683
":gemma_args",
682684
":kv_cache",
683685
":mat",

gemma/activations.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,16 @@ struct AttentionActivations {
7676
: batch_size * layer_config.heads,
7777
allocator)),
7878
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
79-
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)),
80-
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
79+
vit_K_T(MatFactory(
80+
"K2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
81+
layer_config.heads *
82+
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
83+
allocator, MatPadding::kPacked)),
84+
vit_V_T(MatFactory(
85+
"V2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
86+
layer_config.heads *
87+
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
88+
allocator, MatPadding::kPacked)),
8189
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
8290
config.model_dim, allocator)),
8391
// att is only valid for AttentionImpl::kOld.
@@ -126,7 +134,6 @@ struct AttentionActivations {
126134
q.AllocateAndAttachRowPtrs(row_ptrs);
127135
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
128136
q_T.AllocateAndAttachRowPtrs(row_ptrs);
129-
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
130137
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
131138
}
132139

@@ -136,8 +143,7 @@ struct AttentionActivations {
136143
// q_T rows are always qkv_dim!
137144

138145
vit_Q.OverrideRows(batch_size);
139-
// vit_K stays seq_len!
140-
vit_C.OverrideRows(batch_size);
146+
// vit_K_T and vit_V_T stay seq_len!
141147

142148
pre_att_rms_out.OverrideRows(batch_size);
143149
att.OverrideRows(batch_size);
@@ -167,8 +173,8 @@ struct AttentionActivations {
167173
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
168174

169175
MatStorageT<float> vit_Q;
170-
MatStorageT<float> vit_K;
171-
MatStorageT<float> vit_C;
176+
MatStorageT<KV_t> vit_K_T;
177+
MatStorageT<KV_t> vit_V_T;
172178

173179
MatStorageT<float> pre_att_rms_out;
174180
MatStorageT<float> att; // attention vector
@@ -214,8 +220,8 @@ struct AttentionActivationsPtrs {
214220
q_bf = activations.q_bf;
215221
q_T = activations.q_T;
216222
vit_Q = activations.vit_Q;
217-
vit_K = activations.vit_K;
218-
vit_C = activations.vit_C;
223+
vit_K_T = activations.vit_K_T;
224+
vit_V_T = activations.vit_V_T;
219225
pre_att_rms_out = activations.pre_att_rms_out;
220226
att = activations.att;
221227
att_out = activations.att_out;
@@ -233,8 +239,7 @@ struct AttentionActivationsPtrs {
233239
// q_T rows are always qkv_dim!
234240

235241
vit_Q.OverrideRows(batch_size);
236-
// vit_K stays seq_len!
237-
vit_C.OverrideRows(batch_size);
242+
// vit_K_T and vit_V_T stay seq_len!
238243

239244
pre_att_rms_out.OverrideRows(batch_size);
240245
att.OverrideRows(batch_size);
@@ -267,8 +272,8 @@ struct AttentionActivationsPtrs {
267272
MatPtrT<BF16> q_T;
268273

269274
MatPtrT<float> vit_Q;
270-
MatPtrT<float> vit_K;
271-
MatPtrT<float> vit_C;
275+
MatPtrT<KV_t> vit_K_T;
276+
MatPtrT<KV_t> vit_V_T;
272277

273278
// Output of RMSNorm before attention, size batch_size x model_dim.
274279
MatPtrT<float> pre_att_rms_out;

gemma/flash_attention.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,3 +2260,21 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
22602260
} // namespace HWY_NAMESPACE
22612261
} // namespace gcpp
22622262
HWY_AFTER_NAMESPACE();
2263+
2264+
#if HWY_ONCE
2265+
namespace gcpp {
2266+
HWY_EXPORT(DispatchTileFlashAttention148);
2267+
2268+
void DispatchDispatchTileFlashAttention148(
2269+
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
2270+
const MatPtrT<KV_t>& v, const size_t layer_idx,
2271+
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
2272+
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
2273+
AttentionImpl attention_impl) {
2274+
HWY_DYNAMIC_DISPATCH(DispatchTileFlashAttention148)(
2275+
params, q, k, v, layer_idx, activations, att_out, qkv_dim, ctx, worker,
2276+
attention_impl);
2277+
}
2278+
2279+
} // namespace gcpp
2280+
#endif

gemma/flash_attention.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,6 @@ namespace gcpp {
4242
const MatPtr& query_norm_scale, size_t layer_idx, \
4343
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
4444
\
45-
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
46-
const BF16* HWY_RESTRICT q, \
47-
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
48-
size_t layer_idx, \
49-
const AttentionActivationsPtrs& activations, \
50-
float* HWY_RESTRICT att_out, \
51-
ThreadingContext& ctx, size_t worker); \
52-
\
5345
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
5446
size_t total_tasks, size_t target_parallelism); \
5547
\
@@ -83,6 +75,13 @@ HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION)
8375

8476
#undef GEMMA_DECL_FLASH_ATTENTION
8577

78+
void DispatchDispatchTileFlashAttention148(
79+
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
80+
const MatPtrT<KV_t>& v, const size_t layer_idx,
81+
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
82+
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
83+
AttentionImpl attention_impl);
84+
8685
} // namespace gcpp
8786

8887
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_

gemma/tiled_attention_test.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,6 @@ void TestAttentionMultipleTokens() {
544544
test_env.SetupWeights();
545545
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
546546
FillMatPtrT(test_env.activations->attention.q);
547-
FillMatPtrT(test_env.activations->attention.vit_Q);
548-
FillMatPtrT(test_env.activations->attention.vit_K);
549547
FillMatPtrT(test_env.activations->attention.att);
550548
FillMatPtrT(test_env.activations->attention.att_out);
551549
FillMatPtrT(test_env.activations->attention.softmax_max);
@@ -590,8 +588,6 @@ void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() {
590588
test_env.SetupWeights();
591589
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
592590
FillMatPtrT(test_env.activations->attention.q);
593-
FillMatPtrT(test_env.activations->attention.vit_Q);
594-
FillMatPtrT(test_env.activations->attention.vit_K);
595591
FillMatPtrT(test_env.activations->attention.att);
596592
FillMatPtrT(test_env.activations->attention.att_out);
597593
FillMatPtrT(test_env.activations->attention.softmax_max);
@@ -763,8 +759,6 @@ void TestAttentionMultipleTokensBF16() {
763759
test_env.SetupWeights();
764760
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
765761
FillMatPtrT(test_env.activations->attention.q);
766-
FillMatPtrT(test_env.activations->attention.vit_Q);
767-
FillMatPtrT(test_env.activations->attention.vit_K);
768762
FillMatPtrT(test_env.activations->attention.att);
769763
FillMatPtrT(test_env.activations->attention.att_out);
770764
FillMatPtrT(test_env.activations->attention.softmax_max);
@@ -807,8 +801,6 @@ void TestAttentionMultipleTokensInt8() {
807801
test_env.SetupWeights();
808802
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
809803
FillMatPtrT(test_env.activations->attention.q);
810-
FillMatPtrT(test_env.activations->attention.vit_Q);
811-
FillMatPtrT(test_env.activations->attention.vit_K);
812804
FillMatPtrT(test_env.activations->attention.att);
813805
FillMatPtrT(test_env.activations->attention.att_out);
814806
FillMatPtrT(test_env.activations->attention.softmax_max);

0 commit comments

Comments
 (0)