|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""Integration tests for IC-level deploy parameters (data_cache_config, variant_name).""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +import json |
| 17 | +import uuid |
| 18 | +import time |
| 19 | +import random |
| 20 | +import logging |
| 21 | + |
| 22 | +import boto3 |
| 23 | +import pytest |
| 24 | + |
| 25 | +from sagemaker.serve.model_builder import ModelBuilder |
| 26 | +from sagemaker.core.jumpstart.configs import JumpStartConfig |
| 27 | +from sagemaker.core.inference_config import ResourceRequirements |
| 28 | +from sagemaker.core.resources import ( |
| 29 | + Endpoint, |
| 30 | + EndpointConfig, |
| 31 | + InferenceComponent, |
| 32 | +) |
| 33 | +from sagemaker.train.configs import Compute |
| 34 | + |
| 35 | +logger = logging.getLogger(__name__) |
| 36 | + |
| 37 | +# Use the same JumpStart model as test_jumpstart_integration.py |
| 38 | +MODEL_ID = "huggingface-llm-falcon-7b-bf16" |
| 39 | + |
| 40 | +# Training job for model customization path (same as test_model_customization_deployment.py) |
| 41 | +TRAINING_JOB_NAME = "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445" |
| 42 | + |
| 43 | + |
| 44 | +def _cleanup_endpoint(endpoint_name, sagemaker_client): |
| 45 | + """Delete endpoint, endpoint config, and all inference components.""" |
| 46 | + try: |
| 47 | + # Delete inference components first |
| 48 | + paginator = sagemaker_client.get_paginator("list_inference_components") |
| 49 | + for page in paginator.paginate(EndpointNameEquals=endpoint_name): |
| 50 | + for ic in page.get("InferenceComponents", []): |
| 51 | + ic_name = ic["InferenceComponentName"] |
| 52 | + try: |
| 53 | + sagemaker_client.delete_inference_component( |
| 54 | + InferenceComponentName=ic_name |
| 55 | + ) |
| 56 | + logger.info("Deleted inference component: %s", ic_name) |
| 57 | + except Exception as e: |
| 58 | + logger.warning("Failed to delete IC %s: %s", ic_name, e) |
| 59 | + except Exception as e: |
| 60 | + logger.warning("Failed to list/delete ICs for %s: %s", endpoint_name, e) |
| 61 | + |
| 62 | + try: |
| 63 | + sagemaker_client.delete_endpoint(EndpointName=endpoint_name) |
| 64 | + logger.info("Deleted endpoint: %s", endpoint_name) |
| 65 | + except Exception as e: |
| 66 | + logger.warning("Failed to delete endpoint %s: %s", endpoint_name, e) |
| 67 | + |
| 68 | + try: |
| 69 | + sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name) |
| 70 | + logger.info("Deleted endpoint config: %s", endpoint_name) |
| 71 | + except Exception as e: |
| 72 | + logger.warning("Failed to delete endpoint config %s: %s", endpoint_name, e) |
| 73 | + |
| 74 | + |
| 75 | +def _cleanup_model(model_name, sagemaker_client): |
| 76 | + """Delete a SageMaker model.""" |
| 77 | + try: |
| 78 | + sagemaker_client.delete_model(ModelName=model_name) |
| 79 | + logger.info("Deleted model: %s", model_name) |
| 80 | + except Exception as e: |
| 81 | + logger.warning("Failed to delete model %s: %s", model_name, e) |
| 82 | + |
| 83 | + |
| 84 | +@pytest.mark.slow_test |
| 85 | +def test_deploy_with_data_cache_config_and_variant_name_via_ic_path(): |
| 86 | + """Deploy a JumpStart model via the IC-based path with data_cache_config and custom variant_name. |
| 87 | +
|
| 88 | + Verifies: |
| 89 | + - The IC was created with DataCacheConfig.EnableCaching == True |
| 90 | + - The variant name matches the custom value (not 'AllTraffic') |
| 91 | + """ |
| 92 | + unique_id = uuid.uuid4().hex[:8] |
| 93 | + model_name = f"ic-params-test-model-{unique_id}" |
| 94 | + endpoint_name = f"ic-params-test-ep-{unique_id}" |
| 95 | + custom_variant = f"Variant-{unique_id}" |
| 96 | + |
| 97 | + sagemaker_client = boto3.client("sagemaker") |
| 98 | + ic_name = None |
| 99 | + |
| 100 | + try: |
| 101 | + # Build |
| 102 | + compute = Compute(instance_type="ml.g5.2xlarge") |
| 103 | + jumpstart_config = JumpStartConfig(model_id=MODEL_ID) |
| 104 | + model_builder = ModelBuilder.from_jumpstart_config( |
| 105 | + jumpstart_config=jumpstart_config, compute=compute |
| 106 | + ) |
| 107 | + core_model = model_builder.build(model_name=model_name) |
| 108 | + logger.info("Model created: %s", core_model.model_name) |
| 109 | + |
| 110 | + # Deploy with IC path (ResourceRequirements triggers IC-based endpoint) |
| 111 | + resources = ResourceRequirements( |
| 112 | + requests={ |
| 113 | + "memory": 8192, |
| 114 | + "num_accelerators": 1, |
| 115 | + "num_cpus": 2, |
| 116 | + "copies": 1, |
| 117 | + } |
| 118 | + ) |
| 119 | + core_endpoint = model_builder.deploy( |
| 120 | + endpoint_name=endpoint_name, |
| 121 | + initial_instance_count=1, |
| 122 | + inference_config=resources, |
| 123 | + data_cache_config={"enable_caching": True}, |
| 124 | + variant_name=custom_variant, |
| 125 | + ) |
| 126 | + logger.info("Endpoint created: %s", core_endpoint.endpoint_name) |
| 127 | + |
| 128 | + # Find the inference component that was created |
| 129 | + ic_name = model_builder.inference_component_name |
| 130 | + assert ic_name is not None, "inference_component_name should be set after deploy" |
| 131 | + |
| 132 | + # Describe the inference component via boto3 |
| 133 | + ic_desc = sagemaker_client.describe_inference_component( |
| 134 | + InferenceComponentName=ic_name |
| 135 | + ) |
| 136 | + |
| 137 | + # Verify DataCacheConfig.EnableCaching == True |
| 138 | + spec = ic_desc.get("Specification", {}) |
| 139 | + data_cache = spec.get("DataCacheConfig", {}) |
| 140 | + assert data_cache.get("EnableCaching") is True, ( |
| 141 | + f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}" |
| 142 | + ) |
| 143 | + |
| 144 | + # Verify variant name matches custom value |
| 145 | + actual_variant = ic_desc.get("VariantName") |
| 146 | + assert actual_variant == custom_variant, ( |
| 147 | + f"Expected VariantName='{custom_variant}', got '{actual_variant}'" |
| 148 | + ) |
| 149 | + |
| 150 | + logger.info( |
| 151 | + "Test passed: IC '%s' has DataCacheConfig.EnableCaching=True and VariantName='%s'", |
| 152 | + ic_name, |
| 153 | + custom_variant, |
| 154 | + ) |
| 155 | + |
| 156 | + finally: |
| 157 | + _cleanup_endpoint(endpoint_name, sagemaker_client) |
| 158 | + _cleanup_model(model_name, sagemaker_client) |
| 159 | + |
| 160 | + |
| 161 | +@pytest.mark.slow_test |
| 162 | +def test_deploy_with_data_cache_config_via_model_customization_path(): |
| 163 | + """Deploy a fine-tuned model via _deploy_model_customization with data_cache_config. |
| 164 | +
|
| 165 | + Verifies: |
| 166 | + - The IC was created with DataCacheConfig.EnableCaching == True |
| 167 | + - The variant_name defaults to endpoint_name (backward compat) when not explicitly provided |
| 168 | + """ |
| 169 | + from sagemaker.core.resources import TrainingJob |
| 170 | + |
| 171 | + unique_id = uuid.uuid4().hex[:8] |
| 172 | + model_name = f"ic-mc-test-model-{unique_id}" |
| 173 | + endpoint_name = f"ic-mc-test-ep-{unique_id}" |
| 174 | + |
| 175 | + sagemaker_client = boto3.client("sagemaker") |
| 176 | + |
| 177 | + try: |
| 178 | + training_job = TrainingJob.get(training_job_name=TRAINING_JOB_NAME) |
| 179 | + model_builder = ModelBuilder( |
| 180 | + model=training_job, instance_type="ml.g5.4xlarge" |
| 181 | + ) |
| 182 | + model_builder.accept_eula = True |
| 183 | + core_model = model_builder.build(model_name=model_name) |
| 184 | + logger.info("Model created: %s", core_model.model_name) |
| 185 | + |
| 186 | + # Deploy with data_cache_config but WITHOUT explicit variant_name |
| 187 | + # so it should default to endpoint_name for model customization path |
| 188 | + endpoint = model_builder.deploy( |
| 189 | + endpoint_name=endpoint_name, |
| 190 | + initial_instance_count=1, |
| 191 | + data_cache_config={"enable_caching": True}, |
| 192 | + ) |
| 193 | + logger.info("Endpoint created: %s", endpoint.endpoint_name) |
| 194 | + |
| 195 | + # Find inference components on this endpoint |
| 196 | + paginator = sagemaker_client.get_paginator("list_inference_components") |
| 197 | + ic_names = [] |
| 198 | + for page in paginator.paginate(EndpointNameEquals=endpoint_name): |
| 199 | + for ic in page.get("InferenceComponents", []): |
| 200 | + ic_names.append(ic["InferenceComponentName"]) |
| 201 | + |
| 202 | + assert len(ic_names) > 0, ( |
| 203 | + f"Expected at least one inference component on endpoint '{endpoint_name}'" |
| 204 | + ) |
| 205 | + |
| 206 | + # Check the first (or base) IC for DataCacheConfig |
| 207 | + # For LORA, the base IC should have data_cache_config; for non-LORA, the single IC. |
| 208 | + peft_type = model_builder._fetch_peft() |
| 209 | + if peft_type == "LORA": |
| 210 | + # Base IC is named <endpoint_name>-inference-component |
| 211 | + base_ic_name = f"{endpoint_name}-inference-component" |
| 212 | + else: |
| 213 | + base_ic_name = f"{endpoint_name}-inference-component" |
| 214 | + |
| 215 | + ic_desc = sagemaker_client.describe_inference_component( |
| 216 | + InferenceComponentName=base_ic_name |
| 217 | + ) |
| 218 | + |
| 219 | + # Verify DataCacheConfig.EnableCaching == True |
| 220 | + spec = ic_desc.get("Specification", {}) |
| 221 | + data_cache = spec.get("DataCacheConfig", {}) |
| 222 | + assert data_cache.get("EnableCaching") is True, ( |
| 223 | + f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}" |
| 224 | + ) |
| 225 | + |
| 226 | + # Verify variant_name defaults to endpoint_name (backward compat) |
| 227 | + actual_variant = ic_desc.get("VariantName") |
| 228 | + assert actual_variant == endpoint_name, ( |
| 229 | + f"Expected VariantName='{endpoint_name}' (backward compat default), " |
| 230 | + f"got '{actual_variant}'" |
| 231 | + ) |
| 232 | + |
| 233 | + logger.info( |
| 234 | + "Test passed: IC '%s' has DataCacheConfig.EnableCaching=True " |
| 235 | + "and VariantName='%s' (backward compat default)", |
| 236 | + base_ic_name, |
| 237 | + endpoint_name, |
| 238 | + ) |
| 239 | + |
| 240 | + finally: |
| 241 | + _cleanup_endpoint(endpoint_name, sagemaker_client) |
| 242 | + _cleanup_model(model_name, sagemaker_client) |
0 commit comments