Skip to content

Commit d246064

Browse files
author
Roja Reddy Sareddy
committed
Nova,llmft training support
1 parent 34f53aa commit d246064

File tree

4 files changed

+334
-8
lines changed

4 files changed

+334
-8
lines changed

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
_get_args_from_recipe,
108108
_determine_device_type,
109109
_is_nova_recipe,
110+
_is_llmft_recipe,
110111
_load_base_recipe,
111112
)
112113

@@ -258,6 +259,7 @@ class ModelTrainer(BaseModel):
258259
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
259260

260261
_is_nova_recipe: Optional[bool] = PrivateAttr(default=None)
262+
_is_llmft_recipe: Optional[bool] = PrivateAttr(default=None)
261263
# Private Attributes for Recipes
262264
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
263265

@@ -582,12 +584,12 @@ def _create_training_job_args(
582584

583585
final_input_data_config = list(existing_channels.values()) + new_channels
584586

585-
if self._is_nova_recipe:
587+
if self._is_nova_recipe or self._is_llmft_recipe:
586588
for input_data in final_input_data_config:
587589
if input_data.channel_name == SM_RECIPE:
588590
raise ValueError(
589591
"Cannot use reserved channel name 'recipe' as an input channel name "
590-
" for Nova Recipe"
592+
" for Nova or LLMFT Recipe"
591593
)
592594
recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML)
593595
recipe_channel = self.create_input_data_channel(
@@ -596,7 +598,8 @@ def _create_training_job_args(
596598
key_prefix=input_data_key_prefix,
597599
)
598600
final_input_data_config.append(recipe_channel)
599-
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})
601+
if self._is_nova_recipe or self._is_llmft_recipe:
602+
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})
600603

