Skip to content

Commit 6a174f4

Browse files
authored
fix(tuner): Include sm_drivers channel in HyperparameterTuner jobs (#5634)
* fix(tuner): Include sm_drivers channel in HyperparameterTuner jobs When ModelTrainer has distributed=Torchrun(), the sm_drivers channel contains torchrun_driver.py and sm_train.sh which are required for multi-GPU execution. The tuner was not building this channel, causing the framework container to fall back to the legacy single-GPU entry point (python train.py) instead of torchrun. This caused a tensor size mismatch (batch_size vs accumulated_batch) in TRL's compute_loss when gradient_accumulation_steps > 1, because the single-process path doesn't partition batches across ranks. Fix: Replace _upload_source_code_and_configure_hyperparameters with _build_driver_and_code_channels that replicates ModelTrainer's channel building logic (sm_drivers, code, distributed.json, sourcecode.json, sm_train.sh). Also pass through environment and VPC config. * fix(tuner): Harden _build_training_job_definition against missing attributes - Use getattr with fallback for static_hyperparameters (fixes test_build_training_job_definition_includes_internal_channels) - Guard _prepare_model_trainer_for_tuning with isinstance check on entry_script to avoid calling _build_driver_and_code_channels on MagicMock model trainers - Guard environment passthrough with isinstance(env, dict) check - Guard VPC config passthrough with try/except for mock safety * fix(test): Rewrite tuner distributed integ test to match CI patterns - Use sagemaker_session fixture from conftest (auto-resolves role/region) - Use ml.m5.xlarge CPU instance (cheaper, available in CI) - Remove hardcoded role ARN and training_mode - Remove @pytest.mark.slow (not registered in CI config) - Use module-level function instead of class (matches other integ tests) - Use DEFAULT_CPU_IMAGE consistent with test_model_trainer.py * fix(tuner): Upload sourcedir.tar.gz for framework container compatibility The HPT API uses the legacy framework container path which expects sagemaker_submit_directory (a tar.gz on S3) to be downloaded and extracted to /opt/ml/code/. The previous approach of using a 'code' input channel mounted the code at /opt/ml/input/data/code/ instead, causing 'No such file or directory' errors. Fix: Create and upload sourcedir.tar.gz to S3, set both sagemaker_program and sagemaker_submit_directory hyperparameters. Remove the separate 'code' input channel since the framework container handles code extraction via sagemaker_submit_directory. * test(tuner): Add unit tests for driver/code channel building Add 25 unit tests covering the tuner changes from PR #5634: - _prepare_model_trainer_for_tuning guard logic - _build_driver_and_code_channels sm_drivers channel creation - _build_training_job_definition _tuner_channels inclusion - Environment and VPC config passthrough - sourcedir.tar.gz upload and sagemaker_submit_directory HP - static_hyperparameters getattr fallback
1 parent 10df8a4 commit 6a174f4

File tree

3 files changed

+801
-60
lines changed

3 files changed

+801
-60
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 125 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -444,96 +444,140 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke
444444

445445
@classmethod
446446
def _prepare_model_trainer_for_tuning(cls, model_trainer, inputs=None, job_name=None, **kwargs):
447-
"""Prepare ModelTrainer before tuning by uploading source code and configuring hyperparameters.
447+
"""Prepare ModelTrainer before tuning by building sm_drivers and code channels.
448448
449-
This method mimics V2's _prepare_estimator_for_tuning() pattern, adapted for V3's
450-
ModelTrainer architecture. It ensures that script mode hyperparameters are set before
451-
the tuning job is created, which framework containers (PyTorch, TensorFlow) require.
449+
This method replicates the channel-building logic from ModelTrainer._create_training_job()
450+
to ensure the sm_drivers channel (containing torchrun_driver.py, distributed config, and
451+
sm_train.sh) is included in the tuning job definition. Without this, the framework
452+
container falls back to the legacy entry point (python train.py) instead of using the
453+
V3 driver (torchrun), breaking distributed training.
452454
453455
Args:
454456
model_trainer: ModelTrainer instance to prepare
455457
inputs: Training inputs (unused, for V2 compatibility)
456458
job_name: Job name (unused, for V2 compatibility)
457459
**kwargs: Additional arguments (unused, for V2 compatibility)
458460
"""
459-
# Only proceed if source_code is configured
460-
if hasattr(model_trainer, "source_code") and model_trainer.source_code is not None:
461-
cls._upload_source_code_and_configure_hyperparameters(model_trainer)
461+
source_code = getattr(model_trainer, "source_code", None)
462+
if source_code is None:
463+
return
464+
# Only proceed if source_code has a real entry_script string
465+
entry_script = getattr(source_code, "entry_script", None)
466+
if not isinstance(entry_script, str):
467+
return
462468

463-
@classmethod
464-
def _upload_source_code_and_configure_hyperparameters(cls, model_trainer):
465-
"""Upload source code to S3 and add script mode hyperparameters.
469+
cls._build_driver_and_code_channels(model_trainer)
466470

467-
Framework containers (PyTorch, TensorFlow) expect sagemaker_program and
468-
sagemaker_submit_directory hyperparameters for script mode execution. This method:
469-
1. Checks if source_dir is a local path or S3 URI
470-
2. Creates a tar.gz archive and uploads to S3
471-
3. Adds required script mode hyperparameters to model_trainer.hyperparameters
471+
@classmethod
472+
def _build_driver_and_code_channels(cls, model_trainer):
473+
"""Build sm_drivers and code input channels for the tuning job.
472474
473-
This follows V2's pattern of creating sourcedir.tar.gz files.
475+
Replicates the channel-building logic from ModelTrainer._create_training_job()
476+
so that the tuning job gets the same execution environment as a standalone
477+
training job (distributed drivers, source code, train script).
474478
475479
Args:
476480
model_trainer: ModelTrainer instance with source_code configured
477481
"""
482+
import json
478483
import os
479-
import tarfile
480-
import tempfile
484+
import shutil
481485
import time
486+
from tempfile import TemporaryDirectory
487+
488+
from sagemaker.train.constants import (
489+
SM_CODE,
490+
SM_DRIVERS,
491+
SM_DRIVERS_LOCAL_PATH,
492+
DEFAULT_CONTAINER_ENTRYPOINT,
493+
DEFAULT_CONTAINER_ARGUMENTS,
494+
)
482495

483496
source_code = model_trainer.source_code
497+
base_name = model_trainer.base_job_name or "tuning"
498+
key_prefix = f"{base_name}/tuning-{int(time.time())}/input"
499+
500+
# Build sm_drivers channel (same as ModelTrainer._create_training_job)
501+
temp_dir = TemporaryDirectory()
502+
shutil.copytree(SM_DRIVERS_LOCAL_PATH, temp_dir.name, dirs_exist_ok=True)
503+
504+
# If distributed config is set, copy distributed drivers
505+
if model_trainer.distributed:
506+
driver_dir = os.path.join(temp_dir.name, "distributed_drivers")
507+
shutil.copytree(model_trainer.distributed.driver_dir, driver_dir, dirs_exist_ok=True)
508+
509+
# Write sourcecode.json
510+
source_code_json_path = os.path.join(temp_dir.name, "sourcecode.json")
511+
with open(source_code_json_path, "w") as f:
512+
dump = source_code.model_dump() if source_code else {}
513+
f.write(json.dumps(dump))
514+
515+
# Write distributed.json
516+
distributed_json_path = os.path.join(temp_dir.name, "distributed.json")
517+
with open(distributed_json_path, "w") as f:
518+
dump = model_trainer.distributed.model_dump() if model_trainer.distributed else {}
519+
f.write(json.dumps(dump))
520+
521+
# Prepare the train script (sm_train.sh)
522+
model_trainer._prepare_train_script(
523+
tmp_dir=temp_dir,
524+
source_code=source_code,
525+
distributed=model_trainer.distributed,
526+
)
527+
528+
# Upload sm_drivers channel
529+
sm_drivers_channel = model_trainer.create_input_data_channel(
530+
channel_name=SM_DRIVERS,
531+
data_source=temp_dir.name,
532+
key_prefix=key_prefix,
533+
ignore_patterns=source_code.ignore_patterns,
534+
)
484535

485-
# Get source directory and entry script
486-
source_dir = source_code.source_dir
487-
entry_script = source_code.entry_script
536+
# Store channels on model_trainer so _build_training_job_definition can pick them up
537+
model_trainer._tuner_channels = [sm_drivers_channel]
488538

489-
# Check if already an S3 URI
490-
if _is_valid_s3_uri(source_dir):
491-
# Already uploaded, use as-is
492-
source_s3_uri = source_dir
493-
else:
494-
# Local directory - need to create tar.gz and upload
495-
session = model_trainer.sagemaker_session
496-
bucket = session.default_bucket()
539+
# Set script mode hyperparameters required by framework containers.
540+
# The framework container (PyTorch, TF) uses sagemaker_program to find the entry script
541+
# and sagemaker_submit_directory to download source code to /opt/ml/code/.
542+
if model_trainer.hyperparameters is None:
543+
model_trainer.hyperparameters = {}
544+
model_trainer.hyperparameters["sagemaker_program"] = source_code.entry_script
497545

498-
# Generate S3 key
499-
timestamp = int(time.time())
500-
s3_key = (
501-
f"{model_trainer.base_job_name or 'source'}/source-{timestamp}/sourcedir.tar.gz"
502-
)
546+
# Upload sourcedir.tar.gz for the legacy framework container path.
547+
# The HPT API doesn't support container_entrypoint, so the framework container
548+
# uses sagemaker_submit_directory to download and extract code to /opt/ml/code/.
549+
if source_code.source_dir and not _is_valid_s3_uri(source_code.source_dir):
550+
import tarfile
551+
import tempfile
503552

504-
# Create tar.gz file
505-
with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp_file:
506-
tar_path = tmp_file.name
553+
session = model_trainer.sagemaker_session
554+
bucket = session.default_bucket()
555+
s3_key = f"{key_prefix}/sourcedir/sourcedir.tar.gz"
507556

557+
with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
558+
tar_path = tmp.name
508559
try:
509-
# Create tar.gz archive
510560
with tarfile.open(tar_path, "w:gz") as tar:
511-
# Add all files from source_dir
512-
for root, dirs, files in os.walk(source_dir):
513-
for file in files:
514-
file_path = os.path.join(root, file)
515-
# Calculate arcname to preserve directory structure
516-
arcname = os.path.relpath(file_path, source_dir)
517-
tar.add(file_path, arcname=arcname)
518-
519-
# Upload to S3
520-
s3_client = session.boto_session.client("s3", region_name=session.boto_region_name)
561+
for root, _dirs, files in os.walk(source_code.source_dir):
562+
for f in files:
563+
fpath = os.path.join(root, f)
564+
arcname = os.path.relpath(fpath, source_code.source_dir)
565+
tar.add(fpath, arcname=arcname)
566+
s3_client = session.boto_session.client(
567+
"s3", region_name=session.boto_region_name
568+
)
521569
s3_client.upload_file(tar_path, bucket, s3_key)
522-
523-
# Construct S3 URI
524-
source_s3_uri = f"s3://{bucket}/{s3_key}"
570+
model_trainer.hyperparameters["sagemaker_submit_directory"] = (
571+
f"s3://{bucket}/{s3_key}"
572+
)
525573
finally:
526-
# Clean up temp file
527574
if os.path.exists(tar_path):
528575
os.remove(tar_path)
576+
elif source_code.source_dir and _is_valid_s3_uri(source_code.source_dir):
577+
model_trainer.hyperparameters["sagemaker_submit_directory"] = source_code.source_dir
529578

530-
# Initialize hyperparameters dict if None
531-
if model_trainer.hyperparameters is None:
532-
model_trainer.hyperparameters = {}
533-
534-
# Add script mode hyperparameters required by framework containers
535-
model_trainer.hyperparameters["sagemaker_program"] = entry_script
536-
model_trainer.hyperparameters["sagemaker_submit_directory"] = source_s3_uri
579+
# Store the temp dir reference to prevent cleanup
580+
model_trainer._tuner_temp_dir = temp_dir
537581

538582
@runnable_by_pipeline
539583
def tune(
@@ -1422,6 +1466,12 @@ def _build_training_job_definition(self, inputs):
14221466
if not any(c.channel_name == channel.channel_name for c in input_data_config):
14231467
input_data_config.append(channel)
14241468

1469+
# Include channels built by _prepare_model_trainer_for_tuning (sm_drivers, code)
1470+
if hasattr(model_trainer, "_tuner_channels") and model_trainer._tuner_channels:
1471+
for channel in model_trainer._tuner_channels:
1472+
if not any(c.channel_name == channel.channel_name for c in input_data_config):
1473+
input_data_config.append(channel)
1474+
14251475
# Build output data config
14261476
output_config = OutputDataConfig(
14271477
s3_output_path=(
@@ -1459,7 +1509,22 @@ def _build_training_job_definition(self, inputs):
14591509
output_data_config=output_config,
14601510
resource_config=resource_config,
14611511
stopping_condition=stopping_condition,
1462-
static_hyper_parameters=self.static_hyperparameters or {},
1512+
static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {},
14631513
)
14641514

1515+
# Pass through environment variables from model_trainer
1516+
env = getattr(model_trainer, "environment", None)
1517+
if env and isinstance(env, dict):
1518+
definition.environment = env
1519+
1520+
# Pass through VPC config from model_trainer
1521+
networking = getattr(model_trainer, "networking", None)
1522+
if networking and hasattr(networking, "_to_vpc_config"):
1523+
try:
1524+
vpc_config = networking._to_vpc_config()
1525+
if vpc_config:
1526+
definition.vpc_config = vpc_config
1527+
except Exception:
1528+
pass
1529+
14651530
return definition
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Integration test: HyperparameterTuner with Torchrun distributed training.
14+
15+
Regression test for the bug where HyperparameterTuner dropped the sm_drivers
16+
channel, causing the container to fall back to single-GPU execution instead
17+
of using torchrun for multi-GPU distributed training.
18+
"""
19+
from __future__ import absolute_import
20+
21+
import os
22+
import time
23+
import logging
24+
25+
import pytest
26+
27+
from sagemaker.train.model_trainer import ModelTrainer
28+
from sagemaker.train.configs import SourceCode, Compute
29+
from sagemaker.train.distributed import Torchrun
30+
from sagemaker.train.tuner import HyperparameterTuner
31+
from sagemaker.core.parameter import ContinuousParameter
32+
33+
logger = logging.getLogger(__name__)
34+
35+
DATA_DIR = os.path.join(os.path.dirname(__file__), "../..", "data")
36+
DEFAULT_CPU_IMAGE = (
37+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
38+
)
39+
40+
TRAIN_SCRIPT_CONTENT = """\
41+
import os
42+
import argparse
43+
44+
45+
def main():
46+
parser = argparse.ArgumentParser()
47+
parser.add_argument("--learning_rate", type=float, default=1e-4)
48+
args, _ = parser.parse_known_args()
49+
50+
world_size = int(os.environ.get("WORLD_SIZE", 1))
51+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
52+
53+
print(f"DISTRIBUTED_CHECK: world_size={world_size}")
54+
print(f"DISTRIBUTED_CHECK: local_rank={local_rank}")
55+
print(f"DISTRIBUTED_CHECK: learning_rate={args.learning_rate}")
56+
57+
# Emit metric for the tuner to parse
58+
print(f"eval_loss: 0.42")
59+
60+
61+
if __name__ == "__main__":
62+
main()
63+
"""
64+
65+
66+
@pytest.fixture(scope="module")
67+
def train_source_dir(tmp_path_factory):
68+
"""Create a temp directory with a minimal training script."""
69+
d = tmp_path_factory.mktemp("tuner_dist_src")
70+
(d / "train.py").write_text(TRAIN_SCRIPT_CONTENT)
71+
return str(d)
72+
73+
74+
def test_tuner_includes_sm_drivers_channel(sagemaker_session, train_source_dir):
75+
"""Verify tuning jobs include sm_drivers channel for distributed training.
76+
77+
Uses a CPU instance with Torchrun to validate that the sm_drivers channel
78+
(containing torchrun_driver.py and sm_train.sh) is included in the tuning
79+
job definition. The training script logs WORLD_SIZE to confirm the V3
80+
driver path is used instead of the legacy framework container fallback.
81+
"""
82+
model_trainer = ModelTrainer(
83+
sagemaker_session=sagemaker_session,
84+
training_image=DEFAULT_CPU_IMAGE,
85+
base_job_name="tuner-dist-test",
86+
source_code=SourceCode(
87+
source_dir=train_source_dir,
88+
entry_script="train.py",
89+
),
90+
compute=Compute(
91+
instance_type="ml.m5.xlarge",
92+
instance_count=1,
93+
volume_size_in_gb=30,
94+
),
95+
distributed=Torchrun(),
96+
hyperparameters={"learning_rate": 1e-4},
97+
)
98+
99+
tuner = HyperparameterTuner(
100+
model_trainer=model_trainer,
101+
objective_metric_name="eval_loss",
102+
metric_definitions=[
103+
{"Name": "eval_loss", "Regex": r"eval_loss: ([0-9\\.]+)"},
104+
],
105+
hyperparameter_ranges={
106+
"learning_rate": ContinuousParameter(
107+
min_value=1e-5,
108+
max_value=5e-4,
109+
scaling_type="Logarithmic",
110+
),
111+
},
112+
objective_type="Minimize",
113+
max_jobs=1,
114+
max_parallel_jobs=1,
115+
)
116+
117+
tuner.tune(wait=True)
118+
119+
job = tuner.latest_tuning_job.refresh()
120+
assert job.hyper_parameter_tuning_job_status in (
121+
"Completed",
122+
"Stopped",
123+
), f"Tuning job failed: {job.hyper_parameter_tuning_job_status}"
124+
125+
best = tuner.best_training_job()
126+
assert best is not None
127+
logger.info("PASSED: tuner distributed training test - job: %s", best)

0 commit comments

Comments
 (0)