@@ -43,30 +43,65 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config)
4343 /* v_head_dim=*/ DefaultQKVHeadDim(config)) {}
4444
4545MLASelfAttention::MLASelfAttention (const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
46- int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim)
46+ int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
47+ bool q_down_proj_use_tp, bool kv_down_proj_use_tp)
4748 : CloneableModule(kType ), config_(config) {
48- SetupAttention (config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim);
49-
50- modules_[kQAProjLayerName ] = std::make_shared<nn::Linear>(
51- /* in_features=*/ n_embd_,
52- /* out_features=*/ q_lora_rank_,
53- /* bias=*/ config_.add_bias_linear );
54- modules_[kQANormLayerName ] = std::make_shared<nn::RMSNorm>(q_lora_rank_, config_.norm_eps );
55- modules_[kQBProjLayerName ] = std::make_shared<nn::parallel::ColumnParallelLinear>(
56- /* in_features=*/ q_lora_rank_,
57- /* out_features=*/ n_head_ * qk_head_dim_,
58- /* bias=*/ config_.add_bias_linear ,
59- /* gather_output=*/ false ,
60- /* input_is_parallel=*/ false ,
61- /* skip_bias_add=*/ false ,
62- /* sequence_parallel=*/ nn::parallel::global::GetSequenceParallelEnabled ());
49+ SetupAttention (config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
50+ q_down_proj_use_tp, kv_down_proj_use_tp);
51+
52+ if (use_q_lora_) {
53+ if (q_down_proj_use_tp_) {
54+ modules_[kLinearQDownProjLayerName ] = std::make_shared<nn::parallel::ColumnParallelLinear>(
55+ /* in_features=*/ n_embd_,
56+ /* out_features=*/ q_lora_rank_,
57+ /* bias=*/ config_.add_bias_linear ,
58+ /* gather_output=*/ false ,
59+ /* input_is_parallel=*/ false ,
60+ /* skip_bias_add=*/ false ,
61+ /* sequence_parallel=*/ nn::parallel::global::GetSequenceParallelEnabled ());
62+ } else {
63+ modules_[kLinearQDownProjLayerName ] = std::make_shared<nn::Linear>(
64+ /* in_features=*/ n_embd_,
65+ /* out_features=*/ q_lora_rank_,
66+ /* bias=*/ config_.add_bias_linear );
67+ }
68+ modules_[kQLayerNormLayerName ] = std::make_shared<nn::RMSNorm>(q_lora_rank_, config_.norm_eps );
69+ modules_[kLinearQUpProjLayerName ] = std::make_shared<nn::parallel::ColumnParallelLinear>(
70+ /* in_features=*/ q_lora_rank_,
71+ /* out_features=*/ n_head_ * qk_head_dim_,
72+ /* bias=*/ config_.add_bias_linear ,
73+ /* gather_output=*/ false ,
74+ /* input_is_parallel=*/ false ,
75+ /* skip_bias_add=*/ false ,
76+ /* sequence_parallel=*/ nn::parallel::global::GetSequenceParallelEnabled ());
77+ } else {
78+ modules_[kLinearQProjLayerName ] = std::make_shared<nn::parallel::ColumnParallelLinear>(
79+ /* in_features=*/ n_embd_,
80+ /* out_features=*/ n_head_ * qk_head_dim_,
81+ /* bias=*/ config_.add_bias_linear ,
82+ /* gather_output=*/ false ,
83+ /* input_is_parallel=*/ false ,
84+ /* skip_bias_add=*/ false ,
85+ /* sequence_parallel=*/ nn::parallel::global::GetSequenceParallelEnabled ());
86+ }
6387
64- modules_[kKVAProjLayerName ] = std::make_shared<nn::Linear>(
65- /* in_features=*/ n_embd_,
66- /* out_features=*/ kv_lora_rank_ + qk_rope_head_dim_,
67- /* bias=*/ config_.add_bias_linear );
68- modules_[kKVANormLayerName ] = std::make_shared<nn::RMSNorm>(kv_lora_rank_, config_.norm_eps );
69- modules_[kKVBProjLayerName ] = std::make_shared<nn::parallel::ColumnParallelLinear>(
88+ if (kv_down_proj_use_tp_) {
89+ modules_[kLinearKVDownProjLayerName ] = std::make_shared<nn::parallel::ColumnParallelLinear>(
90+ /* in_features=*/ n_embd_,
91+ /* out_features=*/ kv_lora_rank_ + qk_rope_head_dim_,
92+ /* bias=*/ config_.add_bias_linear ,
93+ /* gather_output=*/ false ,
94+ /* input_is_parallel=*/ false ,
95+ /* skip_bias_add=*/ false ,
96+ /* sequence_parallel=*/ nn::parallel::global::GetSequenceParallelEnabled ());
97+ } else {
98+ modules_[kLinearKVDownProjLayerName ] = std::make_shared<nn::Linear>(
99+ /* in_features=*/ n_embd_,
100+ /* out_features=*/ kv_lora_rank_ + qk_rope_head_dim_,
101+ /* bias=*/ config_.add_bias_linear );
102+ }
103+ modules_[kKVLayerNormLayerName ] = std::make_shared<nn::RMSNorm>(kv_lora_rank_, config_.norm_eps );
104+ modules_[kLinearKVUpProjLayerName ] = std::make_shared<nn::parallel::ColumnParallelLinear>(
70105 /* in_features=*/ kv_lora_rank_,
71106 /* out_features=*/ n_head_ * (qk_nope_head_dim_ + v_head_dim_),
72107 /* bias=*/ config_.add_bias_linear ,
@@ -75,7 +110,7 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo
75110 /* skip_bias_add=*/ false ,
76111 /* sequence_parallel=*/ nn::parallel::global::GetSequenceParallelEnabled ());
77112
78- modules_[kCProjLayerName ] = std::make_shared<nn::parallel::RowParallelLinear>(
113+ modules_[kLinearProjLayerName ] = std::make_shared<nn::parallel::RowParallelLinear>(
79114 /* in_features=*/ n_head_ * v_head_dim_,
80115 /* out_features=*/ n_embd_,
81116 /* bias=*/ config_.add_bias_linear ,
@@ -89,12 +124,13 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo
89124}
90125
91126void MLASelfAttention::SetupAttention (const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
92- int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) {
127+ int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
128+ bool q_down_proj_use_tp, bool kv_down_proj_use_tp) {
93129 auto tp_world_size = nn::parallel::global::GetTensorParallelSize ();
94130
95131 CHECK_EQ (config.n_embd % config.n_head , 0 ) << " n_embd must be divisible by n_head" ;
96132 CHECK_EQ (config.n_head % tp_world_size, 0 ) << " n_head must be divisible by TP world size" ;
97- CHECK_GT (q_lora_rank, 0 ) << " q_lora_rank must be positive" ;
133+ CHECK (q_lora_rank == - 1 || q_lora_rank > 0 ) << " q_lora_rank must be positive, or -1 to disable q LoRA " ;
98134 CHECK_GT (kv_lora_rank, 0 ) << " kv_lora_rank must be positive" ;
99135 CHECK_GT (qk_nope_head_dim, 0 ) << " qk_nope_head_dim must be positive" ;
100136 CHECK_GT (qk_rope_head_dim, 0 ) << " qk_rope_head_dim must be positive" ;
@@ -105,80 +141,151 @@ void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q
105141 n_embd_ = config.n_embd ;
106142 local_n_head_ = n_head_ / tp_world_size;
107143
108- q_lora_rank_ = q_lora_rank;
144+ use_q_lora_ = q_lora_rank != -1 ;
145+ q_lora_rank_ = use_q_lora_ ? q_lora_rank : 0 ;
109146 kv_lora_rank_ = kv_lora_rank;
110147 qk_nope_head_dim_ = qk_nope_head_dim;
111148 qk_rope_head_dim_ = qk_rope_head_dim;
112149 qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_;
113150 v_head_dim_ = v_head_dim;
151+ q_down_proj_use_tp_ = q_down_proj_use_tp;
152+ kv_down_proj_use_tp_ = kv_down_proj_use_tp;
114153}
115154
116155std::vector<std::shared_ptr<infini_train::Tensor>>
117156MLASelfAttention::Forward (const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
118157 CHECK_GE (x.size (), 1 ) << " MLASelfAttention expects at least hidden states" ;
119158
159+ // x[0]: (B, T_local, C)
120160 const auto B = x[0 ]->Dims ()[0 ];
121161 const auto C = x[0 ]->Dims ()[2 ];
122162 CHECK_EQ (C, n_embd_) << " hidden size must match n_embd" ;
123163
164+ // freqs_cis: (T, D_rope / 2, 2)
124165 const auto freqs_cis = x.size () > 1 ? x[1 ] : nullptr ;
166+ // external_mask: (1, 1, T, T)
125167 const auto external_mask = x.size () > 3 ? x[3 ] : nullptr ;
126168 if (config_.attention_type == AttentionType::kRoPE ) {
127169 CHECK (freqs_cis != nullptr ) << " freqs_cis is null." ;
128170 }
129171
130- // (B, T, C) -> q_a -> RMSNorm -> q_b -> (B, T, H_local * (D_nope + D_rope))
131- auto q = (*modules_[kQAProjLayerName ])({x[0 ]})[0 ];
132- q = (*modules_[kQANormLayerName ])({q})[0 ];
133- q = (*modules_[kQBProjLayerName ])({q})[0 ];
172+ const bool sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled ();
173+
174+ // ----------- Q PATH -----------
175+ // Q path, align with Megatron:
176+ // - q_lora_rank == -1 -> linear_q_proj directly;
177+ // - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj.
178+ std::shared_ptr<Tensor> q;
179+ if (use_q_lora_) {
180+ // linear_q_down_proj:
181+ // non-TP path: (B, T_local, C) -> (B, T_local, R_q)
182+ // TP path before gather: (B, T, C) -> (B, T, R_q / TP)
183+ // - Note that ColumnParallelLinear would perform a GatherFromSPRegion in the beginning
184+ auto q_compressed = (*modules_[kLinearQDownProjLayerName ])({x[0 ]})[0 ];
185+ if (q_down_proj_use_tp_ && q_compressed->Dims ().back () != q_lora_rank_) {
186+ // Gather the sharded latent dimension: (B, T, R_q / TP) -> (B, T, R_q).
187+ q_compressed = nn::parallel::GatherFromTPRegionFunc (q_compressed)[0 ];
188+ if (sequence_parallel_enabled) {
189+ // Keep the q_up input sequence-sharded: (B, T_full, R_q) -> (B, T_local, R_q).
190+ q_compressed = nn::parallel::ScatterToSPRegionFunc (q_compressed)[0 ];
191+ }
192+ }
193+ // q_layernorm preserves shape: (B, T_local, R_q)
194+ q_compressed = (*modules_[kQLayerNormLayerName ])({q_compressed})[0 ];
195+ // linear_q_up_proj: (B, T_local, R_q) -> (B, T, H_local * (D_nope + D_rope)).
196+ q = (*modules_[kLinearQUpProjLayerName ])({q_compressed})[0 ];
197+ } else {
198+ // linear_q_proj direct path: (B, T, C) -> (B, T, H_local * (D_nope + D_rope)).
199+ q = (*modules_[kLinearQProjLayerName ])({x[0 ]})[0 ];
200+ }
201+
202+ // T should be the full seqlen after the q projection path gathers sequence-parallel input.
134203 const auto T = q->Dims ()[1 ];
204+ // q: (B, T, H_local * D_qk) -> (B, T, H_local, D_qk)
205+ // qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_
135206 q = q->View ({B, T, local_n_head_, qk_head_dim_});
136207
208+ // q_nope: (B, T, H_local, D_nope), q_pos_emb: (B, T, H_local, D_rope)
137209 auto q_nope = q->Slice (-1 , 0 , qk_nope_head_dim_);
138- auto q_pe = q->Slice (-1 , qk_nope_head_dim_, qk_head_dim_);
210+ auto q_pos_emb = q->Slice (-1 , qk_nope_head_dim_, qk_head_dim_);
211+
212+ // ----------- KV PATH -----------
213+ // linear_kv_down_proj:
214+ // non-TP path: (B, T_local, C) -> (B, T_local, R_kv + D_rope)
215+ // TP path before gather: (B, T, C) -> (B, T, (R_kv + D_rope) / TP)
216+ auto compressed_kv_with_pe = (*modules_[kLinearKVDownProjLayerName ])({x[0 ]})[0 ];
217+ const auto kv_down_proj_out_dim = kv_lora_rank_ + qk_rope_head_dim_;
218+ const bool kv_down_proj_output_is_sharded = compressed_kv_with_pe->Dims ().back () != kv_down_proj_out_dim;
219+ if (kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded) {
220+ // Gather latent+RoPE dim: (B, T, (R_kv + D_rope) / TP) -> (B, T, R_kv + D_rope)
221+ compressed_kv_with_pe = nn::parallel::GatherFromTPRegionFunc (compressed_kv_with_pe)[0 ];
222+ }
139223
140- // (B, T, C) -> kv_a -> compressed kv latent and shared RoPE key.
141- auto compressed_kv_with_pe = (*modules_[kKVAProjLayerName ])({x[0 ]})[0 ];
224+ // compressed_kv: (B, T_local, R_kv), k_pos_emb: (B, T_local, D_rope)
142225 auto compressed_kv = compressed_kv_with_pe->Slice (-1 , 0 , kv_lora_rank_);
143- auto k_pe = compressed_kv_with_pe->Slice (-1 , kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)
144- ->Contiguous ();
145- if (nn::parallel::global::GetSequenceParallelEnabled ()) {
146- k_pe = nn::parallel::GatherFromSPRegionFunc (k_pe)[0 ];
226+ auto k_pos_emb = compressed_kv_with_pe->Slice (-1 , kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous ();
227+ const bool k_pos_emb_has_full_sequence = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded
228+ && sequence_parallel_enabled;
229+ if (k_pos_emb_has_full_sequence) {
230+ // k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj.
231+ // compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv)
232+ compressed_kv = nn::parallel::ScatterToSPRegionFunc (compressed_kv)[0 ];
233+ } else if (sequence_parallel_enabled) {
234+ // Replicated down-proj path produces local k_pos_emb; gather it for attention.
235+ // k_pos_emb: (B, T_local, D_rope) -> (B, T, D_rope)
236+ k_pos_emb = nn::parallel::GatherFromSPRegionFunc (k_pos_emb)[0 ];
147237 }
148- k_pe = k_pe->View ({B, T, 1 , qk_rope_head_dim_});
238+ // k_pos_emb: (B, T, D_rope) -> (B, T, 1, D_rope), shared across local heads.
239+ k_pos_emb = k_pos_emb->View ({B, T, 1 , qk_rope_head_dim_});
149240
150- // (B, T, R_kv) -> RMSNorm -> kv_b -> (B, T, H_local * (D_nope + D_v))
151- auto kv = (*modules_[kKVANormLayerName ])({compressed_kv})[0 ];
152- kv = (*modules_[kKVBProjLayerName ])({kv})[0 ];
241+ // (B, T, R_kv) -> kv_layernorm -> linear_kv_up_proj -> (B, T, H_local * (D_nope + D_v))
242+ // kv_layernorm preserves compressed_kv shape: (B, T_local, R_kv)
243+ auto kv = (*modules_[kKVLayerNormLayerName ])({compressed_kv})[0 ];
244+ // linear_kv_up_proj: (B, T_local, R_kv) -> (B, T, H_local * (D_nope + D_v))
245+ kv = (*modules_[kLinearKVUpProjLayerName ])({kv})[0 ];
246+ // kv: (B, T, H_local * (D_nope + D_v)) -> (B, T, H_local, D_nope + D_v)
153247 kv = kv->View ({B, T, local_n_head_, qk_nope_head_dim_ + v_head_dim_});
248+ // k_nope: (B, T, H_local, D_nope), v: (B, T, H_local, D_v)
154249 auto k_nope = kv->Slice (-1 , 0 , qk_nope_head_dim_);
155250 auto v = kv->Slice (-1 , qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_);
156251
157252 if (config_.attention_type == AttentionType::kRoPE ) {
158- std::tie (q_pe, k_pe) = ApplyRotaryEmbedding (q_pe, k_pe, freqs_cis);
253+ // q_pos_emb: (B, T, H_local, D_rope), k_pos_emb: (B, T, 1, D_rope)
254+ std::tie (q_pos_emb, k_pos_emb) = ApplyRotaryEmbedding (q_pos_emb, k_pos_emb, freqs_cis);
159255 }
160256
161- k_pe = k_pe->RepeatInterleave (local_n_head_, 2 );
162- q = nn::function::Concat (std::vector<std::shared_ptr<Tensor>>{q_nope, q_pe}, -1 );
163- auto k = nn::function::Concat (std::vector<std::shared_ptr<Tensor>>{k_nope, k_pe}, -1 );
257+ // k_pos_emb: (B, T, 1, D_rope) -> (B, T, H_local, D_rope)
258+ k_pos_emb = k_pos_emb->RepeatInterleave (local_n_head_, 2 );
259+ // q: (B, T, H_local, D_qk), k: (B, T, H_local, D_qk)
260+ q = nn::function::Concat (std::vector<std::shared_ptr<Tensor>>{q_nope, q_pos_emb}, -1 );
261+ auto k = nn::function::Concat (std::vector<std::shared_ptr<Tensor>>{k_nope, k_pos_emb}, -1 );
164262
165- // (B, T, H_local, D) -> (B, H_local, T, D)
263+ // ----------- CORE ATTN -----------
264+ // q/k: (B, T, H_local, D_qk) -> (B, H_local, T, D_qk)
265+ // v: (B, T, H_local, D_v) -> (B, H_local, T, D_v)
166266 q = q->Transpose (1 , 2 );
167267 k = k->Transpose (1 , 2 );
168268 v = v->Transpose (1 , 2 );
169269
270+ // att: (B, H_local, T, T)
170271 auto att = q->Matmul (k->Transpose (-2 , -1 )) * (1.0 / std::sqrt (static_cast <float >(qk_head_dim_)));
171272 if (external_mask) {
172273 att = att->MaskedFill (external_mask, std::numeric_limits<float >::lowest ());
173274 } else {
275+ // mask: (1, 1, T, T)
174276 auto mask = buffers_[kParamBiasName ]->Slice ({0 , 0 , 0 , 0 }, {1 , 1 , T, T}, {1 , 1 , 1 , 1 });
175277 att = att->MaskedFill (mask == 0 , -std::numeric_limits<float >::infinity ());
176278 }
279+ // att: (B, H_local, T, T)
177280 att = nn::function::Softmax (att, -1 );
178281
282+ // y: (B, H_local, T, D_v)
179283 auto y = att->Matmul (v);
284+ // y: (B, H_local, T, D_v) -> (B, T, H_local, D_v) -> (B, T, H_local * D_v)
180285 y = y->Transpose (1 , 2 )->Contiguous ()->View ({B, T, local_n_head_ * v_head_dim_});
181- y = (*modules_[kCProjLayerName ])({y})[0 ];
286+ // linear_proj: (B, T, H_local * D_v) -> (B, T, C)
287+ y = (*modules_[kLinearProjLayerName ])({y})[0 ];
288+
182289 return {y};
183290}
184291
0 commit comments