Skip to content

Commit 87ca357

Browse files
feat: support q_lora/non-q_lora and tp/non-tp variations
1 parent 9de7f8f commit 87ca357

5 files changed

Lines changed: 242 additions & 57 deletions

File tree

infini_train/include/nn/modules/transformer/mla_self_attention.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,21 @@ class MLASelfAttention : public infini_train::nn::CloneableModule<MLASelfAttenti
1212
public:
1313
static constexpr char kType[] = "MLASelfAttention";
1414

15-
static constexpr char kQAProjLayerName[] = "q_a_proj";
16-
static constexpr char kQANormLayerName[] = "q_a_layernorm";
17-
static constexpr char kQBProjLayerName[] = "q_b_proj";
18-
static constexpr char kKVAProjLayerName[] = "kv_a_proj_with_mqa";
19-
static constexpr char kKVANormLayerName[] = "kv_a_layernorm";
20-
static constexpr char kKVBProjLayerName[] = "kv_b_proj";
21-
static constexpr char kCProjLayerName[] = "c_proj";
15+
static constexpr char kLinearQProjLayerName[] = "linear_q_proj";
16+
static constexpr char kLinearQDownProjLayerName[] = "linear_q_down_proj";
17+
static constexpr char kQLayerNormLayerName[] = "q_layernorm";
18+
static constexpr char kLinearQUpProjLayerName[] = "linear_q_up_proj";
19+
static constexpr char kLinearKVDownProjLayerName[] = "linear_kv_down_proj";
20+
static constexpr char kKVLayerNormLayerName[] = "kv_layernorm";
21+
static constexpr char kLinearKVUpProjLayerName[] = "linear_kv_up_proj";
22+
static constexpr char kLinearProjLayerName[] = "linear_proj";
2223

2324
static constexpr char kParamBiasName[] = "bias";
2425

2526
explicit MLASelfAttention(const TransformerConfig &config);
2627
MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
27-
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim);
28+
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
29+
bool q_down_proj_use_tp = false, bool kv_down_proj_use_tp = false);
2830

2931
std::vector<std::shared_ptr<infini_train::Tensor>>
3032
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
@@ -42,9 +44,13 @@ class MLASelfAttention : public infini_train::nn::CloneableModule<MLASelfAttenti
4244
int64_t qk_head_dim_ = 0;
4345
int64_t v_head_dim_ = 0;
4446

45-
void SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
46-
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim);
47+
bool use_q_lora_ = true;
48+
bool q_down_proj_use_tp_ = false;
49+
bool kv_down_proj_use_tp_ = false;
4750

51+
void SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
52+
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
53+
bool q_down_proj_use_tp, bool kv_down_proj_use_tp);
4854
};
4955

5056
} // namespace infini_train::nn

infini_train/include/nn/parallel/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ std::vector<int> GetPipelineParallelGroupRanks(int global_rank);
2323

2424
// TP/SP Communication Helper Functions
2525
std::vector<std::shared_ptr<Tensor>> GatherFromTPRegionFunc(const std::shared_ptr<Tensor> &input);
26+
std::vector<std::shared_ptr<Tensor>> ScatterToSPRegionFunc(const std::shared_ptr<Tensor> &input);
2627
std::vector<std::shared_ptr<Tensor>> ReduceScatterToSPRegionFunc(const std::shared_ptr<Tensor> &input);
2728
std::vector<std::shared_ptr<Tensor>> GatherFromSPRegionFunc(const std::shared_ptr<Tensor> &input);
2829
std::vector<std::shared_ptr<Tensor>> ScatterToTPRegionFunc(const std::shared_ptr<Tensor> &input);

infini_train/src/nn/modules/transformer/mla_self_attention.cc

Lines changed: 154 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,65 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config)
4343
/*v_head_dim=*/DefaultQKVHeadDim(config)) {}
4444

