Skip to content

Commit 9afb079

Browse files
authored
Merge branch 'master' into mc-bug-fix
2 parents 38c9058 + 88963f8 commit 9afb079

File tree

2 files changed

+129
-2
lines changed

2 files changed

+129
-2
lines changed

sagemaker-core/src/sagemaker/core/transformer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
base_transform_job_name: Optional[str] = None,
8787
sagemaker_session: Optional[Session] = None,
8888
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
89+
transform_ami_version: Optional[Union[str, PipelineVariable]] = None,
8990
):
9091
"""Initialize a ``Transformer``.
9192
@@ -126,6 +127,15 @@ def __init__(
126127
AWS services needed.
127128
volume_kms_key (str or PipelineVariable): Optional. KMS key ID for encrypting
128129
the volume attached to the ML compute instance (default: None).
130+
transform_ami_version (str or PipelineVariable): Optional. Specifies an option
131+
from a collection of preconfigured Amazon Machine Image (AMI) images.
132+
Each image is configured by Amazon Web Services with a set of software
133+
and driver versions. Valid values include:
134+
135+
* 'al2-ami-sagemaker-batch-gpu-470' - GPU accelerator with NVIDIA driver 470
136+
* 'al2-ami-sagemaker-batch-gpu-535' - GPU accelerator with NVIDIA driver 535
137+
138+
(default: None).
129139
"""
130140
self.model_name = model_name
131141
self.strategy = strategy
@@ -162,6 +172,7 @@ def __init__(
162172
TRANSFORM_JOB_ENVIRONMENT_PATH,
163173
sagemaker_session=self.sagemaker_session,
164174
)
175+
self.transform_ami_version = transform_ami_version
165176

