-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Refactor training framework with modular architecture #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
965545f
d836949
f7c9a3b
a3d0936
30df0e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,252 @@ | ||
| #!/usr/bin/env python3 | ||
| import json | ||
| import subprocess | ||
| import argparse | ||
| import os | ||
| import re | ||
| import time | ||
| import math | ||
| import threading | ||
| import uuid | ||
| import torch | ||
|
|
||
| # GPU Memory Monitor | ||
| gpu_peak_mem_mb = [0] | ||
|
|
||
| def monitor_gpu_memory(proc, poll_interval=0.5): | ||
| """Monitor GPU peak memory via nvidia-smi while proc is running. | ||
| Updates global gpu_peak_mem_mb[0] with max MiB seen. | ||
| """ | ||
| while True: | ||
| if proc.poll() is not None: | ||
| break | ||
| try: | ||
| out = subprocess.check_output( | ||
| ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"], | ||
| text=True, stderr=subprocess.DEVNULL | ||
| ) | ||
| # parse lines -> ints | ||
| mems = [int(x.strip()) for x in out.strip().splitlines() if x.strip()] | ||
| if mems: | ||
| gpu_peak_mem_mb[0] = max(gpu_peak_mem_mb[0], max(mems)) | ||
| except Exception: | ||
| # ignore transient read errors | ||
| pass | ||
| time.sleep(poll_interval) | ||
|
|
||
| # CLI | ||
| def parse_args(): | ||
| p = argparse.ArgumentParser() | ||
| p.add_argument("--config", type=str, required=True) | ||
| return p.parse_args() | ||
|
|
||
| def main(): | ||
| args = parse_args() | ||
|
|
||
| # load config | ||
| with open(args.config, "r") as f: | ||
| cfg = json.load(f) | ||
|
|
||
| train = cfg["config"].get("train_args", {}) | ||
| parallel = train.get("parallel", {}) | ||
|
|
||
| dp = parallel.get("dp", 1) | ||
| tp = parallel.get("tp", 1) | ||
| pp = parallel.get("pp", {}).get("value", 1) | ||
| sp = parallel.get("sp", 0) | ||
|
|
||
| # training params with defaults | ||
| mbs = train.get("mbs", 1) | ||
| gbs = train.get("gbs", 1) | ||
| seq = train.get("seq_len", 128) | ||
| lr = train.get("lr", 0.00015) | ||
| step = train.get("step", 10) | ||
| num_layers = train.get("num_layers", 2) | ||
| hidden_size = train.get("hidden_size", 512) | ||
| num_attention_heads = train.get("num_attention_heads", 8) | ||
| max_position_embeddings = train.get("max_position_embeddings", seq) | ||
| vocab_size = train.get("vocab_size", 128256) | ||
|
|
||
| # Generate a unique run_id | ||
| random_uuid = str(uuid.uuid4()) | ||
| run_id = f"train.megatron.SFT.{random_uuid}" | ||
| print(f"Generated run_id: {run_id}") | ||
|
|
||
| # determine world size / nproc_per_node | ||
| available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 | ||
| desired_world = max(1, dp * tp * pp) | ||
| if available_gpus > 0: | ||
| nproc_per_node = min(desired_world, available_gpus) | ||
| else: | ||
| nproc_per_node = max(1, desired_world) | ||
|
|
||
| # build torchrun + megatron args | ||
| torchrun_cmd = [ | ||
| "torchrun", | ||
| f"--nproc_per_node={nproc_per_node}", | ||
| "--master_port=29501" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个port最好不要硬编码 |
||
| ] | ||
|
|
||
| megatron_args = [ | ||
| "pretrain_gpt.py", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个脚本是只针对gpt类的模型么? |
||
| f"--tensor-model-parallel-size={tp}", | ||
| f"--pipeline-model-parallel-size={pp}", | ||
| f"--micro-batch-size={mbs}", | ||
| f"--global-batch-size={gbs}", | ||
| f"--seq-length={seq}", | ||
| f"--lr={lr}", | ||
| f"--train-iters={step}", | ||
| f"--num-layers={num_layers}", | ||
| f"--hidden-size={hidden_size}", | ||
| f"--num-attention-heads={num_attention_heads}", | ||
| f"--max-position-embeddings={max_position_embeddings}", | ||
| f"--vocab-size={vocab_size}", | ||
| "--mock-data", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果我们的config里面有数据集,你这里传mock data是不是不行啊? |
||
| "--tokenizer-type", "NullTokenizer", | ||
| "--transformer-impl", "local", | ||
| "--bf16", | ||
| "--no-gradient-accumulation-fusion", | ||
| "--no-persist-layer-norm", | ||
| "--log-interval", "1", | ||
| "--log-throughput" | ||
| ] | ||
|
|
||
| if sp == 1: | ||
| megatron_args.append("--sequence-parallel") | ||
|
|
||
| cmd = torchrun_cmd + megatron_args | ||
| print("Launching:", " ".join(cmd)) | ||
|
|
||
| # output paths | ||
| output_dir = "./train" | ||
| os.makedirs(output_dir, exist_ok=True) | ||
| log_file = os.path.join(output_dir, f"{run_id}_train.log") | ||
| loss_csv = os.path.join(output_dir, f"{run_id}_train_loss.csv") | ||
| ppl_csv = os.path.join(output_dir, f"{run_id}_train_ppl.csv") | ||
| throughput_csv = os.path.join(output_dir, f"{run_id}_train_throughput.csv") | ||
| result_json = os.path.join(output_dir, f"{run_id}_result.json") | ||
|
|
||
| # regex patterns | ||
| loss_pattern = re.compile(r"lm loss:\s*([+\-]?\d+(?:\.\d+)?(?:[Ee][+\-]?\d+)?)", re.IGNORECASE) | ||
| #ppl_pattern_alt = re.compile(r"lm loss PPL:\s*([+\-]?\d+(?:\.\d+)?(?:[Ee][+\-]?\d+)?)", re.IGNORECASE) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个还需要不?不需要的话可以去掉 |
||
| elapsed_pattern = re.compile(r"elapsed time per iteration \(ms\):\s*([0-9]*\.?[0-9]+)") | ||
|
|
||
| losses = [] | ||
| throughputs = [] | ||
|
|
||
| # launch process | ||
| process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) | ||
|
|
||
| # start gpu monitor thread | ||
| monitor_thread = threading.Thread(target=monitor_gpu_memory, args=(process,), daemon=True) | ||
| monitor_thread.start() | ||
|
|
||
| # read stdout line-by-line and parse | ||
| with open(log_file, "w") as flog: | ||
| for line in process.stdout: | ||
| # print to console and write to log | ||
| print(line, end="") | ||
| flog.write(line) | ||
|
|
||
| # try match loss | ||
| m = loss_pattern.search(line) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个地方有没有跳过那些warmup的训练环节呢? |
||
| if m: | ||
| try: | ||
| val = float(m.group(1)) | ||
| losses.append(val) | ||
| except Exception: | ||
| pass | ||
|
|
||
| # try match elapsed -> throughput | ||
| me = elapsed_pattern.search(line) | ||
| if me: | ||
| try: | ||
| elapsed_ms = float(me.group(1)) | ||
| tokens_per_iter = mbs * seq | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个地方是不是也要除以显卡的数量啊? |
||
| # throughput tokens per second | ||
| throughput = tokens_per_iter / (elapsed_ms / 1000.0) if elapsed_ms > 0 else 0.0 | ||
| throughputs.append(throughput) | ||
| except Exception: | ||
| pass | ||
|
|
||
| # wait for process end and thread join | ||
| process.wait() | ||
| monitor_thread.join() | ||
|
|
||
| peak_memory_gb = gpu_peak_mem_mb[0] / 1024.0 | ||
|
|
||
| # compute PPL from loss | ||
| ppls = [] | ||
| for loss in losses: | ||
| try: | ||
| ppls.append(float(math.exp(loss))) | ||
| except OverflowError: | ||
| ppls.append(float("inf")) | ||
|
|
||
| with open(loss_csv, "w") as f: | ||
| f.write("iteration,loss\n") | ||
| for i, v in enumerate(losses, start=1): | ||
| f.write(f"{i},{v}\n") | ||
|
|
||
| with open(ppl_csv, "w") as f: | ||
| f.write("iteration,ppl\n") | ||
| for i, v in enumerate(ppls, start=1): | ||
| f.write(f"{i},{v}\n") | ||
|
|
||
| with open(throughput_csv, "w") as f: | ||
| f.write("iteration,throughput\n") | ||
| for i, v in enumerate(throughputs, start=1): | ||
| f.write(f"{i},{v}\n") | ||
|
|
||
| # create result json | ||
| result = { | ||
| "config": { | ||
| "command": " ".join(cmd), | ||
| "model": cfg.get("config", {}).get("model", "Megatron-GPT"), | ||
| "model_config": cfg.get("config", {}).get("model_config", ""), | ||
| "train_dataset": cfg.get("config", {}).get("train_dataset", "mock"), | ||
| "validation_dataset": cfg.get("config", {}).get("validation_dataset", None), | ||
| "test_dataset": cfg.get("config", {}).get("test_dataset", None), | ||
| "train_args": train, | ||
| "timeout_ms": train.get("timeout_ms", 10000), | ||
| "warmup_iterations": train.get("warmup_iterations", 100), | ||
| "measured_iterations": train.get("measured_iterations", step) | ||
| }, | ||
| "metrics": [ | ||
| { | ||
| "name": "train.throughput", | ||
| "type": "timeseries", | ||
| "raw_data_url": throughput_csv, | ||
| "unit": "tokens/s/gpu" | ||
| }, | ||
| { | ||
| "name": "train.peak_memory_usage", | ||
| "type": "scalar", | ||
| "value": peak_memory_gb, | ||
| "unit": "GB" | ||
| }, | ||
| { | ||
| "name": "train.loss", | ||
| "type": "timeseries", | ||
| "raw_data_url": loss_csv, | ||
| "unit": "" | ||
| }, | ||
| { | ||
| "name": "train.ppl", | ||
| "type": "timeseries", | ||
| "raw_data_url": ppl_csv, | ||
| "unit": None | ||
| } | ||
| ] | ||
| } | ||
|
|
||
| with open(result_json, "w") as f: | ||
| json.dump(result, f, indent=2) | ||
|
|
||
| print(f"\nResult JSON written to {result_json}") | ||
| print("Log written to", log_file) | ||
| print("CSV files:", loss_csv, ppl_csv, throughput_csv) | ||
| print(f"Peak GPU memory (GiB): {peak_memory_gb:.6f}") | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果pp的类型不是 G-pipe,应该怎么办呢?