Skip to content

Commit 6ba15c3

Browse files
committed
Apply PR comment fixes
1 parent 1e43784 commit 6ba15c3

File tree

25 files changed

+841
-668
lines changed

25 files changed

+841
-668
lines changed

example/gpt2/net.cc

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
#include "example/common/utils.h"
1414
#include "infini_train/include/core/models/decode_only_transformer/model.h"
15-
#include "infini_train/include/core/transformer/transformer_block.h"
15+
#include "infini_train/include/core/transformer/attention/causal_self_attention.h"
16+
#include "infini_train/include/core/transformer/mlp.h"
1617
#include "infini_train/include/core/transformer/transformer_config.h"
18+
#include "infini_train/include/core/transformer/transformer_layer.h"
1719
#include "infini_train/include/nn/modules/normalization.h"
1820
#include "infini_train/include/nn/modules/sparse.h"
1921
#include "infini_train/include/nn/parallel/global.h"
@@ -133,7 +135,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
133135
// local: (vocab_size_per_partition, n_embd)
134136
if (is_first_stage) {
135137
auto &transformer_wte_weight
136-
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerFirstStage::kWTELayerName,
138+
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerFirstStage::kWTELayerName,
137139
nn::parallel::VocabParallelEmbedding::kParamWeightName)];
138140
ReadMatrixRowShardFloat(ifs, static_cast<float *>(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd,
139141
v_start, vpp);
@@ -155,7 +157,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
155157
if (is_first_stage) {
156158
// transformer.wpe.weight
157159
auto &transformer_wpe_weight
158-
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerFirstStage::kWPELayerName,
160+
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerFirstStage::kWPELayerName,
159161
nn::Embedding::kParamWeightName)];
160162
ReadMatrixAllFloat(ifs, static_cast<float *>(transformer_wpe_weight->DataPtr()), block_size, n_embd);
161163
} else {
@@ -168,9 +170,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
168170
for (int idx = 0; idx < n_layer; ++idx) {
169171
if (owned_layers[idx]) {
170172
auto &tensor
171-
= state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName,
173+
= state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
172174
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
173-
nn::TransformerBlock::kLn1LayerName, nn::LayerNorm::kParamWeightName)];
175+
nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamWeightName)];
174176
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
175177
++local_layer_index;
176178
} else {
@@ -183,9 +185,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
183185
local_layer_index = 0;
184186
for (int idx = 0; idx < n_layer; ++idx) {
185187
if (owned_layers[idx]) {
186-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName,
188+
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
187189
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
188-
nn::TransformerBlock::kLn1LayerName, nn::LayerNorm::kParamBiasName)];
190+
nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamBiasName)];
189191
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
190192
++local_layer_index;
191193
} else {
@@ -199,8 +201,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
199201
for (int idx = 0; idx < n_layer; ++idx) {
200202
if (owned_layers[idx]) {
201203
auto &tensor = state_dict[std::format(
202-
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerChunk::kHLayerName,
203-
std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName,
204+
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerChunk::kHLayerName,
205+
std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName,
204206
nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)];
205207
// NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim,
206208
// i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T
@@ -242,8 +244,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
242244
for (int idx = 0; idx < n_layer; ++idx) {
243245
if (owned_layers[idx]) {
244246
auto &tensor = state_dict[std::format(
245-
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerChunk::kHLayerName,
246-
std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName,
247+
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerChunk::kHLayerName,
248+
std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName,
247249
nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)];
248250
// NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated
249251
// i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn]
@@ -284,8 +286,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
284286
for (int idx = 0; idx < n_layer; ++idx) {
285287
if (owned_layers[idx]) {
286288
auto &tensor = state_dict[std::format(
287-
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerChunk::kHLayerName,
288-
std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName,
289+
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerChunk::kHLayerName,
290+
std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName,
289291
nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)];
290292
ReadMatrixColShardFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp,
291293
in_pp);
@@ -301,8 +303,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
301303
for (int idx = 0; idx < n_layer; ++idx) {
302304
if (owned_layers[idx]) {
303305
auto &tensor = state_dict[std::format(
304-
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerChunk::kHLayerName,
305-
std::to_string(local_layer_index), nn::TransformerBlock::kAttnLayerName,
306+
"{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerChunk::kHLayerName,
307+
std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName,
306308
nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)];
307309
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
308310
++local_layer_index;
@@ -317,9 +319,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
317319
for (int idx = 0; idx < n_layer; ++idx) {
318320
if (owned_layers[idx]) {
319321
auto &tensor
320-
= state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName,
322+
= state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
321323
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
322-
nn::TransformerBlock::kLn2LayerName, nn::LayerNorm::kParamWeightName)];
324+
nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamWeightName)];
323325
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
324326
++local_layer_index;
325327
} else {
@@ -332,9 +334,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
332334
local_layer_index = 0;
333335
for (int idx = 0; idx < n_layer; ++idx) {
334336
if (owned_layers[idx]) {
335-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName,
337+
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
336338
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
337-
nn::TransformerBlock::kLn2LayerName, nn::LayerNorm::kParamBiasName)];
339+
nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamBiasName)];
338340
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
339341
++local_layer_index;
340342
} else {
@@ -347,9 +349,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
347349
local_layer_index = 0;
348350
for (int idx = 0; idx < n_layer; ++idx) {
349351
if (owned_layers[idx]) {
350-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName,
352+
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
351353
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
352-
nn::TransformerBlock::kMlpLayerName, nn::MLP::kCFcLayerName,
354+
nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName,
353355
nn::parallel::ColumnParallelLinear::kParamWeightName)];
354356
ReadMatrixRowShardFloat(ifs, static_cast<float *>(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp);
355357
++local_layer_index;
@@ -363,9 +365,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
363365
local_layer_index = 0;
364366
for (int idx = 0; idx < n_layer; ++idx) {
365367
if (owned_layers[idx]) {
366-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName,
368+
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
367369
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
368-
nn::TransformerBlock::kMlpLayerName, nn::MLP::kCFcLayerName,
370+
nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName,
369371
nn::parallel::ColumnParallelLinear::kParamBiasName)];
370372
ReadVectorShardFloat(ifs, static_cast<float *>(tensor->DataPtr()), fc_out, fc_start, fc_pp);
371373
++local_layer_index;
@@ -379,9 +381,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
379381
local_layer_index = 0;
380382
for (int idx = 0; idx < n_layer; ++idx) {
381383
if (owned_layers[idx]) {
382-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName,
384+
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
383385
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
384-
nn::TransformerBlock::kMlpLayerName, nn::MLP::kCProjLayerName,
386+
nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName,
385387
nn::parallel::RowParallelLinear::kParamWeightName)];
386388
ReadMatrixColShardFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp,
387389
in4_pp);
@@ -396,9 +398,9 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
396398
local_layer_index = 0;
397399
for (int idx = 0; idx < n_layer; ++idx) {
398400
if (owned_layers[idx]) {
399-
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName,
401+
auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerModelName,
400402
nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index),
401-
nn::TransformerBlock::kMlpLayerName, nn::MLP::kCProjLayerName,
403+
nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName,
402404
nn::parallel::RowParallelLinear::kParamBiasName)];
403405
ReadVectorAllFloat(ifs, static_cast<float *>(tensor->DataPtr()), n_embd);
404406
++local_layer_index;
@@ -411,12 +413,12 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
411413
if (is_last_stage) {
412414
// transformer.ln_f.weight
413415
auto &transformer_ln_f_weight
414-
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerLastStage::kLnFLayerName,
416+
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerLastStage::kLnFLayerName,
415417
nn::LayerNorm::kParamWeightName)];
416418
ReadVectorAllFloat(ifs, static_cast<float *>(transformer_ln_f_weight->DataPtr()), n_embd);
417419
// transformer.ln_f.bias
418420
auto &transformer_ln_f_bias
419-
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, nn::TransformerLastStage::kLnFLayerName,
421+
= state_dict[std::format("{}.{}.{}", GPT2::kTransformerModelName, nn::TransformerLastStage::kLnFLayerName,
420422
nn::LayerNorm::kParamBiasName)];
421423
ReadVectorAllFloat(ifs, static_cast<float *>(transformer_ln_f_bias->DataPtr()), n_embd);
422424
} else {

0 commit comments

Comments
 (0)