Skip to content

Commit dd18b35

Browse files
fix: move mla args into TransformerConfig
1 parent 87ca357 commit dd18b35

5 files changed

Lines changed: 58 additions & 76 deletions

File tree

infini_train/include/nn/modules/transformer/mla_self_attention.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ class MLASelfAttention : public infini_train::nn::CloneableModule<MLASelfAttenti
2424
static constexpr char kParamBiasName[] = "bias";
2525

2626
explicit MLASelfAttention(const TransformerConfig &config);
27-
MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
28-
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
29-
bool q_down_proj_use_tp = false, bool kv_down_proj_use_tp = false);
3027

3128
std::vector<std::shared_ptr<infini_train::Tensor>>
3229
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
@@ -48,9 +45,7 @@ class MLASelfAttention : public infini_train::nn::CloneableModule<MLASelfAttenti
4845
bool q_down_proj_use_tp_ = false;
4946
bool kv_down_proj_use_tp_ = false;
5047

51-
void SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
52-
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
53-
bool q_down_proj_use_tp, bool kv_down_proj_use_tp);
48+
void SetupAttention(const TransformerConfig &config);
5449
};
5550

5651
} // namespace infini_train::nn

infini_train/include/nn/modules/transformer/transformer_config.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ struct TransformerConfig {
5353
float rope_theta = 500000.0f; // theta in RoPE
5454
bool use_scaled_rope = false; // scaled RoPE
5555

56+
// MLA config
57+
bool multi_latent_attention = false; // Use MLA instead of standard causal self-attention.
58+
std::optional<int64_t> q_lora_rank = std::nullopt; // nullopt means direct linear_q_proj path.
59+
int64_t kv_lora_rank = 0; // 0 falls back to n_embd in MLASelfAttention.
60+
int64_t qk_nope_head_dim = 0; // 0 falls back to n_embd / n_head.
61+
int64_t qk_rope_head_dim = 0; // 0 falls back to n_embd / n_head.
62+
int64_t v_head_dim = 0; // 0 falls back to n_embd / n_head.
63+
bool q_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_q_down_proj.
64+
bool kv_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_kv_down_proj.
65+
5666
// Normalization
5767
float norm_eps = 1e-5f; // epsilon in RMSNorm
5868

infini_train/src/nn/modules/transformer/mla_self_attention.cc

Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,9 @@
1919
#include "infini_train/include/tensor.h"
2020

2121
namespace infini_train::nn {
22-
namespace {
23-
int64_t DefaultQKVHeadDim(const TransformerConfig &config) {
24-
CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head";
25-
return config.n_embd / config.n_head;
26-
}
27-
28-
int64_t DefaultQKRoPEHeadDim(const TransformerConfig &config) {
29-
return DefaultQKVHeadDim(config);
30-
}
3122

32-
int64_t DefaultQKNoPEHeadDim(const TransformerConfig &config) {
33-
return DefaultQKVHeadDim(config);
34-
}
35-
} // namespace
36-
37-
MLASelfAttention::MLASelfAttention(const TransformerConfig &config)
38-
: MLASelfAttention(config,
39-
/*q_lora_rank=*/config.n_embd,
40-
/*kv_lora_rank=*/config.n_embd,
41-
/*qk_nope_head_dim=*/DefaultQKNoPEHeadDim(config),
42-
/*qk_rope_head_dim=*/DefaultQKRoPEHeadDim(config),
43-
/*v_head_dim=*/DefaultQKVHeadDim(config)) {}
44-
45-
MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
46-
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
47-
bool q_down_proj_use_tp, bool kv_down_proj_use_tp)
48-
: CloneableModule(kType), config_(config) {
49-
SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
50-
q_down_proj_use_tp, kv_down_proj_use_tp);
23+
MLASelfAttention::MLASelfAttention(const TransformerConfig &config) : CloneableModule(kType), config_(config) {
24+
SetupAttention(config);
5125

5226
if (use_q_lora_) {
5327
if (q_down_proj_use_tp_) {
@@ -123,15 +97,19 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo
12397
->View({1, 1, config_.block_size, config_.block_size});
12498
}
12599

126-
void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
127-
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
128-
bool q_down_proj_use_tp, bool kv_down_proj_use_tp) {
100+
void MLASelfAttention::SetupAttention(const TransformerConfig &config) {
129101
auto tp_world_size = nn::parallel::global::GetTensorParallelSize();
130102

131103
CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head";
132104
CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size";
133-
CHECK(q_lora_rank == -1 || q_lora_rank > 0) << "q_lora_rank must be positive, or -1 to disable q LoRA";
134-
CHECK_GT(kv_lora_rank, 0) << "kv_lora_rank must be positive";
105+
CHECK(!config.q_lora_rank.has_value() || config.q_lora_rank.value() > 0) << "q_lora_rank must be positive when set";
106+
107+
const auto default_head_dim = config.n_embd / config.n_head;
108+
const int64_t kv_lora_rank = config.kv_lora_rank > 0 ? config.kv_lora_rank : config.n_embd;
109+
const int64_t qk_nope_head_dim = config.qk_nope_head_dim > 0 ? config.qk_nope_head_dim : default_head_dim;
110+
const int64_t qk_rope_head_dim = config.qk_rope_head_dim > 0 ? config.qk_rope_head_dim : default_head_dim;
111+
const int64_t v_head_dim = config.v_head_dim > 0 ? config.v_head_dim : default_head_dim;
112+
135113
CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive";
136114
CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive";
137115
CHECK_GT(v_head_dim, 0) << "v_head_dim must be positive";
@@ -141,15 +119,15 @@ void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q
141119
n_embd_ = config.n_embd;
142120
local_n_head_ = n_head_ / tp_world_size;
143121

144-
use_q_lora_ = q_lora_rank != -1;
145-
q_lora_rank_ = use_q_lora_ ? q_lora_rank : 0;
122+
use_q_lora_ = config.q_lora_rank.has_value();
123+
q_lora_rank_ = config.q_lora_rank.value_or(0);
146124
kv_lora_rank_ = kv_lora_rank;
147125
qk_nope_head_dim_ = qk_nope_head_dim;
148126
qk_rope_head_dim_ = qk_rope_head_dim;
149127
qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_;
150128
v_head_dim_ = v_head_dim;
151-
q_down_proj_use_tp_ = q_down_proj_use_tp;
152-
kv_down_proj_use_tp_ = kv_down_proj_use_tp;
129+
q_down_proj_use_tp_ = config.q_down_proj_use_tp;
130+
kv_down_proj_use_tp_ = config.kv_down_proj_use_tp;
153131
}
154132

155133
std::vector<std::shared_ptr<infini_train::Tensor>>
@@ -173,7 +151,7 @@ MLASelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor
173151

174152
// ----------- Q PATH -----------
175153
// Q path, align with Megatron:
176-
// - q_lora_rank == -1 -> linear_q_proj directly;
154+
// - q_lora_rank == nullopt -> linear_q_proj directly;
177155
// - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj.
178156
std::shared_ptr<Tensor> q;
179157
if (use_q_lora_) {
@@ -224,8 +202,8 @@ MLASelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor
224202
// compressed_kv: (B, T_local, R_kv), k_pos_emb: (B, T_local, D_rope)
225203
auto compressed_kv = compressed_kv_with_pe->Slice(-1, 0, kv_lora_rank_);
226204
auto k_pos_emb = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous();
227-
const bool k_pos_emb_has_full_sequence = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded
228-
&& sequence_parallel_enabled;
205+
const bool k_pos_emb_has_full_sequence
206+
= kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded && sequence_parallel_enabled;
229207
if (k_pos_emb_has_full_sequence) {
230208
// k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj.
231209
// compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv)
@@ -285,7 +263,7 @@ MLASelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor
285263
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_});
286264
// linear_proj: (B, T, H_local * D_v) -> (B, T, C)
287265
y = (*modules_[kLinearProjLayerName])({y})[0];
288-
266+
289267
return {y};
290268
}
291269

infini_train/src/nn/modules/transformer/transformer.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "infini_train/include/nn/modules/normalization.h"
1515
#include "infini_train/include/nn/modules/sparse.h"
1616
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
17+
#include "infini_train/include/nn/modules/transformer/mla_self_attention.h"
1718
#include "infini_train/include/nn/modules/transformer/mlp.h"
1819
#include "infini_train/include/nn/modules/transformer/utils.h"
1920
#include "infini_train/include/nn/parallel/global.h"
@@ -28,8 +29,8 @@ TransformerFirstStage::TransformerFirstStage(const TransformerConfig &config)
2829
modules_[kWTELayerName] = std::make_shared<parallel::VocabParallelEmbedding>(
2930
config_.vocab_size, config_.n_embd, parallel::global::GetSequenceParallelEnabled());
3031

31-
// LLaMA3 use RoPE, so they don't need position embedding
32-
if (config_.activation_type == MLPType::kGELU) {
32+
// RoPE-based models do not use absolute position embedding.
33+
if (config_.attention_type == AttentionType::kStandard) {
3334
modules_[kWPELayerName] = std::make_shared<Embedding>(config_.block_size, config_.n_embd);
3435
}
3536
}
@@ -85,7 +86,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea
8586
LOG(FATAL) << "Unsupported norm type";
8687
}
8788

88-
modules_[kAttnLayerName] = std::make_shared<CausalSelfAttention>(config);
89+
if (config.multi_latent_attention) {
90+
modules_[kAttnLayerName] = std::make_shared<MLASelfAttention>(config);
91+
} else {
92+
modules_[kAttnLayerName] = std::make_shared<CausalSelfAttention>(config);
93+
}
8994
modules_[kMlpLayerName] = std::make_shared<MLP>(config);
9095
}
9196

