Skip to content

Commit afcad51

Browse files
committed
fix: address review comments (iteration #5)
1 parent f865a27 commit afcad51

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Integration tests for IC-level deploy parameters (data_cache_config, variant_name)."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import uuid
18+
import time
19+
import random
20+
import logging
21+
22+
import boto3
23+
import pytest
24+
25+
from sagemaker.serve.model_builder import ModelBuilder
26+
from sagemaker.core.jumpstart.configs import JumpStartConfig
27+
from sagemaker.core.inference_config import ResourceRequirements
28+
from sagemaker.core.resources import (
29+
Endpoint,
30+
EndpointConfig,
31+
InferenceComponent,
32+
)
33+
from sagemaker.train.configs import Compute
34+
35+
logger = logging.getLogger(__name__)
36+
37+
# Use the same JumpStart model as test_jumpstart_integration.py
38+
MODEL_ID = "huggingface-llm-falcon-7b-bf16"
39+
40+
# Training job for model customization path (same as test_model_customization_deployment.py)
41+
TRAINING_JOB_NAME = "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445"
42+
43+
44+
def _cleanup_endpoint(endpoint_name, sagemaker_client):
45+
"""Delete endpoint, endpoint config, and all inference components."""
46+
try:
47+
# Delete inference components first
48+
paginator = sagemaker_client.get_paginator("list_inference_components")
49+
for page in paginator.paginate(EndpointNameEquals=endpoint_name):
50+
for ic in page.get("InferenceComponents", []):
51+
ic_name = ic["InferenceComponentName"]
52+
try:
53+
sagemaker_client.delete_inference_component(
54+
InferenceComponentName=ic_name
55+
)
56+
logger.info("Deleted inference component: %s", ic_name)
57+
except Exception as e:
58+
logger.warning("Failed to delete IC %s: %s", ic_name, e)
59+
except Exception as e:
60+
logger.warning("Failed to list/delete ICs for %s: %s", endpoint_name, e)
61+
62+
try:
63+
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
64+
logger.info("Deleted endpoint: %s", endpoint_name)
65+
except Exception as e:
66+
logger.warning("Failed to delete endpoint %s: %s", endpoint_name, e)
67+
68+
try:
69+
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name)
70+
logger.info("Deleted endpoint config: %s", endpoint_name)
71+
except Exception as e:
72+
logger.warning("Failed to delete endpoint config %s: %s", endpoint_name, e)
73+
74+
75+
def _cleanup_model(model_name, sagemaker_client):
76+
"""Delete a SageMaker model."""
77+
try:
78+
sagemaker_client.delete_model(ModelName=model_name)
79+
logger.info("Deleted model: %s", model_name)
80+
except Exception as e:
81+
logger.warning("Failed to delete model %s: %s", model_name, e)
82+
83+
84+
@pytest.mark.slow_test
85+
def test_deploy_with_data_cache_config_and_variant_name_via_ic_path():
86+
"""Deploy a JumpStart model via the IC-based path with data_cache_config and custom variant_name.
87+
88+
Verifies:
89+
- The IC was created with DataCacheConfig.EnableCaching == True
90+
- The variant name matches the custom value (not 'AllTraffic')
91+
"""
92+
unique_id = uuid.uuid4().hex[:8]
93+
model_name = f"ic-params-test-model-{unique_id}"
94+
endpoint_name = f"ic-params-test-ep-{unique_id}"
95+
custom_variant = f"Variant-{unique_id}"
96+
97+
sagemaker_client = boto3.client("sagemaker")
98+
ic_name = None
99+
100+
try:
101+
# Build
102+
compute = Compute(instance_type="ml.g5.2xlarge")
103+
jumpstart_config = JumpStartConfig(model_id=MODEL_ID)
104+
model_builder = ModelBuilder.from_jumpstart_config(
105+
jumpstart_config=jumpstart_config, compute=compute
106+
)
107+
core_model = model_builder.build(model_name=model_name)
108+
logger.info("Model created: %s", core_model.model_name)
109+
110+
# Deploy with IC path (ResourceRequirements triggers IC-based endpoint)
111+
resources = ResourceRequirements(
112+
requests={
113+
"memory": 8192,
114+
"num_accelerators": 1,
115+
"num_cpus": 2,
116+
"copies": 1,
117+
}
118+
)
119+
core_endpoint = model_builder.deploy(
120+
endpoint_name=endpoint_name,
121+
initial_instance_count=1,
122+
inference_config=resources,
123+
data_cache_config={"enable_caching": True},
124+
variant_name=custom_variant,
125+
)
126+
logger.info("Endpoint created: %s", core_endpoint.endpoint_name)
127+
128+
# Find the inference component that was created
129+
ic_name = model_builder.inference_component_name
130+
assert ic_name is not None, "inference_component_name should be set after deploy"
131+
132+
# Describe the inference component via boto3
133+
ic_desc = sagemaker_client.describe_inference_component(
134+
InferenceComponentName=ic_name
135+
)
136+
137+
# Verify DataCacheConfig.EnableCaching == True
138+
spec = ic_desc.get("Specification", {})
139+
data_cache = spec.get("DataCacheConfig", {})
140+
assert data_cache.get("EnableCaching") is True, (
141+
f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}"
142+
)
143+
144+
# Verify variant name matches custom value
145+
actual_variant = ic_desc.get("VariantName")
146+
assert actual_variant == custom_variant, (
147+
f"Expected VariantName='{custom_variant}', got '{actual_variant}'"
148+
)
149+
150+
logger.info(
151+
"Test passed: IC '%s' has DataCacheConfig.EnableCaching=True and VariantName='%s'",
152+
ic_name,
153+
custom_variant,
154+
)
155+
156+
finally:
157+
_cleanup_endpoint(endpoint_name, sagemaker_client)
158+
_cleanup_model(model_name, sagemaker_client)
159+
160+
161+
@pytest.mark.slow_test
162+
def test_deploy_with_data_cache_config_via_model_customization_path():
163+
"""Deploy a fine-tuned model via _deploy_model_customization with data_cache_config.
164+
165+
Verifies:
166+
- The IC was created with DataCacheConfig.EnableCaching == True
167+
- The variant_name defaults to endpoint_name (backward compat) when not explicitly provided
168+
"""
169+
from sagemaker.core.resources import TrainingJob
170+
171+
unique_id = uuid.uuid4().hex[:8]
172+
model_name = f"ic-mc-test-model-{unique_id}"
173+
endpoint_name = f"ic-mc-test-ep-{unique_id}"
174+
175+
sagemaker_client = boto3.client("sagemaker")
176+
177+
try:
178+
training_job = TrainingJob.get(training_job_name=TRAINING_JOB_NAME)
179+
model_builder = ModelBuilder(
180+
model=training_job, instance_type="ml.g5.4xlarge"
181+
)
182+
model_builder.accept_eula = True
183+
core_model = model_builder.build(model_name=model_name)
184+
logger.info("Model created: %s", core_model.model_name)
185+
186+
# Deploy with data_cache_config but WITHOUT explicit variant_name
187+
# so it should default to endpoint_name for model customization path
188+
endpoint = model_builder.deploy(
189+
endpoint_name=endpoint_name,
190+
initial_instance_count=1,
191+
data_cache_config={"enable_caching": True},
192+
)
193+
logger.info("Endpoint created: %s", endpoint.endpoint_name)
194+
195+
# Find inference components on this endpoint
196+
paginator = sagemaker_client.get_paginator("list_inference_components")
197+
ic_names = []
198+
for page in paginator.paginate(EndpointNameEquals=endpoint_name):
199+
for ic in page.get("InferenceComponents", []):
200+
ic_names.append(ic["InferenceComponentName"])
201+
202+
assert len(ic_names) > 0, (
203+
f"Expected at least one inference component on endpoint '{endpoint_name}'"
204+
)
205+
206+
# Check the first (or base) IC for DataCacheConfig
207+
# For LORA, the base IC should have data_cache_config; for non-LORA, the single IC.
208+
peft_type = model_builder._fetch_peft()
209+
if peft_type == "LORA":
210+
# Base IC is named <endpoint_name>-inference-component
211+
base_ic_name = f"{endpoint_name}-inference-component"
212+
else:
213+
base_ic_name = f"{endpoint_name}-inference-component"
214+
215+
ic_desc = sagemaker_client.describe_inference_component(
216+
InferenceComponentName=base_ic_name
217+
)
218+
219+
# Verify DataCacheConfig.EnableCaching == True
220+
spec = ic_desc.get("Specification", {})
221+
data_cache = spec.get("DataCacheConfig", {})
222+
assert data_cache.get("EnableCaching") is True, (
223+
f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}"
224+
)
225+
226+
# Verify variant_name defaults to endpoint_name (backward compat)
227+
actual_variant = ic_desc.get("VariantName")
228+
assert actual_variant == endpoint_name, (
229+
f"Expected VariantName='{endpoint_name}' (backward compat default), "
230+
f"got '{actual_variant}'"
231+
)
232+
233+
logger.info(
234+
"Test passed: IC '%s' has DataCacheConfig.EnableCaching=True "
235+
"and VariantName='%s' (backward compat default)",
236+
base_ic_name,
237+
endpoint_name,
238+
)
239+
240+
finally:
241+
_cleanup_endpoint(endpoint_name, sagemaker_client)
242+
_cleanup_model(model_name, sagemaker_client)

0 commit comments

Comments
 (0)