Skip to content

Commit 723455f

Browse files
fix(train): Skip default instance_type/instance_count when instance_groups is set (#5564)
Guard the default injection of instance_type and instance_count in TrainDefaults.get_compute() and JumpStartTrainDefaults.get_compute() so that these values are not populated when instance_groups is configured. The SageMaker API treats instance_type/instance_count and instance_groups as mutually exclusive in ResourceConfig, and unconditionally setting defaults causes a ValidationException. Fixes #5555 Co-authored-by: Mufaddal Rohawala <mufi@amazon.com>
1 parent 7732ecf commit 723455f

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

sagemaker-train/src/sagemaker/train/defaults.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,15 @@ def get_compute(compute: Optional[Compute] = None) -> Compute:
100100
volume_size_in_gb=DEFAULT_VOLUME_SIZE,
101101
)
102102
logger.info(f"Compute not provided. Using default:\n{compute}")
103-
if compute.instance_type is None:
104-
compute.instance_type = DEFAULT_INSTANCE_TYPE
105-
logger.info(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}")
106-
if compute.instance_count is None:
107-
compute.instance_count = DEFAULT_INSTANCE_COUNT
108-
logger.info(f"Instance count not provided. Using default:\n{compute.instance_count}")
103+
if not compute.instance_groups:
104+
if compute.instance_type is None:
105+
compute.instance_type = DEFAULT_INSTANCE_TYPE
106+
logger.info(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}")
107+
if compute.instance_count is None:
108+
compute.instance_count = DEFAULT_INSTANCE_COUNT
109+
logger.info(
110+
f"Instance count not provided. Using default:\n{compute.instance_count}"
111+
)
109112
if compute.volume_size_in_gb is None:
110113
compute.volume_size_in_gb = DEFAULT_VOLUME_SIZE
111114
logger.info(f"Volume size not provided. Using default:\n{compute.volume_size_in_gb}")
@@ -225,21 +228,28 @@ def get_compute(
225228
),
226229
)
227230
logger.info(f"Compute not provided. Using default compute:\n{compute}")
228-
if compute.instance_type is None and training_components_model.DefaultTrainingInstanceType:
229-
compute.instance_type = training_components_model.DefaultTrainingInstanceType
230-
logger.info(
231-
f"Instance type not provided. Using default instance type:\n{compute.instance_type}"
232-
)
231+
if not compute.instance_groups:
232+
if (
233+
compute.instance_type is None
234+
and training_components_model.DefaultTrainingInstanceType
235+
):
236+
compute.instance_type = training_components_model.DefaultTrainingInstanceType
237+
logger.info(
238+
f"Instance type not provided. Using default instance type:"
239+
f"\n{compute.instance_type}"
240+
)
241+
if compute.instance_count is None:
242+
compute.instance_count = DEFAULT_INSTANCE_COUNT
243+
logger.info(
244+
f"Instance count not provided. Using default instance count:\n{compute}"
245+
)
233246
if compute.volume_size_in_gb is None:
234247
compute.volume_size_in_gb = (
235248
training_components_model.TrainingVolumeSize or DEFAULT_VOLUME_SIZE
236249
)
237250
logger.info(
238251
f"Volume size not provided. Using default volume size:\n{compute.volume_size_in_gb}"
239252
)
240-
if compute.instance_count is None:
241-
compute.instance_count = DEFAULT_INSTANCE_COUNT
242-
logger.info(f"Instance count not provided. Using default instance count:\n{compute}")
243253
return compute
244254

245255
def get_networking(

0 commit comments

Comments
 (0)