@@ -444,96 +444,140 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke
444444
445445 @classmethod
446446 def _prepare_model_trainer_for_tuning (cls , model_trainer , inputs = None , job_name = None , ** kwargs ):
447- """Prepare ModelTrainer before tuning by uploading source code and configuring hyperparameters .
447+ """Prepare ModelTrainer before tuning by building sm_drivers and code channels .
448448
449- This method mimics V2's _prepare_estimator_for_tuning() pattern, adapted for V3's
450- ModelTrainer architecture. It ensures that script mode hyperparameters are set before
451- the tuning job is created, which framework containers (PyTorch, TensorFlow) require.
449+ This method replicates the channel-building logic from ModelTrainer._create_training_job()
450+ to ensure the sm_drivers channel (containing torchrun_driver.py, distributed config, and
451+ sm_train.sh) is included in the tuning job definition. Without this, the framework
452+ container falls back to the legacy entry point (python train.py) instead of using the
453+ V3 driver (torchrun), breaking distributed training.
452454
453455 Args:
454456 model_trainer: ModelTrainer instance to prepare
455457 inputs: Training inputs (unused, for V2 compatibility)
456458 job_name: Job name (unused, for V2 compatibility)
457459 **kwargs: Additional arguments (unused, for V2 compatibility)
458460 """
459- # Only proceed if source_code is configured
460- if hasattr (model_trainer , "source_code" ) and model_trainer .source_code is not None :
461- cls ._upload_source_code_and_configure_hyperparameters (model_trainer )
461+ source_code = getattr (model_trainer , "source_code" , None )
462+ if source_code is None :
463+ return
464+ # Only proceed if source_code has a real entry_script string
465+ entry_script = getattr (source_code , "entry_script" , None )
466+ if not isinstance (entry_script , str ):
467+ return
462468
463- @classmethod
464- def _upload_source_code_and_configure_hyperparameters (cls , model_trainer ):
465- """Upload source code to S3 and add script mode hyperparameters.
469+ cls ._build_driver_and_code_channels (model_trainer )
466470
467- Framework containers (PyTorch, TensorFlow) expect sagemaker_program and
468- sagemaker_submit_directory hyperparameters for script mode execution. This method:
469- 1. Checks if source_dir is a local path or S3 URI
470- 2. Creates a tar.gz archive and uploads to S3
471- 3. Adds required script mode hyperparameters to model_trainer.hyperparameters
471+ @classmethod
472+ def _build_driver_and_code_channels (cls , model_trainer ):
473+ """Build sm_drivers and code input channels for the tuning job.
472474
473- This follows V2's pattern of creating sourcedir.tar.gz files.
475+ Replicates the channel-building logic from ModelTrainer._create_training_job()
476+ so that the tuning job gets the same execution environment as a standalone
477+ training job (distributed drivers, source code, train script).
474478
475479 Args:
476480 model_trainer: ModelTrainer instance with source_code configured
477481 """
482+ import json
478483 import os
479- import tarfile
480- import tempfile
484+ import shutil
481485 import time
486+ from tempfile import TemporaryDirectory
487+
488+ from sagemaker .train .constants import (
489+ SM_CODE ,
490+ SM_DRIVERS ,
491+ SM_DRIVERS_LOCAL_PATH ,
492+ DEFAULT_CONTAINER_ENTRYPOINT ,
493+ DEFAULT_CONTAINER_ARGUMENTS ,
494+ )
482495
483496 source_code = model_trainer .source_code
497+ base_name = model_trainer .base_job_name or "tuning"
498+ key_prefix = f"{ base_name } /tuning-{ int (time .time ())} /input"
499+
500+ # Build sm_drivers channel (same as ModelTrainer._create_training_job)
501+ temp_dir = TemporaryDirectory ()
502+ shutil .copytree (SM_DRIVERS_LOCAL_PATH , temp_dir .name , dirs_exist_ok = True )
503+
504+ # If distributed config is set, copy distributed drivers
505+ if model_trainer .distributed :
506+ driver_dir = os .path .join (temp_dir .name , "distributed_drivers" )
507+ shutil .copytree (model_trainer .distributed .driver_dir , driver_dir , dirs_exist_ok = True )
508+
509+ # Write sourcecode.json
510+ source_code_json_path = os .path .join (temp_dir .name , "sourcecode.json" )
511+ with open (source_code_json_path , "w" ) as f :
512+ dump = source_code .model_dump () if source_code else {}
513+ f .write (json .dumps (dump ))
514+
515+ # Write distributed.json
516+ distributed_json_path = os .path .join (temp_dir .name , "distributed.json" )
517+ with open (distributed_json_path , "w" ) as f :
518+ dump = model_trainer .distributed .model_dump () if model_trainer .distributed else {}
519+ f .write (json .dumps (dump ))
520+
521+ # Prepare the train script (sm_train.sh)
522+ model_trainer ._prepare_train_script (
523+ tmp_dir = temp_dir ,
524+ source_code = source_code ,
525+ distributed = model_trainer .distributed ,
526+ )
527+
528+ # Upload sm_drivers channel
529+ sm_drivers_channel = model_trainer .create_input_data_channel (
530+ channel_name = SM_DRIVERS ,
531+ data_source = temp_dir .name ,
532+ key_prefix = key_prefix ,
533+ ignore_patterns = source_code .ignore_patterns ,
534+ )
484535
485- # Get source directory and entry script
486- source_dir = source_code .source_dir
487- entry_script = source_code .entry_script
536+ # Store channels on model_trainer so _build_training_job_definition can pick them up
537+ model_trainer ._tuner_channels = [sm_drivers_channel ]
488538
489- # Check if already an S3 URI
490- if _is_valid_s3_uri (source_dir ):
491- # Already uploaded, use as-is
492- source_s3_uri = source_dir
493- else :
494- # Local directory - need to create tar.gz and upload
495- session = model_trainer .sagemaker_session
496- bucket = session .default_bucket ()
539+ # Set script mode hyperparameters required by framework containers.
540+ # The framework container (PyTorch, TF) uses sagemaker_program to find the entry script
541+ # and sagemaker_submit_directory to download source code to /opt/ml/code/.
542+ if model_trainer .hyperparameters is None :
543+ model_trainer .hyperparameters = {}
544+ model_trainer .hyperparameters ["sagemaker_program" ] = source_code .entry_script
497545
498- # Generate S3 key
499- timestamp = int (time .time ())
500- s3_key = (
501- f"{ model_trainer .base_job_name or 'source' } /source-{ timestamp } /sourcedir.tar.gz"
502- )
546+ # Upload sourcedir.tar.gz for the legacy framework container path.
547+ # The HPT API doesn't support container_entrypoint, so the framework container
548+ # uses sagemaker_submit_directory to download and extract code to /opt/ml/code/.
549+ if source_code .source_dir and not _is_valid_s3_uri (source_code .source_dir ):
550+ import tarfile
551+ import tempfile
503552
504- # Create tar.gz file
505- with tempfile . NamedTemporaryFile ( suffix = ".tar.gz" , delete = False ) as tmp_file :
506- tar_path = tmp_file . name
553+ session = model_trainer . sagemaker_session
554+ bucket = session . default_bucket ()
555+ s3_key = f" { key_prefix } /sourcedir/sourcedir.tar.gz"
507556
557+ with tempfile .NamedTemporaryFile (suffix = ".tar.gz" , delete = False ) as tmp :
558+ tar_path = tmp .name
508559 try :
509- # Create tar.gz archive
510560 with tarfile .open (tar_path , "w:gz" ) as tar :
511- # Add all files from source_dir
512- for root , dirs , files in os .walk (source_dir ):
513- for file in files :
514- file_path = os .path .join (root , file )
515- # Calculate arcname to preserve directory structure
516- arcname = os .path .relpath (file_path , source_dir )
517- tar .add (file_path , arcname = arcname )
518-
519- # Upload to S3
520- s3_client = session .boto_session .client ("s3" , region_name = session .boto_region_name )
561+ for root , _dirs , files in os .walk (source_code .source_dir ):
562+ for f in files :
563+ fpath = os .path .join (root , f )
564+ arcname = os .path .relpath (fpath , source_code .source_dir )
565+ tar .add (fpath , arcname = arcname )
566+ s3_client = session .boto_session .client (
567+ "s3" , region_name = session .boto_region_name
568+ )
521569 s3_client .upload_file (tar_path , bucket , s3_key )
522-
523- # Construct S3 URI
524- source_s3_uri = f"s3:// { bucket } / { s3_key } "
570+ model_trainer . hyperparameters [ "sagemaker_submit_directory" ] = (
571+ f"s3:// { bucket } / { s3_key } "
572+ )
525573 finally :
526- # Clean up temp file
527574 if os .path .exists (tar_path ):
528575 os .remove (tar_path )
576+ elif source_code .source_dir and _is_valid_s3_uri (source_code .source_dir ):
577+ model_trainer .hyperparameters ["sagemaker_submit_directory" ] = source_code .source_dir
529578
530- # Initialize hyperparameters dict if None
531- if model_trainer .hyperparameters is None :
532- model_trainer .hyperparameters = {}
533-
534- # Add script mode hyperparameters required by framework containers
535- model_trainer .hyperparameters ["sagemaker_program" ] = entry_script
536- model_trainer .hyperparameters ["sagemaker_submit_directory" ] = source_s3_uri
579+ # Store the temp dir reference to prevent cleanup
580+ model_trainer ._tuner_temp_dir = temp_dir
537581
538582 @runnable_by_pipeline
539583 def tune (
@@ -1422,6 +1466,12 @@ def _build_training_job_definition(self, inputs):
14221466 if not any (c .channel_name == channel .channel_name for c in input_data_config ):
14231467 input_data_config .append (channel )
14241468
1469+ # Include channels built by _prepare_model_trainer_for_tuning (sm_drivers, code)
1470+ if hasattr (model_trainer , "_tuner_channels" ) and model_trainer ._tuner_channels :
1471+ for channel in model_trainer ._tuner_channels :
1472+ if not any (c .channel_name == channel .channel_name for c in input_data_config ):
1473+ input_data_config .append (channel )
1474+
14251475 # Build output data config
14261476 output_config = OutputDataConfig (
14271477 s3_output_path = (
@@ -1459,7 +1509,22 @@ def _build_training_job_definition(self, inputs):
14591509 output_data_config = output_config ,
14601510 resource_config = resource_config ,
14611511 stopping_condition = stopping_condition ,
1462- static_hyper_parameters = self . static_hyperparameters or {},
1512+ static_hyper_parameters = getattr ( self , " static_hyperparameters" , None ) or {},
14631513 )
14641514
1515+ # Pass through environment variables from model_trainer
1516+ env = getattr (model_trainer , "environment" , None )
1517+ if env and isinstance (env , dict ):
1518+ definition .environment = env
1519+
1520+ # Pass through VPC config from model_trainer
1521+ networking = getattr (model_trainer , "networking" , None )
1522+ if networking and hasattr (networking , "_to_vpc_config" ):
1523+ try :
1524+ vpc_config = networking ._to_vpc_config ()
1525+ if vpc_config :
1526+ definition .vpc_config = vpc_config
1527+ except Exception :
1528+ pass
1529+
14651530 return definition
0 commit comments