Skip to content

Commit 2c32e62

Browse files
JiwaniZakirclaude
andcommitted
Fix content_type dropped when converting InputData to Channel in tuner
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e161199 commit 2c32e62

2 files changed

Lines changed: 34 additions & 0 deletions

File tree

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,7 @@ def _build_training_job_definition(self, inputs):
14301430
input_data_config.append(
14311431
Channel(
14321432
channel_name=inp.channel_name,
1433+
content_type=inp.content_type,
14331434
data_source=DataSource(
14341435
s3_data_source=S3DataSource(
14351436
s3_data_type="S3Prefix",

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,36 @@ def test_build_training_job_definition_includes_internal_channels(self):
574574
assert "train" in channel_names, "User 'train' channel should be included"
575575
assert "validation" in channel_names, "User 'validation' channel should be included"
576576
assert len(channel_names) == 4, "Should have exactly 4 channels"
577+
578+
def test_build_training_job_definition_preserves_content_type(self):
579+
"""Test that InputData content_type is preserved when converting to Channel.
580+
581+
This test verifies the fix for GitHub issue #5632 where content_type was
582+
dropped during InputData -> Channel conversion in HyperparameterTuner.
583+
"""
584+
from sagemaker.core.training.configs import InputData
585+
586+
mock_trainer = _create_mock_model_trainer(with_internal_channels=False)
587+
588+
user_inputs = [
589+
InputData(
590+
channel_name="train",
591+
data_source="s3://bucket/train",
592+
content_type="text/csv",
593+
),
594+
]
595+
596+
tuner = HyperparameterTuner(
597+
model_trainer=mock_trainer,
598+
objective_metric_name="accuracy",
599+
hyperparameter_ranges=_create_single_hp_range(),
600+
)
601+
602+
definition = tuner._build_training_job_definition(user_inputs)
603+
604+
train_channel = next(
605+
ch for ch in definition.input_data_config if ch.channel_name == "train"
606+
)
607+
assert train_channel.content_type == "text/csv", (
608+
"content_type should be preserved when converting InputData to Channel"
609+
)

0 commit comments

Comments
 (0)