Skip to content

Commit 0bcc009

Browse files
committed
refactor: remove attention_type_ from TransformerLayer
1 parent 88b292b commit 0bcc009

File tree

6 files changed

+23
-28
lines changed

6 files changed

+23
-28
lines changed

infini_train/include/core/transformer/spec_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ template <typename T> T GetRequiredParam(const ModuleSpec &spec, const std::stri
8484
CHECK(spec.params_.contains(key)) << "Missing required parameter: " << key;
8585

8686
const T *value = std::any_cast<T>(&spec.params_.at(key));
87-
CHECK(value) << "Parameter type mismatch for key '" << key << "': expected " << typeid(T).name() << ", got "
88-
<< spec.params_.at(key).type().name();
87+
CHECK(value) << std::format("Parameter type mismatch for key '{}': expected {}, got {}", key, typeid(T).name(),
88+
spec.params_.at(key).type().name());
8989
return *value;
9090
}
9191

infini_train/include/core/transformer/transformer_config.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ struct TransformerConfig {
3939
NormType norm_type = NormType::kLayerNorm; // Normalization type
4040

4141
bool use_bias = true; // Linear layers bias (GPT2: true, LLaMA3: false)
42-
bool use_gqa = false; // Grouped Query Attention
4342
bool tie_weights = true; // Tie embedding and lm_head weights
4443

4544
// FFN config
@@ -59,6 +58,8 @@ struct TransformerConfig {
5958
bool flash = false; // flash attention
6059
int64_t max_gen_batch_size = 4; // max batch size during inference
6160

61+
bool UseGQA() const { return n_kv_head < n_head; }
62+
6263
static TransformerConfig GPT2() {
6364
return {.model_type = kGPT2Name,
6465
.block_size = 1024,
@@ -72,7 +73,6 @@ struct TransformerConfig {
7273
.activation_type = MLPType::kGELU,
7374
.norm_type = NormType::kLayerNorm,
7475
.use_bias = true,
75-
.use_gqa = false,
7676
.tie_weights = true,
7777
.ffn_expansion_ratio = 4.0f,
7878
.ffn_dim_multiplier = std::nullopt,
@@ -91,7 +91,6 @@ struct TransformerConfig {
9191
.activation_type = MLPType::kSwiGLU,
9292
.norm_type = NormType::kRMSNorm,
9393
.use_bias = false,
94-
.use_gqa = true,
9594
.tie_weights = false,
9695
.ffn_expansion_ratio = 4.0f,
9796
.ffn_dim_multiplier = 1.5f,

infini_train/include/nn/modules/transformer.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ class TransformerLayer : public infini_train::nn::CloneableModule<TransformerLay
1919

2020
std::vector<std::shared_ptr<infini_train::Tensor>>
2121
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
22-
23-
private:
24-
AttentionType attention_type_ = AttentionType::kStandard;
2522
};
2623

2724
} // namespace infini_train::nn

infini_train/src/core/transformer/transformer_builders.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ ModuleSpec BuildAttentionSpec(const TransformerConfig &config) {
4141

4242
// Calculate QKV output dimension based on attention type and GQA
4343
int64_t qkv_out;
44-
if (config.use_gqa && config.n_kv_head < config.n_head) {
44+
if (config.UseGQA()) {
4545
// GQA style (LLaMA3 with GQA enabled)
4646
int64_t head_dim = config.n_embd / config.n_head;
4747
// qkv_out = config.n_embd + 2 * config.n_kv_head * head_dim;

infini_train/src/nn/modules/causal_self_attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void CausalSelfAttention::SetupAttention(const TransformerConfig &config) {
5252
local_n_head_ = n_head_ / tp_world_size;
5353

5454
// For GQA, set n_kv_head and n_rep
55-
if (config.use_gqa && config.n_kv_head < config.n_head) {
55+
if (config.UseGQA()) {
5656
CHECK_EQ(config.n_head % config.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA";
5757
CHECK_EQ(config.n_kv_head % tp_world_size, 0) << "n_kv_head must be divisible by TP world size for GQA";
5858

infini_train/src/nn/modules/transformer.cc

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,37 @@
1919

2020
namespace infini_train::nn {
2121
TransformerLayer::TransformerLayer(const nn::TransformerConfig &config, const ModuleSpec &spec)
22-
: CloneableModule(kType), attention_type_(config.attention_type) {
22+
: CloneableModule(kType) {
2323
modules_[kLn1LayerName] = BuildModule(config, spec.submodules_.at(kLn1LayerName));
2424
modules_[kAttnLayerName] = BuildModule(config, spec.submodules_.at(kAttnLayerName));
2525
modules_[kLn2LayerName] = BuildModule(config, spec.submodules_.at(kLn2LayerName));
2626
modules_[kMlpLayerName] = BuildModule(config, spec.submodules_.at(kMlpLayerName));
2727
}
2828

29-
std::vector<std::shared_ptr<infini_train::Tensor>>
30-
TransformerLayer::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
29+
std::vector<std::shared_ptr<Tensor>> TransformerLayer::Forward(const std::vector<std::shared_ptr<Tensor>> &x) {
3130
// (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd)
3231
auto ln1_out = (*modules_[kLn1LayerName])({x[0]})[0];
3332

34-
std::shared_ptr<infini_train::Tensor> x1;
35-
// Build attention input
36-
if (attention_type_ == AttentionType::kRoPE) {
37-
// LLaMA3: {ln1_out, freqs_cis, start_pos, mask}
38-
const auto freqs_cis = x.size() > 1 ? x[1] : nullptr;
39-
const auto start_pos = x.size() > 2 ? x[2] : nullptr;
40-
const auto mask = x.size() > 3 ? x[3] : nullptr;
41-
auto attn_out = (*modules_[kAttnLayerName])({ln1_out, freqs_cis, start_pos, mask})[0];
42-
x1 = x[0] + attn_out;
43-
} else {
44-
// GPT2: {ln1_out}
45-
auto attn_out = (*modules_[kAttnLayerName])({ln1_out})[0];
46-
x1 = x[0] + attn_out;
33+
std::vector<std::shared_ptr<Tensor>> attn_input = {ln1_out};
34+
if (x.size() > 1) {
35+
attn_input.push_back(x[1]); // freqs_cis
36+
}
37+
if (x.size() > 2) {
38+
attn_input.push_back(x[2]); // start_pos
39+
}
40+
if (x.size() > 3) {
41+
attn_input.push_back(x[3]); // mask
4742
}
4843

49-
// (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> MLP -> (bs, seq_len, n_embd)
50-
// -> Add -> (bs, seq_len, n_embd)
44+
auto attn_out = (*modules_[kAttnLayerName])(attn_input)[0];
45+
auto x1 = x[0] + attn_out;
46+
47+
// (bs, seq_len, n_embd) -> Layernorm -> (bs, seq_len, n_embd) -> MLP -> (bs, seq_len, n_embd) -> Add -> (bs,
48+
// seq_len, n_embd)
5149
auto x2 = x1 + (*modules_[kMlpLayerName])((*modules_[kLn2LayerName])({x1}))[0];
5250

5351
// (bs, seq_len, n_embd)
5452
return {x2};
5553
}
54+
5655
} // namespace infini_train::nn

0 commit comments

Comments
 (0)