|
37 | 37 | # Use the same JumpStart model as test_jumpstart_integration.py |
38 | 38 | MODEL_ID = "huggingface-llm-falcon-7b-bf16" |
39 | 39 |
|
40 | | -# Training job for model customization path (same as test_model_customization_deployment.py) |
41 | | -TRAINING_JOB_NAME = "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445" |
42 | | - |
43 | 40 |
|
44 | 41 | def _cleanup_endpoint(endpoint_name, sagemaker_client): |
45 | 42 | """Delete endpoint, endpoint config, and all inference components.""" |
@@ -158,85 +155,4 @@ def test_deploy_with_data_cache_config_and_variant_name_via_ic_path(): |
158 | 155 | _cleanup_model(model_name, sagemaker_client) |
159 | 156 |
|
160 | 157 |
|
161 | | -@pytest.mark.slow_test |
162 | | -def test_deploy_with_data_cache_config_via_model_customization_path(): |
163 | | - """Deploy a fine-tuned model via _deploy_model_customization with data_cache_config. |
164 | | -
|
165 | | - Verifies: |
166 | | - - The IC was created with DataCacheConfig.EnableCaching == True |
167 | | - - The variant_name defaults to endpoint_name (backward compat) when not explicitly provided |
168 | | - """ |
169 | | - from sagemaker.core.resources import TrainingJob |
170 | | - |
171 | | - unique_id = uuid.uuid4().hex[:8] |
172 | | - model_name = f"ic-mc-test-model-{unique_id}" |
173 | | - endpoint_name = f"ic-mc-test-ep-{unique_id}" |
174 | | - |
175 | | - sagemaker_client = boto3.client("sagemaker") |
176 | | - |
177 | | - try: |
178 | | - training_job = TrainingJob.get(training_job_name=TRAINING_JOB_NAME) |
179 | | - model_builder = ModelBuilder( |
180 | | - model=training_job, instance_type="ml.g5.4xlarge" |
181 | | - ) |
182 | | - model_builder.accept_eula = True |
183 | | - core_model = model_builder.build(model_name=model_name) |
184 | | - logger.info("Model created: %s", core_model.model_name) |
185 | | - |
186 | | - # Deploy with data_cache_config but WITHOUT explicit variant_name |
187 | | - # so it should default to endpoint_name for model customization path |
188 | | - endpoint = model_builder.deploy( |
189 | | - endpoint_name=endpoint_name, |
190 | | - initial_instance_count=1, |
191 | | - data_cache_config={"enable_caching": True}, |
192 | | - ) |
193 | | - logger.info("Endpoint created: %s", endpoint.endpoint_name) |
194 | | - |
195 | | - # Find inference components on this endpoint |
196 | | - paginator = sagemaker_client.get_paginator("list_inference_components") |
197 | | - ic_names = [] |
198 | | - for page in paginator.paginate(EndpointNameEquals=endpoint_name): |
199 | | - for ic in page.get("InferenceComponents", []): |
200 | | - ic_names.append(ic["InferenceComponentName"]) |
201 | | - |
202 | | - assert len(ic_names) > 0, ( |
203 | | - f"Expected at least one inference component on endpoint '{endpoint_name}'" |
204 | | - ) |
205 | | - |
206 | | - # Check the first (or base) IC for DataCacheConfig |
207 | | - # For LORA, the base IC should have data_cache_config; for non-LORA, the single IC. |
208 | | - peft_type = model_builder._fetch_peft() |
209 | | - if peft_type == "LORA": |
210 | | - # Base IC is named <endpoint_name>-inference-component |
211 | | - base_ic_name = f"{endpoint_name}-inference-component" |
212 | | - else: |
213 | | - base_ic_name = f"{endpoint_name}-inference-component" |
214 | | - |
215 | | - ic_desc = sagemaker_client.describe_inference_component( |
216 | | - InferenceComponentName=base_ic_name |
217 | | - ) |
218 | 158 |
|
219 | | - # Verify DataCacheConfig.EnableCaching == True |
220 | | - spec = ic_desc.get("Specification", {}) |
221 | | - data_cache = spec.get("DataCacheConfig", {}) |
222 | | - assert data_cache.get("EnableCaching") is True, ( |
223 | | - f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}" |
224 | | - ) |
225 | | - |
226 | | - # Verify variant_name defaults to endpoint_name (backward compat) |
227 | | - actual_variant = ic_desc.get("VariantName") |
228 | | - assert actual_variant == endpoint_name, ( |
229 | | - f"Expected VariantName='{endpoint_name}' (backward compat default), " |
230 | | - f"got '{actual_variant}'" |
231 | | - ) |
232 | | - |
233 | | - logger.info( |
234 | | - "Test passed: IC '%s' has DataCacheConfig.EnableCaching=True " |
235 | | - "and VariantName='%s' (backward compat default)", |
236 | | - base_ic_name, |
237 | | - endpoint_name, |
238 | | - ) |
239 | | - |
240 | | - finally: |
241 | | - _cleanup_endpoint(endpoint_name, sagemaker_client) |
242 | | - _cleanup_model(model_name, sagemaker_client) |
0 commit comments