File tree Expand file tree Collapse file tree 3 files changed +28
-4
lines changed
src/sagemaker/modules/train
tests/unit/sagemaker/modules/train/sm_recipes Expand file tree Collapse file tree 3 files changed +28
-4
lines changed Original file line number Diff line number Diff line change @@ -649,7 +649,7 @@ def _create_training_job_args(
649649 key_prefix = input_data_key_prefix ,
650650 )
651651 final_input_data_config .append (recipe_channel )
652- if self ._is_nova_recipe :
652+ if self ._is_nova_recipe or self . _is_llmft_recipe :
653653 self .hyperparameters .update (
654654 {"sagemaker_recipe_local_path" : SM_RECIPE_CONTAINER_PATH }
655655 )
Original file line number Diff line number Diff line change @@ -289,7 +289,7 @@ def _is_llmft_recipe(
289289
290290 A recipe is considered a LLMFT recipe if it meets the following conditions:
291291 1. Having a run section
292- 2. The model_type in run is llmft
292+ 2. The model_type in run is llm_finetuning_aws or verl
293293 3. Having a training_config section
294294
295295 Args:
@@ -299,8 +299,10 @@ def _is_llmft_recipe(
299299 bool: True if the recipe is a LLMFT recipe, False otherwise
300300 """
301301 run_config = recipe .get ("run" , {})
302- has_llmft_model = run_config .get ("model_type" , "" ).lower () == "llm_finetuning_aws"
303- return bool (has_llmft_model ) and bool (recipe .get ("training_config" ))
302+ model_type = run_config .get ("model_type" , "" ).lower ()
303+ has_llmft_model = model_type == "llm_finetuning_aws"
304+ has_verl_model = model_type == "verl"
305+ return (bool (has_llmft_model ) or bool (has_verl_model )) and bool (recipe .get ("training_config" ))
304306
305307
306308def _get_args_from_nova_recipe (
Original file line number Diff line number Diff line change @@ -545,12 +545,34 @@ def test_get_args_from_nova_recipe_with_evaluation(test_case):
545545 },
546546 "is_llmft" : False ,
547547 },
548+ {
549+ "recipe" : {
550+ "run" : {
551+ "name" : "verl-grpo-llama" ,
552+ "model_type" : "verl" ,
553+ },
554+ "trainer" : {"num_nodes" : "1" },
555+ "training_config" : {"trainer" : {"total_epochs" : 2 }},
556+ },
557+ "is_llmft" : True ,
558+ },
559+ {
560+ "recipe" : {
561+ "run" : {
562+ "name" : "verl-grpo-llama" ,
563+ "model_type" : "verl" ,
564+ },
565+ },
566+ "is_llmft" : False ,
567+ },
548568 ],
549569 ids = [
550570 "llmft_model" ,
551571 "llmft_model_subtype" ,
552572 "llmft_missing_training_config" ,
553573 "non_llmft_model" ,
574+ "verl_model" ,
575+ "verl_missing_training_config" ,
554576 ],
555577)
556578def test_is_llmft_recipe (test_case ):
You can’t perform that action at this time.
0 commit comments