1919import shutil
2020import tempfile
2121from urllib .request import urlretrieve
22- from typing import Dict , Any , Optional , Tuple
22+ from typing import Dict , Any , Optional , Tuple , Union
2323
2424import omegaconf
2525from omegaconf import OmegaConf , dictconfig
3030from sagemaker .train .utils import _run_clone_command_silent
3131from sagemaker .train .configs import Compute , SourceCode
3232from sagemaker .train .distributed import Torchrun , SMP
33+ from sagemaker .train .constants import SM_RECIPE_YAML
3334
3435
3536def _try_resolve_recipe (recipe , key = None ):
@@ -86,6 +87,8 @@ def _load_base_recipe(
8687 )
8788 else :
8889 recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
90+ if training_recipes_cfg is None :
91+ training_recipes_cfg = _load_recipes_cfg ()
8992
9093 launcher_repo = os .environ .get ("TRAINING_LAUNCHER_GIT" , None ) or training_recipes_cfg .get (
9194 "launcher_repo"
@@ -149,7 +152,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
149152def _configure_gpu_args (
150153 training_recipes_cfg : Dict [str , Any ],
151154 region_name : str ,
152- recipe : OmegaConf ,
155+ recipe : dictconfig . DictConfig ,
153156 recipe_train_dir : tempfile .TemporaryDirectory ,
154157) -> Dict [str , Any ]:
155158 """Configure arguments specific to GPU."""
@@ -234,11 +237,12 @@ def _configure_trainium_args(
234237
235238
236239def _get_args_from_recipe (
237- training_recipe : str ,
240+ training_recipe : Union [ str , dictconfig . DictConfig ] ,
238241 compute : Compute ,
239242 region_name : str ,
240243 recipe_overrides : Optional [Dict [str , Any ]],
241244 requirements : Optional [str ],
245+ role : Optional [str ] = None ,
242246) -> Tuple [Dict [str , Any ], tempfile .TemporaryDirectory ]:
243247 """Get arguments for ModelTrainer from a training recipe.
244248
@@ -254,8 +258,8 @@ def _get_args_from_recipe(
254258 ```
255259
256260 Args:
257- training_recipe (str):
258- Name of the training recipe or path to the recipe file.
261+ training_recipe (Union[ str, Dict[str, Any]] ):
262+ Name of the training recipe or path to the recipe file or loaded recipe Dict .
259263 compute (Compute):
260264 Compute configuration for training.
261265 region_name (str):
@@ -269,7 +273,13 @@ def _get_args_from_recipe(
269273 raise ValueError ("Must set `instance_type` in compute when using training recipes." )
270274
271275 training_recipes_cfg = _load_recipes_cfg ()
272- recipe = _load_base_recipe (training_recipe , recipe_overrides , training_recipes_cfg )
276+ if isinstance (training_recipe , str ):
277+ recipe = _load_base_recipe (training_recipe , recipe_overrides , training_recipes_cfg )
278+ else :
279+ recipe = training_recipe
280+ if _is_nova_recipe (recipe ):
281+ args , recipe_local_dir = _get_args_from_nova_recipe (recipe , compute , role = role )
282+ return args , recipe_local_dir
273283
274284 if "trainer" not in recipe :
275285 raise ValueError ("Supplied recipe does not contain required field trainer." )
@@ -283,7 +293,7 @@ def _get_args_from_recipe(
283293 if compute .instance_count is None :
284294 if "num_nodes" not in recipe ["trainer" ]:
285295 raise ValueError (
286- "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe."
296+ "Must provide Compute with instance_count or set trainer -> num_nodes in recipe."
287297 )
288298 compute .instance_count = recipe ["trainer" ]["num_nodes" ]
289299
@@ -313,7 +323,7 @@ def _get_args_from_recipe(
313323
314324 # Save Final Recipe to source_dir
315325 OmegaConf .save (
316- config = final_recipe , f = os .path .join (args ["source_code" ].source_dir , "recipe.yaml" )
326+ config = final_recipe , f = os .path .join (args ["source_code" ].source_dir , SM_RECIPE_YAML )
317327 )
318328
319329 # If recipe_requirements is provided, copy it to source_dir
@@ -322,7 +332,7 @@ def _get_args_from_recipe(
322332 args ["source_code" ].requirements = os .path .basename (requirements )
323333
324334 # Update args with compute and hyperparameters
325- hyperparameters = {"config-path" : "." , "config-name" : "recipe.yaml" }
335+ hyperparameters = {"config-path" : "." , "config-name" : SM_RECIPE_YAML }
326336
327337 # Handle eval custom lambda configuration
328338 if recipe .get ("evaluation" , {}):
@@ -339,3 +349,111 @@ def _get_args_from_recipe(
339349 )
340350
341351 return args , recipe_train_dir
352+
353+ def _is_nova_recipe (
354+ recipe : dictconfig .DictConfig ,
355+ ) -> bool :
356+ """Check if the recipe is a Nova recipe.
357+
358+ A recipe is considered a Nova recipe if it meets either of the following conditions:
359+
360+ 1. It has a run section with:
361+ - A model_type that includes "amazon.nova"
362+ - A model_name_or_path field
363+
364+ OR
365+
366+ 2. It has a training_config section with:
367+ - A distillation_data field
368+
369+ Args:
370+ recipe (DictConfig): The loaded recipe configuration
371+
372+ Returns:
373+ bool: True if the recipe is a Nova recipe, False otherwise
374+ """
375+ run_config = recipe .get ("run" , {})
376+ model_type = run_config .get ("model_type" , "" ).lower ()
377+ has_nova_model = (
378+ model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config
379+ )
380+
381+ # Check for distillation data
382+ training_config = recipe .get ("training_config" , {})
383+ has_distillation = training_config .get ("distillation_data" ) is not None
384+ return bool (has_nova_model ) or bool (has_distillation )
385+
386+ def _get_args_from_nova_recipe (
387+ recipe : dictconfig .DictConfig ,
388+ compute : Compute ,
389+ role : Optional [str ] = None ,
390+ ) -> Tuple [Dict [str , Any ], tempfile .TemporaryDirectory ]:
391+ if not compute .instance_count and not recipe .get ("run" , {}).get ("replicas" , None ):
392+ raise ValueError ("Must set ``instance_type`` in compute or ``replicas`` in recipe." )
393+ compute .instance_count = compute .instance_count or recipe .get ("run" , {}).get ("replicas" )
394+
395+ args = dict ()
396+ args .update ({"hyperparameters" : {}})
397+
398+ run_config = recipe .get ("run" , {})
399+ model_name_or_path = run_config .get ("model_name_or_path" )
400+ if model_name_or_path :
401+ if model_name_or_path .startswith ("s3://" ):
402+ args ["hyperparameters" ]["base_model_location" ] = model_name_or_path
403+ else :
404+ args ["hyperparameters" ]["base_model" ] = model_name_or_path
405+
406+ # Handle distillation configuration
407+ training_config = recipe .get ("training_config" , {})
408+ distillation_data = training_config .get ("distillation_data" )
409+ if bool (distillation_data ):
410+ args ["hyperparameters" ]["distillation_data" ] = distillation_data
411+ if not role :
412+ raise ValueError ("Must provide 'role' parameter when using Nova distillation" )
413+ args ["hyperparameters" ]["role_arn" ] = role
414+
415+ kms_key = training_config .get ("kms_key" )
416+ if kms_key is None :
417+ raise ValueError (
418+ 'Nova distillation job recipe requires "kms_key" field in "training_config"'
419+ )
420+ args ["hyperparameters" ]["kms_key" ] = kms_key
421+
422+ # Handle eval custom lambda configuration
423+ if recipe .get ("evaluation" , {}):
424+ processor = recipe .get ("processor" , {})
425+ lambda_arn = processor .get ("lambda_arn" , "" )
426+ if lambda_arn :
427+ args ["hyperparameters" ]["eval_lambda_arn" ] = lambda_arn
428+
429+ # Handle reward lambda configuration
430+ run_config = recipe .get ("run" , {})
431+ reward_lambda_arn = run_config .get ("reward_lambda_arn" , "" )
432+ if reward_lambda_arn :
433+ args ["hyperparameters" ]["reward_lambda_arn" ] = reward_lambda_arn
434+
435+ _register_custom_resolvers ()
436+
437+ # Resolve Final Recipe
438+ final_recipe = _try_resolve_recipe (recipe )
439+ if final_recipe is None :
440+ final_recipe = _try_resolve_recipe (recipe , "recipes" )
441+ if final_recipe is None :
442+ final_recipe = _try_resolve_recipe (recipe , "training" )
443+ if final_recipe is None :
444+ raise RuntimeError ("Could not resolve provided recipe." )
445+
446+ # Save Final Recipe to tmp dir
447+ recipe_local_dir = tempfile .TemporaryDirectory (prefix = "recipe_" )
448+ final_recipe_path = os .path .join (recipe_local_dir .name , SM_RECIPE_YAML )
449+ OmegaConf .save (config = final_recipe , f = final_recipe_path )
450+
451+ args .update (
452+ {
453+ "compute" : compute ,
454+ "training_image" : None ,
455+ "source_code" : None ,
456+ "distributed" : None ,
457+ }
458+ )
459+ return args , recipe_local_dir
0 commit comments