1919#include " infini_train/include/tensor.h"
2020
2121namespace infini_train ::nn {
22- namespace {
23- int64_t DefaultQKVHeadDim (const TransformerConfig &config) {
24- CHECK_EQ (config.n_embd % config.n_head , 0 ) << " n_embd must be divisible by n_head" ;
25- return config.n_embd / config.n_head ;
26- }
27-
28- int64_t DefaultQKRoPEHeadDim (const TransformerConfig &config) {
29- return DefaultQKVHeadDim (config);
30- }
3122
32- int64_t DefaultQKNoPEHeadDim (const TransformerConfig &config) {
33- return DefaultQKVHeadDim (config);
34- }
35- } // namespace
36-
37- MLASelfAttention::MLASelfAttention (const TransformerConfig &config)
38- : MLASelfAttention(config,
39- /* q_lora_rank=*/ config.n_embd,
40- /* kv_lora_rank=*/ config.n_embd,
41- /* qk_nope_head_dim=*/ DefaultQKNoPEHeadDim(config),
42- /* qk_rope_head_dim=*/ DefaultQKRoPEHeadDim(config),
43- /* v_head_dim=*/ DefaultQKVHeadDim(config)) {}
44-
45- MLASelfAttention::MLASelfAttention (const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
46- int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
47- bool q_down_proj_use_tp, bool kv_down_proj_use_tp)
48- : CloneableModule(kType ), config_(config) {
49- SetupAttention (config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
50- q_down_proj_use_tp, kv_down_proj_use_tp);
23+ MLASelfAttention::MLASelfAttention (const TransformerConfig &config) : CloneableModule(kType ), config_(config) {
24+ SetupAttention (config);
5125
5226 if (use_q_lora_) {
5327 if (q_down_proj_use_tp_) {
@@ -123,15 +97,19 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo
12397 ->View ({1 , 1 , config_.block_size , config_.block_size });
12498}
12599
126- void MLASelfAttention::SetupAttention (const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
127- int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
128- bool q_down_proj_use_tp, bool kv_down_proj_use_tp) {
100+ void MLASelfAttention::SetupAttention (const TransformerConfig &config) {
129101 auto tp_world_size = nn::parallel::global::GetTensorParallelSize ();
130102
131103 CHECK_EQ (config.n_embd % config.n_head , 0 ) << " n_embd must be divisible by n_head" ;
132104 CHECK_EQ (config.n_head % tp_world_size, 0 ) << " n_head must be divisible by TP world size" ;
133- CHECK (q_lora_rank == -1 || q_lora_rank > 0 ) << " q_lora_rank must be positive, or -1 to disable q LoRA" ;
134- CHECK_GT (kv_lora_rank, 0 ) << " kv_lora_rank must be positive" ;
105+ CHECK (!config.q_lora_rank .has_value () || config.q_lora_rank .value () > 0 ) << " q_lora_rank must be positive when set" ;
106+
107+ const auto default_head_dim = config.n_embd / config.n_head ;
108+ const int64_t kv_lora_rank = config.kv_lora_rank > 0 ? config.kv_lora_rank : config.n_embd ;
109+ const int64_t qk_nope_head_dim = config.qk_nope_head_dim > 0 ? config.qk_nope_head_dim : default_head_dim;
110+ const int64_t qk_rope_head_dim = config.qk_rope_head_dim > 0 ? config.qk_rope_head_dim : default_head_dim;
111+ const int64_t v_head_dim = config.v_head_dim > 0 ? config.v_head_dim : default_head_dim;
112+
135113 CHECK_GT (qk_nope_head_dim, 0 ) << " qk_nope_head_dim must be positive" ;
136114 CHECK_GT (qk_rope_head_dim, 0 ) << " qk_rope_head_dim must be positive" ;
137115 CHECK_GT (v_head_dim, 0 ) << " v_head_dim must be positive" ;
@@ -141,15 +119,15 @@ void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q
141119 n_embd_ = config.n_embd ;
142120 local_n_head_ = n_head_ / tp_world_size;
143121
144- use_q_lora_ = q_lora_rank != - 1 ;
145- q_lora_rank_ = use_q_lora_ ? q_lora_rank : 0 ;
122+ use_q_lora_ = config. q_lora_rank . has_value () ;
123+ q_lora_rank_ = config. q_lora_rank . value_or ( 0 ) ;
146124 kv_lora_rank_ = kv_lora_rank;
147125 qk_nope_head_dim_ = qk_nope_head_dim;
148126 qk_rope_head_dim_ = qk_rope_head_dim;
149127 qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_;
150128 v_head_dim_ = v_head_dim;
151- q_down_proj_use_tp_ = q_down_proj_use_tp;
152- kv_down_proj_use_tp_ = kv_down_proj_use_tp;
129+ q_down_proj_use_tp_ = config. q_down_proj_use_tp ;
130+ kv_down_proj_use_tp_ = config. kv_down_proj_use_tp ;
153131}
154132
155133std::vector<std::shared_ptr<infini_train::Tensor>>
@@ -173,7 +151,7 @@ MLASelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor
173151
174152 // ----------- Q PATH -----------
175153 // Q path, align with Megatron:
176- // - q_lora_rank == -1 -> linear_q_proj directly;
154+ // - q_lora_rank == nullopt -> linear_q_proj directly;
177155 // - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj.
178156 std::shared_ptr<Tensor> q;
179157 if (use_q_lora_) {
@@ -224,8 +202,8 @@ MLASelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor
224202 // compressed_kv: (B, T_local, R_kv), k_pos_emb: (B, T_local, D_rope)
225203 auto compressed_kv = compressed_kv_with_pe->Slice (-1 , 0 , kv_lora_rank_);
226204 auto k_pos_emb = compressed_kv_with_pe->Slice (-1 , kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous ();
227- const bool k_pos_emb_has_full_sequence = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded
228- && sequence_parallel_enabled;
205+ const bool k_pos_emb_has_full_sequence
206+ = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded && sequence_parallel_enabled;
229207 if (k_pos_emb_has_full_sequence) {
230208 // k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj.
231209 // compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv)
@@ -285,7 +263,7 @@ MLASelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor
285263 y = y->Transpose (1 , 2 )->Contiguous ()->View ({B, T, local_n_head_ * v_head_dim_});
286264 // linear_proj: (B, T, H_local * D_v) -> (B, T, C)
287265 y = (*modules_[kLinearProjLayerName ])({y})[0 ];
288-
266+
289267 return {y};
290268}
291269
0 commit comments