Skip to content

Commit 0c44833

Browse files
committed
ci(e2e): fix torch 2.9.1 cudnn memory issue in e2e test
1 parent e2601f0 commit 0c44833

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

autotest/module/train.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ def get_cmd(config):
2626
)
2727
config["work_dir"] = work_dir
2828

29+
# this patch is for torch 2.9.1 Conv3d memory issue fix
30+
cudnn_patch = """TORCH_VERSION=$(python -c "import torch;print(torch.__version__.split('+')[0])"); if [[ $TORCH_VERSION == "2.9.1" ]]; then pip install nvidia-cudnn-cu12==9.15.1.9; fi; """
31+
2932
if train_type == "sft":
3033
command = (
3134
f"cd {current_dir}; pwd; pip install -e .[all]; pip install more-itertools; export GITHUB_RUN_ID={config.get('run_id')}; export WORK_DIR={work_dir}; "
35+
+ cudnn_patch
3236
+ f"torchrun --nproc-per-node {nproc_per_node} --master_addr=${{MASTER_ADDR}} --master_port=${{MASTER_PORT}} --nnodes=${{WORLD_SIZE}} --node_rank=${{RANK}} "
3337
+ f"xtuner/v1/train/cli/{train_type}.py"
3438
)
@@ -52,6 +56,7 @@ def get_cmd(config):
5256
infer_type = config.get("parameters", {}).get("infer_backend", "lmdeploy")
5357
command = (
5458
f"cd {current_dir}; pwd; pip install -e .[all]; export GITHUB_RUN_ID={config.get('run_id')}; export WORK_DIR={work_dir}; "
59+
+ cudnn_patch
5560
+ f"bash -x examples/v1/scripts/run_rl.sh {config_path} {infer_type} ${{MODEL_PATH}} ${{DATA_PATH}} ${{EVAL_DATA_PATH}}"
5661
)
5762
return command, config
@@ -64,11 +69,11 @@ def validate(config):
6469
config.get("base_path").get("base_baseline_path"), config.get("assert_info", {}).get("base_metric", None)
6570
)
6671
train_type = config.get("type")
67-
if train_type == 'sft':
72+
if train_type == "sft":
6873
cur_path = os.path.join(get_latest_subdir(work_dir), "logs/exp_tracking/rank0/tracker.jsonl")
6974
check_metrics = config.get("assert_info", {}).get("check_metrics", {})
7075
return check_result(config["case_name"], base_path, cur_path, check_metrics)
71-
elif train_type == 'rl':
76+
elif train_type == "rl":
7277
cur_path = os.path.join(get_latest_subdir(work_dir), "exp_tracking/tracker.jsonl")
7378
check_metrics = config.get("assert_info", {})
7479
return check_rl_result(config["case_name"], base_path, cur_path, check_metrics)
@@ -88,7 +93,9 @@ def post_action(config=None):
8893

8994

9095
def get_latest_subdir(work_dir):
91-
dirs = [d for d in os.listdir(work_dir) if os.path.isdir(os.path.join(work_dir, d)) and len(d) == 14 and d.isdigit()]
96+
dirs = [
97+
d for d in os.listdir(work_dir) if os.path.isdir(os.path.join(work_dir, d)) and len(d) == 14 and d.isdigit()
98+
]
9299

93100
if not dirs:
94101
return None

0 commit comments

Comments
 (0)