166177
@runnable_by_pipeline
167178
def transform(
@@ -517,6 +528,9 @@ def _prepare_init_params_from_job_description(cls, job_details):
517528
init_params["volume_kms_key"] = getattr(
518529
job_details["transform_resources"], "volume_kms_key_id", None
519530
)
531+
init_params["transform_ami_version"] = getattr(
532+
job_details["transform_resources"], "transform_ami_version", None
533+
)
520534
init_params["strategy"] = job_details.get("batch_strategy")
521535
if job_details.get("transform_output"):
522536
init_params["assemble_with"] = getattr(
@@ -584,7 +598,10 @@ def _load_config(self, data, data_type, content_type, compression_type, split_ty
584598
)
585599

586600
resource_config = self._prepare_resource_config(
587-
self.instance_count, self.instance_type, self.volume_kms_key
601+
self.instance_count,
602+
self.instance_type,
603+
self.volume_kms_key,
604+
self.transform_ami_version,
588605
)
589606

590607
return {
@@ -631,13 +648,18 @@ def _prepare_output_config(self, s3_path, kms_key_id, assemble_with, accept):
631648

632649
return config
633650

634-
def _prepare_resource_config(self, instance_count, instance_type, volume_kms_key):
651+
def _prepare_resource_config(
652+
self, instance_count, instance_type, volume_kms_key, transform_ami_version=None
653+
):
635654
"""Prepare resource config."""
636655
config = {"instance_count": instance_count, "instance_type": instance_type}
637656

638657
if volume_kms_key is not None:
639658
config["volume_kms_key_id"] = volume_kms_key
640659

660+
if transform_ami_version is not None:
661+
config["transform_ami_version"] = transform_ami_version
662+
641663
return config
642664

643665
def _prepare_data_processing(self, input_filter, output_filter, join_source):

sagemaker-core/tests/unit/test_transformer.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_init_with_all_params(self, mock_session):
6767
base_transform_job_name="test-job",
6868
sagemaker_session=mock_session,
6969
volume_kms_key="volume-key",
70+
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
7071
)
7172

7273
assert transformer.strategy == "MultiRecord"
@@ -77,6 +78,7 @@ def test_init_with_all_params(self, mock_session):
7778
assert transformer.max_concurrent_transforms == 4
7879
assert transformer.max_payload == 10
7980
assert transformer.volume_kms_key == "volume-key"
81+
assert transformer.transform_ami_version == "al2-ami-sagemaker-batch-gpu-535"
8082

8183
def test_format_inputs_to_input_config(self, mock_session):
8284
"""Test _format_inputs_to_input_config method"""
@@ -179,6 +181,27 @@ def test_prepare_resource_config(self, mock_session):
179181
assert config["instance_type"] == "ml.m5.xlarge"
180182
assert config["volume_kms_key_id"] == "volume-key"
181183

184+
def test_prepare_resource_config_with_ami_version(self, mock_session):
185+
"""Test _prepare_resource_config with transform_ami_version"""
186+
transformer = Transformer(
187+
model_name="test-model",
188+
instance_count=1,
189+
instance_type="ml.m5.xlarge",
190+
sagemaker_session=mock_session,
191+
)
192+
193+
config = transformer._prepare_resource_config(
194+
instance_count=2,
195+
instance_type="ml.g4dn.xlarge",
196+
volume_kms_key="volume-key",
197+
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
198+
)
199+
200+
assert config["instance_count"] == 2
201+
assert config["instance_type"] == "ml.g4dn.xlarge"
202+
assert config["volume_kms_key_id"] == "volume-key"
203+
assert config["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535"
204+
182205
def test_prepare_resource_config_no_kms(self, mock_session):
183206
"""Test _prepare_resource_config without KMS key"""
184207
transformer = Transformer(
@@ -195,6 +218,7 @@ def test_prepare_resource_config_no_kms(self, mock_session):
195218
assert config["instance_count"] == 1
196219
assert config["instance_type"] == "ml.m5.xlarge"
197220
assert "volume_kms_key_id" not in config
221+
assert "transform_ami_version" not in config
198222

199223
def test_prepare_data_processing_all_params(self, mock_session):
200224
"""Test _prepare_data_processing with all parameters"""
@@ -438,6 +462,87 @@ def test_prepare_init_params_from_job_description(self, mock_session):
438462
assert init_params["volume_kms_key"] == "volume-key"
439463
assert init_params["base_transform_job_name"] == "test-job-456"
440464

465+
def test_prepare_init_params_from_job_description_with_ami_version(self, mock_session):
466+
"""Test _prepare_init_params_from_job_description with transform_ami_version"""
467+
job_details = {
468+
"model_name": "test-model",
469+
"transform_resources": Mock(
470+
instance_count=2,
471+
instance_type="ml.g4dn.xlarge",
472+
volume_kms_key_id="volume-key",
473+
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
474+
),
475+
"batch_strategy": "SingleRecord",
476+
"transform_output": Mock(
477+
assemble_with="None",
478+
s3_output_path="s3://bucket/output",
479+
kms_key_id="output-key",
480+
accept="text/csv",
481+
),
482+
"max_concurrent_transforms": 8,
483+
"max_payload_in_mb": 20,
484+
"transform_job_name": "test-job-789",
485+
}
486+
487+
init_params = Transformer._prepare_init_params_from_job_description(job_details)
488+
489+
assert init_params["model_name"] == "test-model"
490+
assert init_params["instance_count"] == 2
491+
assert init_params["instance_type"] == "ml.g4dn.xlarge"
492+
assert init_params["volume_kms_key"] == "volume-key"
493+
assert init_params["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535"
494+
assert init_params["base_transform_job_name"] == "test-job-789"
495+
496+
def test_init_with_transform_ami_version(self, mock_session):
497+
"""Test initialization with transform_ami_version parameter"""
498+
transformer = Transformer(
499+
model_name="test-model",
500+
instance_count=1,
501+
instance_type="ml.g4dn.xlarge",
502+
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
503+
sagemaker_session=mock_session,
504+
)
505+
506+
assert transformer.model_name == "test-model"
507+
assert transformer.instance_count == 1
508+
assert transformer.instance_type == "ml.g4dn.xlarge"
509+
assert transformer.transform_ami_version == "al2-ami-sagemaker-batch-gpu-535"
510+
511+
def test_init_without_transform_ami_version(self, mock_session):
512+
"""Test initialization without transform_ami_version parameter"""
513+
transformer = Transformer(
514+
model_name="test-model",
515+
instance_count=1,
516+
instance_type="ml.g4dn.xlarge",
517+
sagemaker_session=mock_session,
518+
)
519+
520+
assert transformer.transform_ami_version is None
521+
522+
def test_load_config_with_transform_ami_version(self, mock_session):
523+
"""Test _load_config includes transform_ami_version in resource_config"""
524+
transformer = Transformer(
525+
model_name="test-model",
526+
instance_count=2,
527+
instance_type="ml.g4dn.xlarge",
528+
output_path="s3://bucket/output",
529+
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
530+
sagemaker_session=mock_session,
531+
)
532+
533+
config = transformer._load_config(
534+
data="s3://bucket/input",
535+
data_type="S3Prefix",
536+
content_type="text/csv",
537+
compression_type=None,
538+
split_type="Line",
539+
)
540+
541+
assert "resource_config" in config
542+
assert config["resource_config"]["instance_count"] == 2
543+
assert config["resource_config"]["instance_type"] == "ml.g4dn.xlarge"
544+
assert config["resource_config"]["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535"
545+
441546
def test_delete_model(self, mock_session):
442547
"""Test delete_model method"""
443548
transformer = Transformer(

0 commit comments

Comments
 (0)