@@ -26,9 +26,13 @@ def get_cmd(config):
2626 )
2727 config ["work_dir" ] = work_dir
2828
29+ # this patch is for torch 2.9.1 Conv3d memory issue fix
30+ cudnn_patch = """TORCH_VERSION=$(python -c "import torch;print(torch.__version__.split('+')[0])"); if [[ $TORCH_VERSION == "2.9.1" ]]; then pip install nvidia-cudnn-cu12==9.15.1.9; fi; """
31+
2932 if train_type == "sft" :
3033 command = (
3134 f"cd { current_dir } ; pwd; pip install -e .[all]; pip install more-itertools; export GITHUB_RUN_ID={ config .get ('run_id' )} ; export WORK_DIR={ work_dir } ; "
35+ + cudnn_patch
3236 + f"torchrun --nproc-per-node { nproc_per_node } --master_addr=${{MASTER_ADDR}} --master_port=${{MASTER_PORT}} --nnodes=${{WORLD_SIZE}} --node_rank=${{RANK}} "
3337 + f"xtuner/v1/train/cli/{ train_type } .py"
3438 )
@@ -52,6 +56,7 @@ def get_cmd(config):
5256 infer_type = config .get ("parameters" , {}).get ("infer_backend" , "lmdeploy" )
5357 command = (
5458 f"cd { current_dir } ; pwd; pip install -e .[all]; export GITHUB_RUN_ID={ config .get ('run_id' )} ; export WORK_DIR={ work_dir } ; "
59+ + cudnn_patch
5560 + f"bash -x examples/v1/scripts/run_rl.sh { config_path } { infer_type } ${{MODEL_PATH}} ${{DATA_PATH}} ${{EVAL_DATA_PATH}}"
5661 )
5762 return command , config
@@ -64,11 +69,11 @@ def validate(config):
6469 config .get ("base_path" ).get ("base_baseline_path" ), config .get ("assert_info" , {}).get ("base_metric" , None )
6570 )
6671 train_type = config .get ("type" )
67- if train_type == ' sft' :
72+ if train_type == " sft" :
6873 cur_path = os .path .join (get_latest_subdir (work_dir ), "logs/exp_tracking/rank0/tracker.jsonl" )
6974 check_metrics = config .get ("assert_info" , {}).get ("check_metrics" , {})
7075 return check_result (config ["case_name" ], base_path , cur_path , check_metrics )
71- elif train_type == 'rl' :
76+ elif train_type == "rl" :
7277 cur_path = os .path .join (get_latest_subdir (work_dir ), "exp_tracking/tracker.jsonl" )
7378 check_metrics = config .get ("assert_info" , {})
7479 return check_rl_result (config ["case_name" ], base_path , cur_path , check_metrics )
@@ -88,7 +93,9 @@ def post_action(config=None):
8893
8994
9095def get_latest_subdir (work_dir ):
91- dirs = [d for d in os .listdir (work_dir ) if os .path .isdir (os .path .join (work_dir , d )) and len (d ) == 14 and d .isdigit ()]
96+ dirs = [
97+ d for d in os .listdir (work_dir ) if os .path .isdir (os .path .join (work_dir , d )) and len (d ) == 14 and d .isdigit ()
98+ ]
9299
93100 if not dirs :
94101 return None
0 commit comments