@@ -76,8 +76,16 @@ struct AttentionActivations {
7676 : batch_size * layer_config.heads,
7777 allocator)),
7878 vit_Q(MatFactory(" Q2" , batch_size, layer_config.qkv_dim, allocator)),
79- vit_K(MatFactory(" K2" , seq_len, layer_config.qkv_dim, allocator)),
80- vit_C(MatFactory(" C2" , batch_size, seq_len, allocator)),
79+ vit_K_T(MatFactory(
80+ " K2_T" , hwy::RoundUpTo(seq_len, kMaxBF16PerVector ),
81+ layer_config.heads *
82+ hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector ),
83+ allocator, MatPadding::kPacked)),
84+ vit_V_T(MatFactory(
85+ " V2_T" , hwy::RoundUpTo(seq_len, kMaxBF16PerVector ),
86+ layer_config.heads *
87+ hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector ),
88+ allocator, MatPadding::kPacked)),
8189 pre_att_rms_out(MatFactory(" pre_att_rms_out" , batch_size,
8290 config.model_dim, allocator)),
8391 // att is only valid for AttentionImpl::kOld.
@@ -126,7 +134,6 @@ struct AttentionActivations {
126134 q.AllocateAndAttachRowPtrs (row_ptrs);
127135 q_bf.AllocateAndAttachRowPtrs (row_ptrs);
128136 q_T.AllocateAndAttachRowPtrs (row_ptrs);
129- vit_C.AllocateAndAttachRowPtrs (row_ptrs);
130137 att_sums.AllocateAndAttachRowPtrs (row_ptrs);
131138 }
132139
@@ -136,8 +143,7 @@ struct AttentionActivations {
136143 // q_T rows are always qkv_dim!
137144
138145 vit_Q.OverrideRows (batch_size);
139- // vit_K stays seq_len!
140- vit_C.OverrideRows (batch_size);
146+ // vit_K_T and vit_V_T stay seq_len!
141147
142148 pre_att_rms_out.OverrideRows (batch_size);
143149 att.OverrideRows (batch_size);
@@ -167,8 +173,8 @@ struct AttentionActivations {
167173 MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
168174
169175 MatStorageT<float > vit_Q;
170- MatStorageT<float > vit_K ;
171- MatStorageT<float > vit_C ;
176+ MatStorageT<KV_t> vit_K_T ;
177+ MatStorageT<KV_t> vit_V_T ;
172178
173179 MatStorageT<float > pre_att_rms_out;
174180 MatStorageT<float > att; // attention vector
@@ -214,8 +220,8 @@ struct AttentionActivationsPtrs {
214220 q_bf = activations.q_bf ;
215221 q_T = activations.q_T ;
216222 vit_Q = activations.vit_Q ;
217- vit_K = activations.vit_K ;
218- vit_C = activations.vit_C ;
223+ vit_K_T = activations.vit_K_T ;
224+ vit_V_T = activations.vit_V_T ;
219225 pre_att_rms_out = activations.pre_att_rms_out ;
220226 att = activations.att ;
221227 att_out = activations.att_out ;
@@ -233,8 +239,7 @@ struct AttentionActivationsPtrs {
233239 // q_T rows are always qkv_dim!
234240
235241 vit_Q.OverrideRows (batch_size);
236- // vit_K stays seq_len!
237- vit_C.OverrideRows (batch_size);
242+ // vit_K_T and vit_V_T stay seq_len!
238243
239244 pre_att_rms_out.OverrideRows (batch_size);
240245 att.OverrideRows (batch_size);
@@ -267,8 +272,8 @@ struct AttentionActivationsPtrs {
267272 MatPtrT<BF16> q_T;
268273
269274 MatPtrT<float > vit_Q;
270- MatPtrT<float > vit_K ;
271- MatPtrT<float > vit_C ;
275+ MatPtrT<KV_t> vit_K_T ;
276+ MatPtrT<KV_t> vit_V_T ;
272277
273278 // Output of RMSNorm before attention, size batch_size x model_dim.
274279 MatPtrT<float > pre_att_rms_out;
0 commit comments