Skip to content

Commit f1171d4

Browse files
author
Roja Reddy Sareddy
committed
Nova training support
1 parent 33bf993 commit f1171d4

File tree

5 files changed

+279
-15
lines changed

5 files changed

+279
-15
lines changed

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,7 @@
6262
"mistral.mistral-large-2402-v1:0": ["us-west-2", "us-east-1", "eu-west-1"],
6363
"amazon.nova-pro-v1:0": ["us-east-1"]
6464
}
65+
66+
SM_RECIPE = "recipe"
67+
SM_RECIPE_YAML = "recipe.yaml"
68+
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@
8383
SM_CODE_CONTAINER_PATH,
8484
SM_DRIVERS,
8585
SM_DRIVERS_LOCAL_PATH,
86+
SM_RECIPE,
87+
SM_RECIPE_YAML,
88+
SM_RECIPE_CONTAINER_PATH,
8689
TRAIN_SCRIPT,
8790
DEFAULT_CONTAINER_ENTRYPOINT,
8891
DEFAULT_CONTAINER_ARGUMENTS,
@@ -100,7 +103,12 @@
100103
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
101104
from sagemaker.core.telemetry.constants import Feature
102105
from sagemaker.train import logger
103-
from sagemaker.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type
106+
from sagemaker.train.sm_recipes.utils import (
107+
_get_args_from_recipe,
108+
_determine_device_type,
109+
_is_nova_recipe,
110+
_load_base_recipe,
111+
)
104112

105113
from sagemaker.core.jumpstart.configs import JumpStartConfig
106114
from sagemaker.core.jumpstart.document import get_hub_content_and_document
@@ -249,6 +257,7 @@ class ModelTrainer(BaseModel):
249257
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
250258
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
251259

260+
_is_nova_recipe: Optional[bool] = PrivateAttr(default=None)
252261
# Private Attributes for Recipes
253262
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
254263