4545
MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
46-
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim)
46+
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
47+
bool q_down_proj_use_tp, bool kv_down_proj_use_tp)
4748
: CloneableModule(kType), config_(config) {
48-
SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim);
49-
50-
modules_[kQAProjLayerName] = std::make_shared<nn::Linear>(
51-
/*in_features=*/n_embd_,
52-
/*out_features=*/q_lora_rank_,
53-
/*bias=*/config_.add_bias_linear);
54-
modules_[kQANormLayerName] = std::make_shared<nn::RMSNorm>(q_lora_rank_, config_.norm_eps);
55-
modules_[kQBProjLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
56-
/*in_features=*/q_lora_rank_,
57-
/*out_features=*/n_head_ * qk_head_dim_,
58-
/*bias=*/config_.add_bias_linear,
59-
/*gather_output=*/false,
60-
/*input_is_parallel=*/false,
61-
/*skip_bias_add=*/false,
62-
/*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled());
49+
SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
50+
q_down_proj_use_tp, kv_down_proj_use_tp);
51+
52+
if (use_q_lora_) {
53+
if (q_down_proj_use_tp_) {
54+
modules_[kLinearQDownProjLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
55+
/*in_features=*/n_embd_,
56+
/*out_features=*/q_lora_rank_,
57+
/*bias=*/config_.add_bias_linear,
58+
/*gather_output=*/false,
59+
/*input_is_parallel=*/false,
60+
/*skip_bias_add=*/false,
61+
/*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled());
62+
} else {
63+
modules_[kLinearQDownProjLayerName] = std::make_shared<nn::Linear>(
64+
/*in_features=*/n_embd_,
65+
/*out_features=*/q_lora_rank_,
66+
/*bias=*/config_.add_bias_linear);
67+
}
68+
modules_[kQLayerNormLayerName] = std::make_shared<nn::RMSNorm>(q_lora_rank_, config_.norm_eps);
69+
modules_[kLinearQUpProjLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
70+
/*in_features=*/q_lora_rank_,
71+
/*out_features=*/n_head_ * qk_head_dim_,
72+
/*bias=*/config_.add_bias_linear,
73+
/*gather_output=*/false,
74+
/*input_is_parallel=*/false,
75+
/*skip_bias_add=*/false,
76+
/*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled());
77+
} else {
78+
modules_[kLinearQProjLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
79+
/*in_features=*/n_embd_,
80+
/*out_features=*/n_head_ * qk_head_dim_,
81+
/*bias=*/config_.add_bias_linear,
82+
/*gather_output=*/false,
83+
/*input_is_parallel=*/false,
84+
/*skip_bias_add=*/false,
85+
/*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled());
86+
}
6387

64-
modules_[kKVAProjLayerName] = std::make_shared<nn::Linear>(
65-
/*in_features=*/n_embd_,
66-
/*out_features=*/kv_lora_rank_ + qk_rope_head_dim_,
67-
/*bias=*/config_.add_bias_linear);
68-
modules_[kKVANormLayerName] = std::make_shared<nn::RMSNorm>(kv_lora_rank_, config_.norm_eps);
69-
modules_[kKVBProjLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
88+
if (kv_down_proj_use_tp_) {
89+
modules_[kLinearKVDownProjLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
90+
/*in_features=*/n_embd_,
91+
/*out_features=*/kv_lora_rank_ + qk_rope_head_dim_,
92+
/*bias=*/config_.add_bias_linear,
93+
/*gather_output=*/false,
94+
/*input_is_parallel=*/false,
95+
/*skip_bias_add=*/false,
96+
/*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled());
97+
} else {
98+
modules_[kLinearKVDownProjLayerName] = std::make_shared<nn::Linear>(
99+
/*in_features=*/n_embd_,
100+
/*out_features=*/kv_lora_rank_ + qk_rope_head_dim_,
101+
/*bias=*/config_.add_bias_linear);
102+
}
103+
modules_[kKVLayerNormLayerName] = std::make_shared<nn::RMSNorm>(kv_lora_rank_, config_.norm_eps);
104+
modules_[kLinearKVUpProjLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
70105
/*in_features=*/kv_lora_rank_,
71106
/*out_features=*/n_head_ * (qk_nope_head_dim_ + v_head_dim_),
72107
/*bias=*/config_.add_bias_linear,
@@ -75,7 +110,7 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo
75110
/*skip_bias_add=*/false,
76111
/*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled());
77112

78-
modules_[kCProjLayerName] = std::make_shared<nn::parallel::RowParallelLinear>(
113+
modules_[kLinearProjLayerName] = std::make_shared<nn::parallel::RowParallelLinear>(
79114
/*in_features=*/n_head_ * v_head_dim_,
80115
/*out_features=*/n_embd_,
81116
/*bias=*/config_.add_bias_linear,
@@ -89,12 +124,13 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo
89124
}
90125

91126
void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank,
92-
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) {
127+
int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim,
128+
bool q_down_proj_use_tp, bool kv_down_proj_use_tp) {
93129
auto tp_world_size = nn::parallel::global::GetTensorParallelSize();
94130

95131
CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head";
96132
CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size";
97-
CHECK_GT(q_lora_rank, 0) << "q_lora_rank must be positive";
133+
CHECK(q_lora_rank == -1 || q_lora_rank > 0) << "q_lora_rank must be positive, or -1 to disable q LoRA";
98134
CHECK_GT(kv_lora_rank, 0) << "kv_lora_rank must be positive";
99135
CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive";
100136
CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive";
@@ -105,80 +141,151 @@ void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q
105141
n_embd_ = config.n_embd;
106142
local_n_head_ = n_head_ / tp_world_size;
107143

108-
q_lora_rank_ = q_lora_rank;
144+
use_q_lora_ = q_lora_rank != -1;
145+
q_lora_rank_ = use_q_lora_ ? q_lora_rank : 0;
109146
kv_lora_rank_ = kv_lora_rank;
110147
qk_nope_head_dim_ = qk_nope_head_dim;
111148
qk_rope_head_dim_ = qk_rope_head_dim;
112149
qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_;
113150
v_head_dim_ = v_head_dim;
151+
q_down_proj_use_tp_ = q_down_proj_use_tp;
152+
kv_down_proj_use_tp_ = kv_down_proj_use_tp;
114153
}
115154