@@ -135,8 +140,10 @@ std::vector<std::shared_ptr<Tensor>> TransformerChunk::Forward(const std::vector
135140

136141
// Init freqs_cis on device only once
137142
if (buffers_[kFreqsCisName] == nullptr) {
138-
int64_t head_dim = config_.n_embd / config_.n_head;
139-
buffers_[kFreqsCisName] = PrecomputeFreqsCis(head_dim, config_.block_size * 2, config_.rope_theta,
143+
int64_t rope_head_dim = config_.multi_latent_attention && config_.qk_rope_head_dim > 0
144+
? config_.qk_rope_head_dim
145+
: config_.n_embd / config_.n_head;
146+
buffers_[kFreqsCisName] = PrecomputeFreqsCis(rope_head_dim, config_.block_size * 2, config_.rope_theta,
140147
config_.use_scaled_rope, device);
141148
}
142149

tests/transformer/test_transformer_architecture.cc

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cmath>
22
#include <memory>
3+
#include <optional>
34
#include <vector>
45

56
#include "gtest/gtest.h"
@@ -121,14 +122,14 @@ TEST_P(TransformerModuleTest, MLAAttention) {
121122
config.block_size = 16;
122123
config.attention_type = nn::AttentionType::kStandard;
123124
config.add_bias_linear = true;
124-
125-
auto attn = std::make_shared<nn::MLASelfAttention>(
126-
config,
127-
/*q_lora_rank=*/32,
128-
/*kv_lora_rank=*/32,
129-
/*qk_nope_head_dim=*/8,
130-
/*qk_rope_head_dim=*/8,
131-
/*v_head_dim=*/16);
125+
config.multi_latent_attention = true;
126+
config.q_lora_rank = 32;
127+
config.kv_lora_rank = 32;
128+
config.qk_nope_head_dim = 8;
129+
config.qk_rope_head_dim = 8;
130+
config.v_head_dim = 16;
131+
132+
auto attn = std::make_shared<nn::MLASelfAttention>(config);
132133
attn->To(GetDevice());
133134
EXPECT_FALSE(attn->Parameters().empty());
134135
EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), nn::Linear::kType);
@@ -138,15 +139,10 @@ TEST_P(TransformerModuleTest, MLAAttention) {
138139
auto output = (*attn)({input});
139140
EXPECT_EQ(output[0]->Dims(), input->Dims());
140141

141-
auto tp_down_attn = std::make_shared<nn::MLASelfAttention>(
142-
config,
143-
/*q_lora_rank=*/32,
144-
/*kv_lora_rank=*/32,
145-
/*qk_nope_head_dim=*/8,
146-
/*qk_rope_head_dim=*/8,
147-
/*v_head_dim=*/16,
148-
/*q_down_proj_use_tp=*/true,
149-
/*kv_down_proj_use_tp=*/true);
142+
auto tp_down_config = config;
143+
tp_down_config.q_down_proj_use_tp = true;
144+
tp_down_config.kv_down_proj_use_tp = true;
145+
auto tp_down_attn = std::make_shared<nn::MLASelfAttention>(tp_down_config);
150146
tp_down_attn->To(GetDevice());
151147
EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(),
152148
nn::parallel::ColumnParallelLinear::kType);
@@ -155,13 +151,9 @@ TEST_P(TransformerModuleTest, MLAAttention) {
155151
output = (*tp_down_attn)({input});
156152
EXPECT_EQ(output[0]->Dims(), input->Dims());
157153

158-
auto direct_q_attn = std::make_shared<nn::MLASelfAttention>(
159-
config,
160-
/*q_lora_rank=*/-1,
161-
/*kv_lora_rank=*/32,
162-
/*qk_nope_head_dim=*/8,
163-
/*qk_rope_head_dim=*/8,
164-
/*v_head_dim=*/16);
154+
auto direct_q_config = config;
155+
direct_q_config.q_lora_rank = std::nullopt;
156+
auto direct_q_attn = std::make_shared<nn::MLASelfAttention>(direct_q_config);
165157
direct_q_attn->To(GetDevice());
166158
EXPECT_EQ(direct_q_attn->module(nn::MLASelfAttention::kLinearQProjLayerName).type(),
167159
nn::parallel::ColumnParallelLinear::kType);

0 commit comments

Comments
 (0)