Skip to content

Commit 0a089ae

Browse files
committed
fix(scripts): adapt run_models_and_profile.bash to nested test_groups format
1 parent b32b662 commit 0a089ae

6 files changed

Lines changed: 120 additions & 81 deletions

File tree

example/gpt2/main.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ DEFINE_string(resume_from, "", "checkpoint directory to resume from");
8686
DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints");
8787
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
8888
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
89-
DEFINE_string(checkpoint_format, "pth",
90-
"checkpoint format: bin|pth. "
89+
DEFINE_string(checkpoint_format, "ckpt",
90+
"checkpoint format: bin|ckpt. "
9191
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
92-
"'pth' generates model.pth/optimizer.pth (native StateDict binary).");
92+
"'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary).");
9393
// precision check
9494
DEFINE_string(
9595
precision_check, "",

example/llama3/main.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ DEFINE_string(resume_from, "", "checkpoint directory to resume from");
8484
DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints");
8585
DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep");
8686
DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints");
87-
DEFINE_string(checkpoint_format, "pth",
88-
"checkpoint format: bin|pth. "
87+
DEFINE_string(checkpoint_format, "ckpt",
88+
"checkpoint format: bin|ckpt. "
8989
"'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); "
90-
"'pth' generates model.pth/optimizer.pth (native StateDict binary).");
90+
"'ckpt' generates model.ckpt/optimizer.ckpt (native StateDict binary).");
9191
// precision check
9292
DEFINE_string(
9393
precision_check, "",

infini_train/src/checkpoint.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ template <typename T> T ExtractNumberField(const std::string &content, const std
8686

8787
void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer,
8888
const TrainerState &state, const CheckpointOptions &options) {
89-
CHECK(options.format == "bin" || options.format == "pth") << "Unsupported checkpoint format: " << options.format;
89+
CHECK(options.format == "bin" || options.format == "ckpt") << "Unsupported checkpoint format: " << options.format;
9090
std::filesystem::create_directories(checkpoint_dir);
9191
LOG(ERROR) << "[CKPT] Save begin: dir=" << checkpoint_dir << ", format=" << options.format
9292
<< ", global_step=" << state.global_step;
9393

94-
const auto model_path = checkpoint_dir / (options.format == "pth" ? "model.pth" : "model.bin");
94+
const auto model_path = checkpoint_dir / (options.format == "ckpt" ? "model.ckpt" : "model.bin");
9595
if (options.format == "bin" && options.model_bin_writer) {
9696
options.model_bin_writer(model, model_path);
9797
} else {
@@ -101,7 +101,7 @@ void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Mod
101101
if (options.save_optimizer_state) {
102102
auto opt_state = optimizer.StateDict();
103103
if (!opt_state.empty()) {
104-
const auto opt_path = checkpoint_dir / (options.format == "pth" ? "optimizer.pth" : "optimizer.bin");
104+
const auto opt_path = checkpoint_dir / (options.format == "ckpt" ? "optimizer.ckpt" : "optimizer.bin");
105105
SaveStateDictBinary(opt_path, opt_state);
106106
}
107107
}
@@ -116,7 +116,7 @@ void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module *m
116116
CHECK(state != nullptr);
117117

118118
const std::string format = InferFormat(checkpoint_dir);
119-
const auto model_path = checkpoint_dir / (format == "pth" ? "model.pth" : "model.bin");
119+
const auto model_path = checkpoint_dir / (format == "ckpt" ? "model.ckpt" : "model.bin");
120120
LOG(ERROR) << "[CKPT] Load begin: dir=" << checkpoint_dir << ", format=" << format;
121121
LOG(ERROR) << "[CKPT] Loading model: " << model_path;
122122
if (format == "bin" && options.model_bin_loader) {
@@ -134,7 +134,7 @@ void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module *m
134134
}
135135

136136
if (optimizer != nullptr && options.load_optimizer_state) {
137-
const auto opt_path = checkpoint_dir / (format == "pth" ? "optimizer.pth" : "optimizer.bin");
137+
const auto opt_path = checkpoint_dir / (format == "ckpt" ? "optimizer.ckpt" : "optimizer.bin");
138138
if (std::filesystem::exists(opt_path)) {
139139
LOG(ERROR) << "[CKPT] Loading optimizer: " << opt_path;
140140
optimizer->LoadStateDict(LoadStateDictBinary(opt_path));
@@ -264,8 +264,8 @@ TrainerState Checkpoint::LoadTrainerState(const std::filesystem::path &path) {
264264
}
265265

266266
std::string Checkpoint::InferFormat(const std::filesystem::path &checkpoint_dir) {
267-
if (std::filesystem::exists(checkpoint_dir / "model.pth")) {
268-
return "pth";
267+
if (std::filesystem::exists(checkpoint_dir / "model.ckpt")) {
268+
return "ckpt";
269269
}
270270
if (std::filesystem::exists(checkpoint_dir / "model.bin")) {
271271
return "bin";

scripts/run_models_and_profile.bash

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -210,26 +210,29 @@ move_profile_logs() {
210210
# For checkpoint-related args, automatically isolate by model and run mode
211211
# (resume/no_resume) to avoid cross-test overwrites in one-click runs.
212212
args_string_for_test() {
213-
local idx="$1"
214-
local model_name="$2"
215-
jq -r --argjson i "$idx" --arg model "$model_name" '
216-
def namespaced_path($p; $model; $mode):
217-
if ($p | test("/checkpoint_step_[0-9]+($|/)")) then
218-
($p | capture("^(?<prefix>.*)/(?<step>checkpoint_step_[0-9]+(?:/.*)?)$")) as $m
219-
| ($m.prefix + "/" + $model + "/" + $mode + "/" + $m.step)
220-
else
221-
($p + "/" + $model + "/" + $mode)
222-
end;
223-
224-
.tests[$i].args as $args
225-
| (if ($args | has("resume_from")) then "resume" else "no_resume" end) as $run_mode
226-
| (if (($args.resume_from // "") | test("(^|/)no_resume(/|$)")) then "no_resume" else "resume" end) as $resume_src_mode
227-
| $args
228-
| (if has("checkpoint_dir") then .checkpoint_dir = namespaced_path(.checkpoint_dir; $model; $run_mode) else . end)
229-
| (if has("resume_from") then .resume_from = namespaced_path(.resume_from; $model; $resume_src_mode) else . end)
230-
| to_entries[]
231-
| "--\(.key) \(.value|tostring)"
232-
' "$CONFIG_FILE" | paste -sd' ' -
213+
local group_idx="$1"
214+
local test_idx="$2"
215+
local model_name="$3"
216+
local test_id="$4"
217+
218+
jq -r --argjson g "$group_idx" --argjson t "$test_idx" --arg model "$model_name" --arg test_id "$test_id" '
219+
def namespaced_path($p; $model; $mode):
220+
if ($p | test("/checkpoint_step_[0-9]+($|/)")) then
221+
($p | capture("^(?<prefix>.*)/(?<step>checkpoint_step_[0-9]+(?:/.*)?)$")) as $m
222+
| ($m.prefix + "/" + $model + "/" + $mode + "/" + $m.step)
223+
else
224+
($p + "/" + $model + "/" + $mode)
225+
end;
226+
227+
.test_groups[$g].tests[$t].args as $args
228+
| (if ($args | has("resume_from")) then "resume" else "no_resume" end) as $run_mode
229+
| (if (($args.resume_from // "") | test("no_resume")) then "no_resume" else "resume" end) as $resume_src_mode
230+
| $args
231+
| (if has("checkpoint_dir") then .checkpoint_dir = namespaced_path(.checkpoint_dir; $model; $run_mode) else . end)
232+
| (if has("resume_from") then .resume_from = namespaced_path(.resume_from; $model; $resume_src_mode) else . end)
233+
| to_entries[]
234+
| "--\(.key) \(.value|tostring)"
235+
' "$CONFIG_FILE" | paste -sd' ' -
233236
}
234237

235238
# Run tests
@@ -268,18 +271,28 @@ for ((id=0; id<num_builds; ++id)); do
268271
log_suffix="_profile"
269272
fi
270273

271-
for ((ti=0; ti<num_tests; ++ti)); do
272-
test_id=$(jq -r ".tests[$ti].id" "$CONFIG_FILE")
273-
gpt2_arg_str="$(args_string_for_test "$ti" "gpt2")"
274-
llama3_arg_str="$(args_string_for_test "$ti" "llama3")"
274+
for ((gi=0; gi<num_groups; ++gi)); do
275+
group_tag=$(jq -r ".test_groups[$gi].tag" "$CONFIG_FILE")
276+
if [[ ${#SELECTED_TAGS[@]} -gt 0 && -z "${SELECTED_TAGS[$group_tag]}" ]]; then
277+
continue
278+
fi
279+
280+
num_tests=$(jq ".test_groups[$gi].tests | length" "$CONFIG_FILE")
281+
echo -e "\033[1;36m[TEST GROUP] tag=${group_tag}, cases=${num_tests}\033[0m"
282+
283+
for ((ti=0; ti<num_tests; ++ti)); do
284+
test_id=$(jq -r ".test_groups[$gi].tests[$ti].id" "$CONFIG_FILE")
285+
gpt2_arg_str="$(args_string_for_test "$gi" "$ti" "gpt2" "$test_id")"
286+
llama3_arg_str="$(args_string_for_test "$gi" "$ti" "llama3" "$test_id")"
275287

276-
# gpt2
277-
gpt2_cmd="${prefix}./gpt2 --input_bin ${GPT2_INPUT_BIN} --llmc_filepath ${GPT2_LLMC_FILEPATH} --device cuda ${gpt2_arg_str}"
278-
run_and_log "$gpt2_cmd" "gpt2_${test_id}${log_suffix}" "$profile_flag"
288+
# gpt2
289+
gpt2_cmd="${prefix}./gpt2 --input_bin ${GPT2_INPUT_BIN} --llmc_filepath ${GPT2_LLMC_FILEPATH} --device cuda ${gpt2_arg_str}"
290+
run_and_log "$gpt2_cmd" "gpt2_${test_id}${log_suffix}" "$profile_flag" "$group_tag"
279291

280-
# llama3
281-
llama3_cmd="${prefix}./llama3 --input_bin ${LLAMA3_INPUT_BIN} --llmc_filepath ${LLAMA3_LLMC_FILEPATH} --device cuda ${llama3_arg_str}"
282-
run_and_log "$llama3_cmd" "llama3_${test_id}${log_suffix}" "$profile_flag"
292+
# llama3
293+
llama3_cmd="${prefix}./llama3 --input_bin ${LLAMA3_INPUT_BIN} --llmc_filepath ${LLAMA3_LLMC_FILEPATH} --device cuda ${llama3_arg_str}"
294+
run_and_log "$llama3_cmd" "llama3_${test_id}${log_suffix}" "$profile_flag" "$group_tag"
295+
done
283296
done
284297
done
285298

scripts/test_config.json

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,70 @@
526526
}
527527
}
528528
]
529+
},
530+
{
531+
"tag": "checkpoint",
532+
"tests": [
533+
{
534+
"id": "ckpt_no_resume",
535+
"args": {
536+
"num_iteration": 50,
537+
"save_steps": 10,
538+
"checkpoint_dir": "/data1/ckpt/no_resume"
539+
}
540+
},
541+
{
542+
"id": "ckpt_resume",
543+
"args": {
544+
"num_iteration": 50,
545+
"save_steps": 10,
546+
"resume_from": "/data1/ckpt/no_resume/checkpoint_step_000030",
547+
"checkpoint_dir": "/data1/ckpt/base_resume"
548+
}
549+
},
550+
{
551+
"id": "ckpt_bf16",
552+
"args": {
553+
"dtype": "bfloat16",
554+
"num_iteration": 50,
555+
"save_steps": 10,
556+
"checkpoint_dir": "/data1/ckpt/bf16_no_resume"
557+
}
558+
},
559+
{
560+
"id": "ckpt_bf16_resume",
561+
"args": {
562+
"dtype": "bfloat16",
563+
"num_iteration": 50,
564+
"save_steps": 10,
565+
"resume_from": "/data1/ckpt/bf16_no_resume/checkpoint_step_000030",
566+
"checkpoint_dir": "/data1/ckpt/bf16_resume"
567+
}
568+
},
569+
{
570+
"id": "ckpt_lora",
571+
"args": {
572+
"num_iteration": 50,
573+
"save_steps": 10,
574+
"lora_rank": 8,
575+
"lora_alpha": 16.0,
576+
"lora_target_modules": "c_attn,attn.c_proj",
577+
"checkpoint_dir": "/data1/ckpt/lora_no_resume"
578+
}
579+
},
580+
{
581+
"id": "ckpt_lora_resume",
582+
"args": {
583+
"num_iteration": 50,
584+
"save_steps": 10,
585+
"lora_rank": 8,
586+
"lora_alpha": 16.0,
587+
"lora_target_modules": "c_attn,attn.c_proj",
588+
"resume_from": "/data1/ckpt/lora_no_resume/checkpoint_step_000030",
589+
"checkpoint_dir": "/data1/ckpt/lora_resume"
590+
}
591+
}
592+
]
529593
}
530594
]
531595
}

scripts/test_resume.json

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)