116155
std::vector<std::shared_ptr<infini_train::Tensor>>
117156
MLASelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
118157
CHECK_GE(x.size(), 1) << "MLASelfAttention expects at least hidden states";
119158

159+
// x[0]: (B, T_local, C)
120160
const auto B = x[0]->Dims()[0];
121161
const auto C = x[0]->Dims()[2];
122162
CHECK_EQ(C, n_embd_) << "hidden size must match n_embd";
123163

164+
// freqs_cis: (T, D_rope / 2, 2)
124165
const auto freqs_cis = x.size() > 1 ? x[1] : nullptr;
166+
// external_mask: (1, 1, T, T)
125167
const auto external_mask = x.size() > 3 ? x[3] : nullptr;
126168
if (config_.attention_type == AttentionType::kRoPE) {
127169
CHECK(freqs_cis != nullptr) << "freqs_cis is null.";
128170
}
129171

130-
// (B, T, C) -> q_a -> RMSNorm -> q_b -> (B, T, H_local * (D_nope + D_rope))
131-
auto q = (*modules_[kQAProjLayerName])({x[0]})[0];
132-
q = (*modules_[kQANormLayerName])({q})[0];
133-
q = (*modules_[kQBProjLayerName])({q})[0];
172+
const bool sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled();
173+
174+
// ----------- Q PATH -----------
175+
// Q path, align with Megatron:
176+
// - q_lora_rank == -1 -> linear_q_proj directly;
177+
// - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj.
178+
std::shared_ptr<Tensor> q;
179+
if (use_q_lora_) {
180+
// linear_q_down_proj:
181+
// non-TP path: (B, T_local, C) -> (B, T_local, R_q)
182+
// TP path before gather: (B, T, C) -> (B, T, R_q / TP)
183+
// - Note that ColumnParallelLinear would perform a GatherFromSPRegion in the beginning
184+
auto q_compressed = (*modules_[kLinearQDownProjLayerName])({x[0]})[0];
185+
if (q_down_proj_use_tp_ && q_compressed->Dims().back() != q_lora_rank_) {
186+
// Gather the sharded latent dimension: (B, T, R_q / TP) -> (B, T, R_q).
187+
q_compressed = nn::parallel::GatherFromTPRegionFunc(q_compressed)[0];
188+
if (sequence_parallel_enabled) {
189+
// Keep the q_up input sequence-sharded: (B, T_full, R_q) -> (B, T_local, R_q).
190+
q_compressed = nn::parallel::ScatterToSPRegionFunc(q_compressed)[0];
191+
}
192+
}
193+
// q_layernorm preserves shape: (B, T_local, R_q)
194+
q_compressed = (*modules_[kQLayerNormLayerName])({q_compressed})[0];
195+
// linear_q_up_proj: (B, T_local, R_q) -> (B, T, H_local * (D_nope + D_rope)).
196+
q = (*modules_[kLinearQUpProjLayerName])({q_compressed})[0];
197+
} else {
198+
// linear_q_proj direct path: (B, T, C) -> (B, T, H_local * (D_nope + D_rope)).
199+
q = (*modules_[kLinearQProjLayerName])({x[0]})[0];
200+
}
201+
202+
// T should be the full seqlen after the q projection path gathers sequence-parallel input.
134203
const auto T = q->Dims()[1];
204+
// q: (B, T, H_local * D_qk) -> (B, T, H_local, D_qk)
205+
// qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_
135206
q = q->View({B, T, local_n_head_, qk_head_dim_});
136207

