@@ -67,6 +67,7 @@ def test_init_with_all_params(self, mock_session):
6767 base_transform_job_name = "test-job" ,
6868 sagemaker_session = mock_session ,
6969 volume_kms_key = "volume-key" ,
70+ transform_ami_version = "al2-ami-sagemaker-batch-gpu-535" ,
7071 )
7172
7273 assert transformer .strategy == "MultiRecord"
@@ -77,6 +78,7 @@ def test_init_with_all_params(self, mock_session):
7778 assert transformer .max_concurrent_transforms == 4
7879 assert transformer .max_payload == 10
7980 assert transformer .volume_kms_key == "volume-key"
81+ assert transformer .transform_ami_version == "al2-ami-sagemaker-batch-gpu-535"
8082
8183 def test_format_inputs_to_input_config (self , mock_session ):
8284 """Test _format_inputs_to_input_config method"""
@@ -179,6 +181,27 @@ def test_prepare_resource_config(self, mock_session):
179181 assert config ["instance_type" ] == "ml.m5.xlarge"
180182 assert config ["volume_kms_key_id" ] == "volume-key"
181183
184+ def test_prepare_resource_config_with_ami_version (self , mock_session ):
185+ """Test _prepare_resource_config with transform_ami_version"""
186+ transformer = Transformer (
187+ model_name = "test-model" ,
188+ instance_count = 1 ,
189+ instance_type = "ml.m5.xlarge" ,
190+ sagemaker_session = mock_session ,
191+ )
192+
193+ config = transformer ._prepare_resource_config (
194+ instance_count = 2 ,
195+ instance_type = "ml.g4dn.xlarge" ,
196+ volume_kms_key = "volume-key" ,
197+ transform_ami_version = "al2-ami-sagemaker-batch-gpu-535" ,
198+ )
199+
200+ assert config ["instance_count" ] == 2
201+ assert config ["instance_type" ] == "ml.g4dn.xlarge"
202+ assert config ["volume_kms_key_id" ] == "volume-key"
203+ assert config ["transform_ami_version" ] == "al2-ami-sagemaker-batch-gpu-535"
204+
182205 def test_prepare_resource_config_no_kms (self , mock_session ):
183206 """Test _prepare_resource_config without KMS key"""
184207 transformer = Transformer (
@@ -195,6 +218,7 @@ def test_prepare_resource_config_no_kms(self, mock_session):
195218 assert config ["instance_count" ] == 1
196219 assert config ["instance_type" ] == "ml.m5.xlarge"
197220 assert "volume_kms_key_id" not in config
221+ assert "transform_ami_version" not in config
198222
199223 def test_prepare_data_processing_all_params (self , mock_session ):
200224 """Test _prepare_data_processing with all parameters"""
@@ -438,6 +462,87 @@ def test_prepare_init_params_from_job_description(self, mock_session):
438462 assert init_params ["volume_kms_key" ] == "volume-key"
439463 assert init_params ["base_transform_job_name" ] == "test-job-456"
440464
465+ def test_prepare_init_params_from_job_description_with_ami_version (self , mock_session ):
466+ """Test _prepare_init_params_from_job_description with transform_ami_version"""
467+ job_details = {
468+ "model_name" : "test-model" ,
469+ "transform_resources" : Mock (
470+ instance_count = 2 ,
471+ instance_type = "ml.g4dn.xlarge" ,
472+ volume_kms_key_id = "volume-key" ,
473+ transform_ami_version = "al2-ami-sagemaker-batch-gpu-535" ,
474+ ),
475+ "batch_strategy" : "SingleRecord" ,
476+ "transform_output" : Mock (
477+ assemble_with = "None" ,
478+ s3_output_path = "s3://bucket/output" ,
479+ kms_key_id = "output-key" ,
480+ accept = "text/csv" ,
481+ ),
482+ "max_concurrent_transforms" : 8 ,
483+ "max_payload_in_mb" : 20 ,
484+ "transform_job_name" : "test-job-789" ,
485+ }
486+
487+ init_params = Transformer ._prepare_init_params_from_job_description (job_details )
488+
489+ assert init_params ["model_name" ] == "test-model"
490+ assert init_params ["instance_count" ] == 2
491+ assert init_params ["instance_type" ] == "ml.g4dn.xlarge"
492+ assert init_params ["volume_kms_key" ] == "volume-key"
493+ assert init_params ["transform_ami_version" ] == "al2-ami-sagemaker-batch-gpu-535"
494+ assert init_params ["base_transform_job_name" ] == "test-job-789"
495+
496+ def test_init_with_transform_ami_version (self , mock_session ):
497+ """Test initialization with transform_ami_version parameter"""
498+ transformer = Transformer (
499+ model_name = "test-model" ,
500+ instance_count = 1 ,
501+ instance_type = "ml.g4dn.xlarge" ,
502+ transform_ami_version = "al2-ami-sagemaker-batch-gpu-535" ,
503+ sagemaker_session = mock_session ,
504+ )
505+
506+ assert transformer .model_name == "test-model"
507+ assert transformer .instance_count == 1
508+ assert transformer .instance_type == "ml.g4dn.xlarge"
509+ assert transformer .transform_ami_version == "al2-ami-sagemaker-batch-gpu-535"
510+
511+ def test_init_without_transform_ami_version (self , mock_session ):
512+ """Test initialization without transform_ami_version parameter"""
513+ transformer = Transformer (
514+ model_name = "test-model" ,
515+ instance_count = 1 ,
516+ instance_type = "ml.g4dn.xlarge" ,
517+ sagemaker_session = mock_session ,
518+ )
519+
520+ assert transformer .transform_ami_version is None
521+
522+ def test_load_config_with_transform_ami_version (self , mock_session ):
523+ """Test _load_config includes transform_ami_version in resource_config"""
524+ transformer = Transformer (
525+ model_name = "test-model" ,
526+ instance_count = 2 ,
527+ instance_type = "ml.g4dn.xlarge" ,
528+ output_path = "s3://bucket/output" ,
529+ transform_ami_version = "al2-ami-sagemaker-batch-gpu-535" ,
530+ sagemaker_session = mock_session ,
531+ )
532+
533+ config = transformer ._load_config (
534+ data = "s3://bucket/input" ,
535+ data_type = "S3Prefix" ,
536+ content_type = "text/csv" ,
537+ compression_type = None ,
538+ split_type = "Line" ,
539+ )
540+
541+ assert "resource_config" in config
542+ assert config ["resource_config" ]["instance_count" ] == 2
543+ assert config ["resource_config" ]["instance_type" ] == "ml.g4dn.xlarge"
544+ assert config ["resource_config" ]["transform_ami_version" ] == "al2-ami-sagemaker-batch-gpu-535"
545+
441546 def test_delete_model (self , mock_session ):
442547 """Test delete_model method"""
443548 transformer = Transformer (
0 commit comments