601604
if final_input_data_config:
602605
final_input_data_config = self._get_input_data_config(
@@ -1166,14 +1169,15 @@ def from_recipe(
11661169
training_recipe=training_recipe, recipe_overrides=recipe_overrides
11671170
)
11681171
is_nova = _is_nova_recipe(recipe=recipe)
1169-
if device_type == "cpu" and not is_nova:
1172+
is_llmft = _is_llmft_recipe(recipe=recipe)
1173+
if device_type == "cpu" and not (is_nova or is_llmft):
11701174
raise ValueError(
11711175
"Training recipes are not supported for CPU instances. "
11721176
"Please provide a GPU or Tranium instance type."
11731177
)
11741178

1175-
if training_image is None and is_nova:
1176-
raise ValueError("training_image must be provided when using recipe for Nova.")
1179+
if training_image is None and (is_nova or is_llmft):
1180+
raise ValueError("training_image must be provided when using recipe for Nova or LLMFT")
11771181

11781182
if training_image_config and training_image is None:
11791183
raise ValueError("training_image must be provided when using training_image_config.")
@@ -1200,7 +1204,7 @@ def from_recipe(
12001204

12011205
if hyperparameters and not is_nova:
12021206
logger.warning(
1203-
"Hyperparameters are not supported for general training recipes. "
1207+
"Hyperparameters are not supported for general and LLMFT training recipes. "
12041208
+ "Ignoring hyperparameters input."
12051209
)
12061210
if is_nova:
@@ -1227,6 +1231,7 @@ def from_recipe(
12271231
)
12281232

12291233
model_trainer._is_nova_recipe = is_nova
1234+
model_trainer._is_llmft_recipe = is_llmft
12301235
model_trainer._temp_recipe_train_dir = tmp_dir
12311236
return model_trainer
12321237

sagemaker-train/src/sagemaker/train/sm_recipes/utils.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ def _get_args_from_recipe(
280280
if _is_nova_recipe(recipe):
281281
args, recipe_local_dir = _get_args_from_nova_recipe(recipe, compute, role=role)
282282
return args, recipe_local_dir
283+
if _is_llmft_recipe(recipe):
284+
args, recipe_local_dir = _get_args_from_llmft_recipe(recipe, compute)
285+
return args, recipe_local_dir
283286

284287
if "trainer" not in recipe:
285288
raise ValueError("Supplied recipe does not contain required field trainer.")
@@ -456,4 +459,74 @@ def _get_args_from_nova_recipe(
456459
"distributed": None,
457460
}
458461
)
459-
return args, recipe_local_dir
462+
return args, recipe_local_dir
463+
464+
def _resolve_final_recipe(recipe: dictconfig.DictConfig):
465+
"""Resolve final recipe."""
466+
final_recipe = _try_resolve_recipe(recipe)
467+
if final_recipe is None:
468+
final_recipe = _try_resolve_recipe(recipe, "recipes")
469+
if final_recipe is None:
470+
final_recipe = _try_resolve_recipe(recipe, "training")
471+
if final_recipe is None:
472+
raise RuntimeError("Could not resolve provided recipe.")
473+
474+
return final_recipe
475+
476+
def _is_llmft_recipe(
477+
recipe: dictconfig.DictConfig,
478+
) -> bool:
479+
"""Check if the recipe is a LLMFT recipe.
480+
481+
A recipe is considered a LLMFT recipe if it meets the following conditions:
482+
1. Having a run section
483+
2. The model_type in run is llm_finetuning_aws or verl
484+
3. Having a training_config section
485+
486+
Args:
487+
recipe (DictConfig): The loaded recipe configuration
488+
489+
Returns:
490+
bool: True if the recipe is a LLMFT recipe, False otherwise
491+
"""
492+
run_config = recipe.get("run", {})
493+
model_type = run_config.get("model_type", "").lower()
494+
has_llmft_model = model_type == "llm_finetuning_aws"
495+
has_verl_model = model_type == "verl"
496+
return (bool(has_llmft_model) or bool(has_verl_model)) and bool(recipe.get("training_config"))
497+
498+
def _get_args_from_llmft_recipe(
499+
recipe: dictconfig.DictConfig,
500+
compute: Compute,
501+
) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]:
502+
503+
if not compute.instance_count and not recipe.get("trainer", {}).get("num_nodes", None):
504+
raise ValueError(
505+
"Must set ``instance_count`` in compute or ``num_nodes`` in trainer in recipe."
506+
)
507+
if compute.instance_count and recipe.get("trainer", {}).get("num_nodes", None) is not None:
508+
logger.warning(
509+
f"Using Compute to set instance_count:\n{compute}."
510+
"\nIgnoring trainer -> num_nodes in recipe."
511+
)
512+
compute.instance_count = compute.instance_count or recipe.get("trainer", {}).get("num_nodes")
513+
514+
args = dict()
515+
516+
_register_custom_resolvers()
517+
final_recipe = _resolve_final_recipe(recipe)
518+
519+
# Save Final Recipe to tmp dir
520+
recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_")
521+
final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML)
522+
OmegaConf.save(config=final_recipe, f=final_recipe_path)
523+
524+
args.update(
525+
{
526+
"compute": compute,
527+
"training_image": None,
528+
"source_code": None,
529+
"distributed": None,
530+
}
531+
)
532+
return args, recipe_local_dir

sagemaker-train/tests/unit/train/sm_recipes/test_utils.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from unittest.mock import patch, MagicMock
1818

1919
import yaml
20+
from omegaconf import OmegaConf
2021
from urllib.request import urlretrieve
2122
from tempfile import NamedTemporaryFile
2223