@@ -573,6 +582,22 @@ def _create_training_job_args(
573582

574583
final_input_data_config = list(existing_channels.values()) + new_channels
575584

585+
if self._is_nova_recipe:
586+
for input_data in final_input_data_config:
587+
if input_data.channel_name == SM_RECIPE:
588+
raise ValueError(
589+
"Cannot use reserved channel name 'recipe' as an input channel name "
590+
" for Nova Recipe"
591+
)
592+
recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML)
593+
recipe_channel = self.create_input_data_channel(
594+
channel_name=SM_RECIPE,
595+
data_source=recipe_file_path,
596+
key_prefix=input_data_key_prefix,
597+
)
598+
final_input_data_config.append(recipe_channel)
599+
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})
600+
576601
if final_input_data_config:
577602
final_input_data_config = self._get_input_data_config(
578603
final_input_data_config, input_data_key_prefix
@@ -1039,6 +1064,7 @@ def from_recipe(
10391064
checkpoint_config: Optional[shapes.CheckpointConfig] = None,
10401065
training_input_mode: Optional[str] = "File",
10411066
environment: Optional[Dict[str, str]] = None,
1067+
hyperparameters: Optional[Union[Dict[str, Any], str]] = {},
10421068
tags: Optional[List[Tag]] = None,
10431069
sagemaker_session: Optional[Session] = None,
10441070
role: Optional[str] = None,
@@ -1136,12 +1162,19 @@ def from_recipe(
11361162
if compute.instance_type is None:
11371163
raise ValueError("Must set ``instance_type`` in Compute when using training recipes.")
11381164
device_type = _determine_device_type(compute.instance_type)
1139-
if device_type == "cpu":
1165+
recipe = _load_base_recipe(
1166+
training_recipe=training_recipe, recipe_overrides=recipe_overrides
1167+
)
1168+
is_nova = _is_nova_recipe(recipe=recipe)
1169+
if device_type == "cpu" and not is_nova:
11401170
raise ValueError(
11411171
"Training recipes are not supported for CPU instances. "
11421172
"Please provide a GPU or Tranium instance type."
11431173
)
11441174

1175+
if training_image is None and is_nova:
1176+
raise ValueError("training_image must be provided when using recipe for Nova.")
1177+
11451178
if training_image_config and training_image is None:
11461179
raise ValueError("training_image must be provided when using training_image_config.")
11471180

@@ -1154,16 +1187,29 @@ def from_recipe(
11541187
# - distributed
11551188
# - compute
11561189
# - hyperparameters
1157-
model_trainer_args, recipe_train_dir = _get_args_from_recipe(
1158-
training_recipe=training_recipe,
1190+
model_trainer_args, tmp_dir = _get_args_from_recipe(
1191+
training_recipe=recipe,
11591192
recipe_overrides=recipe_overrides,
11601193
requirements=requirements,
11611194
compute=compute,
11621195
region_name=sagemaker_session.boto_region_name,
1196+
role=role,
11631197
)
11641198
if training_image is not None:
11651199
model_trainer_args["training_image"] = training_image
11661200

1201+
if hyperparameters and not is_nova:
1202+
logger.warning(
1203+
"Hyperparameters are not supported for general training recipes. "
1204+
+ "Ignoring hyperparameters input."
1205+
)
1206+
if is_nova:
1207+
if hyperparameters and isinstance(hyperparameters, str):
1208+
hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters)
1209+
model_trainer_args["hyperparameters"].update(hyperparameters)
1210+
elif hyperparameters and isinstance(hyperparameters, dict):
1211+
model_trainer_args["hyperparameters"].update(hyperparameters)
1212+
11671213
model_trainer = cls(
11681214
sagemaker_session=sagemaker_session,
11691215
role=role,
@@ -1180,7 +1226,8 @@ def from_recipe(
11801226
**model_trainer_args,
11811227
)
11821228

1183-
model_trainer._temp_recipe_train_dir = recipe_train_dir
1229+
model_trainer._is_nova_recipe = is_nova
1230+
model_trainer._temp_recipe_train_dir = tmp_dir
11841231
return model_trainer
11851232

11861233
@classmethod

sagemaker-train/src/sagemaker/train/sm_recipes/utils.py

Lines changed: 127 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import shutil
2020
import tempfile
2121
from urllib.request import urlretrieve
22-
from typing import Dict, Any, Optional, Tuple
22+
from typing import Dict, Any, Optional, Tuple, Union
2323

2424
import omegaconf
2525
from omegaconf import OmegaConf, dictconfig
@@ -30,6 +30,7 @@
3030
from sagemaker.train.utils import _run_clone_command_silent
3131
from sagemaker.train.configs import Compute, SourceCode
3232
from sagemaker.train.distributed import Torchrun, SMP
33+
from sagemaker.train.constants import SM_RECIPE_YAML
3334

3435

3536
def _try_resolve_recipe(recipe, key=None):
@@ -86,6 +87,8 @@ def _load_base_recipe(
8687
)
8788
else:
8889
recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
90+
if training_recipes_cfg is None:
91+
training_recipes_cfg = _load_recipes_cfg()
8992

9093
launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get(
9194
"launcher_repo"
@@ -149,7 +152,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
149152
def _configure_gpu_args(
150153
training_recipes_cfg: Dict[str, Any],
151154
region_name: str,
152-
recipe: OmegaConf,
155+
recipe: dictconfig.DictConfig,
153156
recipe_train_dir: tempfile.TemporaryDirectory,
154157
) -> Dict[str, Any]:
155158
"""Configure arguments specific to GPU."""
@@ -234,11 +237,12 @@ def _configure_trainium_args(
234237

235238

236239
def _get_args_from_recipe(
237-
training_recipe: str,
240+
training_recipe: Union[str, dictconfig.DictConfig],
238241
compute: Compute,
239242
region_name: str,
240243
recipe_overrides: Optional[Dict[str, Any]],
241244
requirements: Optional[str],
245+
role: Optional[str] = None,
242246
) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]:
243247
"""Get arguments for ModelTrainer from a training recipe.
244248
@@ -254,8 +258,8 @@ def _get_args_from_recipe(
254258
```
255259
256260
Args:
257-
training_recipe (str):
258-
Name of the training recipe or path to the recipe file.
261+
training_recipe (Union[str, Dict[str, Any]]):
262+
Name of the training recipe or path to the recipe file or loaded recipe Dict.
259263
compute (Compute):
260264
Compute configuration for training.
261265
region_name (str):
@@ -269,7 +273,13 @@ def _get_args_from_recipe(
269273
raise ValueError("Must set `instance_type` in compute when using training recipes.")
270274

271275
training_recipes_cfg = _load_recipes_cfg()
272-
recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg)
276+
if isinstance(training_recipe, str):
277+
recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg)
278+
else:
279+
recipe = training_recipe
280+
if _is_nova_recipe(recipe):
281+
args, recipe_local_dir = _get_args_from_nova_recipe(recipe, compute, role=role)
282+
return args, recipe_local_dir
273283

274284
if "trainer" not in recipe:
275285
raise ValueError("Supplied recipe does not contain required field trainer.")
@@ -283,7 +293,7 @@ def _get_args_from_recipe(
283293
if compute.instance_count is None:
284294
if "num_nodes" not in recipe["trainer"]:
285295
raise ValueError(
286-
"Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe."
296+
"Must provide Compute with instance_count or set trainer -> num_nodes in recipe."
287297
)
288298
compute.instance_count = recipe["trainer"]["num_nodes"]
289299

@@ -313,7 +323,7 @@ def _get_args_from_recipe(
313323

314324
# Save Final Recipe to source_dir
315325
OmegaConf.save(
316-
config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml")
326+
config=final_recipe, f=os.path.join(args["source_code"].source_dir, SM_RECIPE_YAML)
317327
)
318328

319329
# If recipe_requirements is provided, copy it to source_dir
@@ -322,7 +332,7 @@ def _get_args_from_recipe(
322332
args["source_code"].requirements = os.path.basename(requirements)
323333

324334
# Update args with compute and hyperparameters
325-
hyperparameters = {"config-path": ".", "config-name": "recipe.yaml"}
335+
hyperparameters = {"config-path": ".", "config-name": SM_RECIPE_YAML}
326336

327337
# Handle eval custom lambda configuration
328338
if recipe.get("evaluation", {}):
@@ -339,3 +349,111 @@ def _get_args_from_recipe(
339349
)
340350

341351
return args, recipe_train_dir
352+
353+
def _is_nova_recipe(
354+
recipe: dictconfig.DictConfig,
355+
) -> bool:
356+
"""Check if the recipe is a Nova recipe.
357+
358+
A recipe is considered a Nova recipe if it meets either of the following conditions:
359+
360+
1. It has a run section with:
361+
- A model_type that includes "amazon.nova"
362+
- A model_name_or_path field
363+
364+
OR
365+
366+
2. It has a training_config section with:
367+
- A distillation_data field
368+
369+
Args:
370+
recipe (DictConfig): The loaded recipe configuration
371+
372+
Returns:
373+
bool: True if the recipe is a Nova recipe, False otherwise
374+
"""
375+
run_config = recipe.get("run", {})
376+
model_type = run_config.get("model_type", "").lower()
377+
has_nova_model = (
378+
model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config
379+
)
380+
381+
# Check for distillation data
382+
training_config = recipe.get("training_config", {})
383+
has_distillation = training_config.get("distillation_data") is not None
384+
return bool(has_nova_model) or bool(has_distillation)
385+
386+
def _get_args_from_nova_recipe(
387+
recipe: dictconfig.DictConfig,
388+
compute: Compute,
389+
role: Optional[str] = None,
390+
) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]:
391+
if not compute.instance_count and not recipe.get("run", {}).get("replicas", None):
392+
raise ValueError("Must set ``instance_type`` in compute or ``replicas`` in recipe.")
393+
compute.instance_count = compute.instance_count or recipe.get("run", {}).get("replicas")
394+
395+
args = dict()
396+
args.update({"hyperparameters": {}})
397+
398+
run_config = recipe.get("run", {})
399+
model_name_or_path = run_config.get("model_name_or_path")
400+
if model_name_or_path:
401+
if model_name_or_path.startswith("s3://"):
402+
args["hyperparameters"]["base_model_location"] = model_name_or_path
403+
else:
404+
args["hyperparameters"]["base_model"] = model_name_or_path
405+
406+
# Handle distillation configuration
407+
training_config = recipe.get("training_config", {})
408+
distillation_data = training_config.get("distillation_data")
409+
if bool(distillation_data):
410+
args["hyperparameters"]["distillation_data"] = distillation_data
411+
if not role:
412+
raise ValueError("Must provide 'role' parameter when using Nova distillation")
413+
args["hyperparameters"]["role_arn"] = role
414+
415+
kms_key = training_config.get("kms_key")
416+
if kms_key is None:
417+
raise ValueError(
418+
'Nova distillation job recipe requires "kms_key" field in "training_config"'
419+
)
420+
args["hyperparameters"]["kms_key"] = kms_key
421+
422+
# Handle eval custom lambda configuration
423+
if recipe.get("evaluation", {}):
424+
processor = recipe.get("processor", {})
425+
lambda_arn = processor.get("lambda_arn", "")
426+
if lambda_arn:
427+
args["hyperparameters"]["eval_lambda_arn"] = lambda_arn
428+
429+
# Handle reward lambda configuration
430+
run_config = recipe.get("run", {})
431+
reward_lambda_arn = run_config.get("reward_lambda_arn", "")
432+
if reward_lambda_arn:
433+
args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn
434+
435+
_register_custom_resolvers()
436+
437+
# Resolve Final Recipe
438+
final_recipe = _try_resolve_recipe(recipe)
439+
if final_recipe is None:
440+
final_recipe = _try_resolve_recipe(recipe, "recipes")
441+
if final_recipe is None:
442+
final_recipe = _try_resolve_recipe(recipe, "training")
443+
if final_recipe is None:
444+
raise RuntimeError("Could not resolve provided recipe.")
445+
446+
# Save Final Recipe to tmp dir
447+
recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_")
448+
final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML)
449+
OmegaConf.save(config=final_recipe, f=final_recipe_path)
450+
451+
args.update(
452+
{
453+
"compute": compute,
454+
"training_image": None,
455+
"source_code": None,
456+
"distributed": None,
457+
}
458+
)
459+
return args, recipe_local_dir

sagemaker-train/tests/unit/train/sm_recipes/test_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
_configure_gpu_args,
2828
_configure_trainium_args,
2929
_get_trainining_recipe_gpu_model_name_and_script,
30+
_is_nova_recipe,
31+
_get_args_from_nova_recipe,
3032
)
3133
from sagemaker.train.utils import _run_clone_command_silent
3234
from sagemaker.train.configs import Compute

0 commit comments

Comments
 (0)