208+
// q_nope: (B, T, H_local, D_nope), q_pos_emb: (B, T, H_local, D_rope)
137209
auto q_nope = q->Slice(-1, 0, qk_nope_head_dim_);
138-
auto q_pe = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_);
210+
auto q_pos_emb = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_);
211+
212+
// ----------- KV PATH -----------
213+
// linear_kv_down_proj:
214+
// non-TP path: (B, T_local, C) -> (B, T_local, R_kv + D_rope)
215+
// TP path before gather: (B, T, C) -> (B, T, (R_kv + D_rope) / TP)
216+
auto compressed_kv_with_pe = (*modules_[kLinearKVDownProjLayerName])({x[0]})[0];
217+
const auto kv_down_proj_out_dim = kv_lora_rank_ + qk_rope_head_dim_;
218+
const bool kv_down_proj_output_is_sharded = compressed_kv_with_pe->Dims().back() != kv_down_proj_out_dim;
219+
if (kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded) {
220+
// Gather latent+RoPE dim: (B, T, (R_kv + D_rope) / TP) -> (B, T, R_kv + D_rope)
221+
compressed_kv_with_pe = nn::parallel::GatherFromTPRegionFunc(compressed_kv_with_pe)[0];
222+
}
139223

140-
// (B, T, C) -> kv_a -> compressed kv latent and shared RoPE key.
141-
auto compressed_kv_with_pe = (*modules_[kKVAProjLayerName])({x[0]})[0];
224+
// compressed_kv: (B, T_local, R_kv), k_pos_emb: (B, T_local, D_rope)
142225
auto compressed_kv = compressed_kv_with_pe->Slice(-1, 0, kv_lora_rank_);
143-
auto k_pe = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)
144-
->Contiguous();
145-
if (nn::parallel::global::GetSequenceParallelEnabled()) {
146-
k_pe = nn::parallel::GatherFromSPRegionFunc(k_pe)[0];
226+
auto k_pos_emb = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous();
227+
const bool k_pos_emb_has_full_sequence = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded
228+
&& sequence_parallel_enabled;
229+
if (k_pos_emb_has_full_sequence) {
230+
// k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj.
231+
// compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv)
232+
compressed_kv = nn::parallel::ScatterToSPRegionFunc(compressed_kv)[0];
233+
} else if (sequence_parallel_enabled) {
234+
// Replicated down-proj path produces local k_pos_emb; gather it for attention.
235+
// k_pos_emb: (B, T_local, D_rope) -> (B, T, D_rope)
236+
k_pos_emb = nn::parallel::GatherFromSPRegionFunc(k_pos_emb)[0];
147237
}
148-
k_pe = k_pe->View({B, T, 1, qk_rope_head_dim_});
238+
// k_pos_emb: (B, T, D_rope) -> (B, T, 1, D_rope), shared across local heads.
239+
k_pos_emb = k_pos_emb->View({B, T, 1, qk_rope_head_dim_});
149240

150-
// (B, T, R_kv) -> RMSNorm -> kv_b -> (B, T, H_local * (D_nope + D_v))
151-
auto kv = (*modules_[kKVANormLayerName])({compressed_kv})[0];
152-
kv = (*modules_[kKVBProjLayerName])({kv})[0];
241+
// (B, T, R_kv) -> kv_layernorm -> linear_kv_up_proj -> (B, T, H_local * (D_nope + D_v))
242+
// kv_layernorm preserves compressed_kv shape: (B, T_local, R_kv)
243+
auto kv = (*modules_[kKVLayerNormLayerName])({compressed_kv})[0];
244+
// linear_kv_up_proj: (B, T_local, R_kv) -> (B, T, H_local * (D_nope + D_v))
245+
kv = (*modules_[kLinearKVUpProjLayerName])({kv})[0];
246+
// kv: (B, T, H_local * (D_nope + D_v)) -> (B, T, H_local, D_nope + D_v)
153247
kv = kv->View({B, T, local_n_head_, qk_nope_head_dim_ + v_head_dim_});
248+
// k_nope: (B, T, H_local, D_nope), v: (B, T, H_local, D_v)
154249
auto k_nope = kv->Slice(-1, 0, qk_nope_head_dim_);
155250
auto v = kv->Slice(-1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_);
156251

157252
if (config_.attention_type == AttentionType::kRoPE) {
158-
std::tie(q_pe, k_pe) = ApplyRotaryEmbedding(q_pe, k_pe, freqs_cis);
253+
// q_pos_emb: (B, T, H_local, D_rope), k_pos_emb: (B, T, 1, D_rope)
254+
std::tie(q_pos_emb, k_pos_emb) = ApplyRotaryEmbedding(q_pos_emb, k_pos_emb, freqs_cis);
159255
}
160256

161-
k_pe = k_pe->RepeatInterleave(local_n_head_, 2);
162-
q = nn::function::Concat(std::vector<std::shared_ptr<Tensor>>{q_nope, q_pe}, -1);
163-
auto k = nn::function::Concat(std::vector<std::shared_ptr<Tensor>>{k_nope, k_pe}, -1);
257+
// k_pos_emb: (B, T, 1, D_rope) -> (B, T, H_local, D_rope)
258+
k_pos_emb = k_pos_emb->RepeatInterleave(local_n_head_, 2);
259+
// q: (B, T, H_local, D_qk), k: (B, T, H_local, D_qk)
260+
q = nn::function::Concat(std::vector<std::shared_ptr<Tensor>>{q_nope, q_pos_emb}, -1);
261+
auto k = nn::function::Concat(std::vector<std::shared_ptr<Tensor>>{k_nope, k_pos_emb}, -1);
164262

165-
// (B, T, H_local, D) -> (B, H_local, T, D)
263+
// ----------- CORE ATTN -----------
264+
// q/k: (B, T, H_local, D_qk) -> (B, H_local, T, D_qk)
265+
// v: (B, T, H_local, D_v) -> (B, H_local, T, D_v)
166266
q = q->Transpose(1, 2);
167267
k = k->Transpose(1, 2);
168268
v = v->Transpose(1, 2);
169269

270+
// att: (B, H_local, T, T)
170271
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(qk_head_dim_)));
171272
if (external_mask) {
172273
att = att->MaskedFill(external_mask, std::numeric_limits<float>::lowest());
173274
} else {
275+
// mask: (1, 1, T, T)
174276
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
175277
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
176278
}
279+
// att: (B, H_local, T, T)
177280
att = nn::function::Softmax(att, -1);
178281

282+
// y: (B, H_local, T, D_v)
179283
auto y = att->Matmul(v);
284+
// y: (B, H_local, T, D_v) -> (B, T, H_local, D_v) -> (B, T, H_local * D_v)
180285
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_});
181-
y = (*modules_[kCProjLayerName])({y})[0];
286+
// linear_proj: (B, T, H_local * D_v) -> (B, T, C)
287+
y = (*modules_[kLinearProjLayerName])({y})[0];
288+
182289
return {y};
183290
}
184291

0 commit comments

Comments
 (0)