@@ -28,7 +29,9 @@
2829
_configure_trainium_args,
2930
_get_trainining_recipe_gpu_model_name_and_script,
3031
_is_nova_recipe,
32+
_is_llmft_recipe,
3133
_get_args_from_nova_recipe,
34+
_get_args_from_llmft_recipe,
3235
)
3336
from sagemaker.train.utils import _run_clone_command_silent
3437
from sagemaker.train.configs import Compute
@@ -272,3 +275,158 @@ def test_get_args_from_recipe_with_evaluation(temporary_recipe):
272275
assert args["hyperparameters"]["lambda_arn"] == "arn:aws:lambda:us-east-1:123456789012:function:MyFunc"
273276
finally:
274277
os.unlink(recipe_path)
278+
279+
@pytest.mark.parametrize(
280+
"test_case",
281+
[
282+
{
283+
"recipe": {
284+
"run": {
285+
"name": "dummy-model",
286+
"model_type": "llm_finetuning_aws",
287+
},
288+
"trainer": {"num_nodes": "12"},
289+
"training_config": {"model_save_name": "xyz"},
290+
},
291+
"is_llmft": True,
292+
},
293+
{
294+
"recipe": {
295+
"run": {
296+
"name": "dummy-model",
297+
"model_type": "llm_finetuning_aws",
298+
},
299+
"training_config": {"model_save_name": "xyz"},
300+
},
301+
"is_llmft": True,
302+
},
303+
{
304+
"recipe": {
305+
"run": {
306+
"name": "dummy-model",
307+
"model_type": "llm_finetuning_aws",
308+
},
309+
},
310+
"is_llmft": False,
311+
},
312+
{
313+
"recipe": {
314+
"run": {
315+
"name": "dummy-model",
316+
"model_type": "xyz",
317+
},
318+
"training_config": {"model_save_name": "xyz"},
319+
},
320+
"is_llmft": False,
321+
},
322+
{
323+
"recipe": {
324+
"run": {
325+
"name": "verl-grpo-llama",
326+
"model_type": "verl",
327+
},
328+
"trainer": {"num_nodes": "1"},
329+
"training_config": {"trainer": {"total_epochs": 2}},
330+
},
331+
"is_llmft": True,
332+
},
333+
{
334+
"recipe": {
335+
"run": {
336+
"name": "verl-grpo-llama",
337+
"model_type": "verl",
338+
},
339+
},
340+
"is_llmft": False,
341+
},
342+
],
343+
ids=[
344+
"llmft_model",
345+
"llmft_model_subtype",
346+
"llmft_missing_training_config",
347+
"non_llmft_model",
348+
"verl_model",
349+
"verl_missing_training_config",
350+
],
351+
)
352+
def test_is_llmft_recipe(test_case):
353+
recipe = OmegaConf.create(test_case["recipe"])
354+
is_llmft = _is_llmft_recipe(recipe)
355+
assert is_llmft == test_case["is_llmft"]
356+
357+
358+
@patch("sagemaker.train.sm_recipes.utils._get_args_from_llmft_recipe")
359+
def test_get_args_from_recipe_with_llmft_and_role(mock_get_args_from_llmft_recipe):
360+
# Set up mock return value
361+
mock_args = {}
362+
mock_dir = MagicMock()
363+
mock_get_args_from_llmft_recipe.return_value = (mock_args, mock_dir)
364+
365+
recipe = {
366+
"run": {
367+
"name": "dummy-model",
368+
"model_type": "llm_finetuning_aws",
369+
},
370+
"trainer": {"num_nodes": "12"},
371+
"training_config": {"model_save_name": "xyz"},
372+
}
373+
compute = Compute(instance_type="ml.g5.xlarge")
374+
role = "arn:aws:iam::123456789012:role/SageMakerRole"
375+
376+
# Mock the LLMFT recipe detection to return True
377+
with patch("sagemaker.train.sm_recipes.utils._is_llmft_recipe", return_value=True):
378+
_get_args_from_recipe(
379+
training_recipe=recipe,
380+
compute=compute,
381+
region_name="us-west-2",
382+
recipe_overrides=None,
383+
requirements=None,
384+
role=role,
385+
)
386+
387+
# Verify _get_args_from_llmft_recipe was called
388+
mock_get_args_from_llmft_recipe.assert_called_once_with(recipe, compute)
389+
390+
391+
@pytest.mark.parametrize(
392+
"test_case",
393+
[
394+
{
395+
"recipe": {
396+
"run": {
397+
"name": "dummy-model",
398+
"model_type": "llm_finetuning_aws",
399+
},
400+
"trainer": {"num_nodes": "12"},
401+
"training_config": {"model_save_name": "xyz"},
402+
},
403+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
404+
"expected_args": {
405+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
406+
"training_image": None,
407+
"source_code": None,
408+
"distributed": None,
409+
},
410+
},
411+
{
412+
"recipe": {
413+
"run": {
414+
"name": "dummy-model",
415+
"model_type": "llm_finetuning_aws",
416+
},
417+
"training_config": {"model_save_name": "xyz"},
418+
},
419+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
420+
"expected_args": {
421+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
422+
"training_image": None,
423+
"source_code": None,
424+
"distributed": None,
425+
},
426+
},
427+
],
428+
)
429+
def test_get_args_from_llmft_recipe(test_case):
430+
recipe = OmegaConf.create(test_case["recipe"])
431+
args, _ = _get_args_from_llmft_recipe(recipe=recipe, compute=test_case["compute"])
432+
assert args == test_case["expected_args"]

0 commit comments

Comments
 (0)