Skip to content

Commit e416483

Browse files
committed
Refactor output data configuration handling in tuner
1 parent 5221bfb commit e416483

2 files changed

Lines changed: 3 additions & 10 deletions

File tree

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def __init__(
253253
self.instance_configs_dict = None
254254
self.instance_configs = None
255255
self.autotune = autotune
256+
self.output_data_config = model_trainer.output_data_config
256257

257258
def override_resource_config(
258259
self,
@@ -1331,7 +1332,6 @@ def _build_training_job_definition(self, inputs):
13311332
from sagemaker.core.shapes import (
13321333
HyperParameterTrainingJobDefinition,
13331334
HyperParameterAlgorithmSpecification,
1334-
OutputDataConfig,
13351335
ResourceConfig,
13361336
StoppingCondition,
13371337
Channel,
@@ -1422,14 +1422,6 @@ def _build_training_job_definition(self, inputs):
14221422
if not any(c.channel_name == channel.channel_name for c in input_data_config):
14231423
input_data_config.append(channel)
14241424

1425-
# Build output data config
1426-
output_config = OutputDataConfig(
1427-
s3_output_path=(
1428-
model_trainer.output_data_config.s3_output_path
1429-
if model_trainer.output_data_config
1430-
else None
1431-
)
1432-
)
14331425

14341426
# Build resource config
14351427
resource_config = ResourceConfig(
@@ -1456,7 +1448,7 @@ def _build_training_job_definition(self, inputs):
14561448
algorithm_specification=algorithm_spec,
14571449
role_arn=model_trainer.role,
14581450
input_data_config=input_data_config if input_data_config else None,
1459-
output_data_config=output_config,
1451+
output_data_config=self.output_data_config,
14601452
resource_config=resource_config,
14611453
stopping_condition=stopping_condition,
14621454
static_hyper_parameters=self.static_hyperparameters or {},

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _create_mock_model_trainer(with_internal_channels=False):
5353
trainer.training_input_mode = "File"
5454
trainer.role = "arn:aws:iam::123456789012:role/SageMakerRole"
5555
trainer.output_data_config = MagicMock()
56+
trainer.output_data_config.kms_key_id = None
5657
trainer.output_data_config.s3_output_path = "s3://bucket/output"
5758
trainer.compute = MagicMock()
5859
trainer.compute.instance_type = "ml.m5.xlarge"

0 commit comments

Comments
 (0)