Skip to content

Commit 8f44fb1

Browse files
author
Harsh Thakkar
committed
Adding verl support
1 parent 00124a9 commit 8f44fb1

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def _create_training_job_args(
649649
key_prefix=input_data_key_prefix,
650650
)
651651
final_input_data_config.append(recipe_channel)
652-
if self._is_nova_recipe:
652+
if self._is_nova_recipe or self._is_llmft_recipe:
653653
self.hyperparameters.update(
654654
{"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH}
655655
)

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def _is_llmft_recipe(
289289
290290
A recipe is considered a LLMFT recipe if it meets the following conditions:
291291
1. Having a run section
292-
2. The model_type in run is llmft
292+
2. The model_type in run is llm_finetuning_aws or verl
293293
3. Having a training_config section
294294
295295
Args:
@@ -299,8 +299,10 @@ def _is_llmft_recipe(
299299
bool: True if the recipe is a LLMFT recipe, False otherwise
300300
"""
301301
run_config = recipe.get("run", {})
302-
has_llmft_model = run_config.get("model_type", "").lower() == "llm_finetuning_aws"
303-
return bool(has_llmft_model) and bool(recipe.get("training_config"))
302+
model_type = run_config.get("model_type", "").lower()
303+
has_llmft_model = model_type == "llm_finetuning_aws"
304+
has_verl_model = model_type == "verl"
305+
return (bool(has_llmft_model) or bool(has_verl_model)) and bool(recipe.get("training_config"))
304306

305307

306308
def _get_args_from_nova_recipe(

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,34 @@ def test_get_args_from_nova_recipe_with_evaluation(test_case):
545545
},
546546
"is_llmft": False,
547547
},
548+
{
549+
"recipe": {
550+
"run": {
551+
"name": "verl-grpo-llama",
552+
"model_type": "verl",
553+
},
554+
"trainer": {"num_nodes": "1"},
555+
"training_config": {"trainer": {"total_epochs": 2}},
556+
},
557+
"is_llmft": True,
558+
},
559+
{
560+
"recipe": {
561+
"run": {
562+
"name": "verl-grpo-llama",
563+
"model_type": "verl",
564+
},
565+
},
566+
"is_llmft": False,
567+
},
548568
],
549569
ids=[
550570
"llmft_model",
551571
"llmft_model_subtype",
552572
"llmft_missing_training_config",
553573
"non_llmft_model",
574+
"verl_model",
575+
"verl_missing_training_config",
554576
],
555577
)
556578
def test_is_llmft_recipe(test_case):

0 commit comments

Comments
 (0)