Skip to content

Commit 6004e8b

Browse files
committed
fixes for model builder
1 parent a0b3b99 commit 6004e8b

2 files changed

Lines changed: 177 additions & 120 deletions

File tree

sagemaker-core/src/sagemaker/core/shapes/shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8577,7 +8577,7 @@ class InferenceComponentComputeResourceRequirements(Base):
85778577
max_memory_required_in_mb: The maximum MB of memory to allocate to run a model that you assign to an inference component.
85788578
"""
85798579

8580-
min_memory_required_in_mb: int
8580+
min_memory_required_in_mb: Optional[int] = Unassigned()
85818581
number_of_cpu_cores_required: Optional[float] = Unassigned()
85828582
number_of_accelerator_devices_required: Optional[float] = Unassigned()
85838583
max_memory_required_in_mb: Optional[int] = Unassigned()

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 176 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,10 @@ def _fetch_and_cache_recipe_config(self):
957957
if not self.image_uri:
958958
self.image_uri = config.get("EcrAddress")
959959

960+
# Cache environment variables from recipe config
961+
if not self.env_vars:
962+
self.env_vars = config.get("Environment", {})
963+
960964
# Infer instance type from JumpStart metadata if not provided
961965
# This is only called for model_customization deployments
962966
if not self.instance_type:
@@ -2211,21 +2215,57 @@ def _build_single_modelbuilder(
22112215
"Only SageMaker Endpoint Mode is supported for Model Customization use cases"
22122216
)
22132217
model_package = self._fetch_model_package()
2214-
# Fetch recipe config first to set image_uri, instance_type, and s3_upload_path
2218+
# Fetch recipe config first to set image_uri, instance_type, env_vars, and s3_upload_path
22152219
self._fetch_and_cache_recipe_config()
2216-
self.s3_upload_path = model_package.inference_specification.containers[
2217-
0
2218-
].model_data_source.s3_data_source.s3_uri
2219-
container_def = ContainerDefinition(
2220-
image=self.image_uri,
2221-
model_data_source={
2222-
"s3_data_source": {
2223-
"s3_uri": f"{self.s3_upload_path}/",
2224-
"s3_data_type": "S3Prefix",
2225-
"compression_type": "None",
2226-
}
2227-
},
2228-
)
2220+
peft_type = self._fetch_peft()
2221+
2222+
if peft_type == "LORA":
2223+
# For LORA: Model points at JumpStart base model, not training output
2224+
hub_document = self._fetch_hub_document_for_custom_model()
2225+
hosting_artifact_uri = hub_document.get("HostingArtifactUri")
2226+
if not hosting_artifact_uri:
2227+
raise ValueError(
2228+
"HostingArtifactUri not found in JumpStart hub metadata. "
2229+
"Cannot deploy LORA adapter without base model artifacts."
2230+
)
2231+
container_def = ContainerDefinition(
2232+
image=self.image_uri,
2233+
environment=self.env_vars,
2234+
model_data_source={
2235+
"s3_data_source": {
2236+
"s3_uri": hosting_artifact_uri,
2237+
"s3_data_type": "S3Prefix",
2238+
"compression_type": "None",
2239+
"model_access_config": {"accept_eula": True},
2240+
}
2241+
},
2242+
)
2243+
# Store adapter path for use during deploy
2244+
if isinstance(self.model, TrainingJob):
2245+
self._adapter_s3_uri = (
2246+
f"{self.model.model_artifacts.s3_model_artifacts}/checkpoints/hf/"
2247+
)
2248+
elif isinstance(self.model, ModelTrainer):
2249+
self._adapter_s3_uri = (
2250+
f"{self.model._latest_training_job.model_artifacts.s3_model_artifacts}"
2251+
"/checkpoints/hf/"
2252+
)
2253+
else:
2254+
# Non-LORA: Model points at training output
2255+
self.s3_upload_path = model_package.inference_specification.containers[
2256+
0
2257+
].model_data_source.s3_data_source.s3_uri
2258+
container_def = ContainerDefinition(
2259+
image=self.image_uri,
2260+
model_data_source={
2261+
"s3_data_source": {
2262+
"s3_uri": self.s3_upload_path.rstrip("/") + "/",
2263+
"s3_data_type": "S3Prefix",
2264+
"compression_type": "None",
2265+
}
2266+
},
2267+
)
2268+
22292269
model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}"
22302270
# Create model
22312271
self.built_model = Model.create(
@@ -4142,17 +4182,13 @@ def _deploy_model_customization(
41424182
"""Deploy a model customization (fine-tuned) model to an endpoint with inference components.
41434183
41444184
This method handles the special deployment flow for fine-tuned models, creating:
4145-
1. Core Model resource
4146-
2. EndpointConfig
4147-
3. Endpoint
4148-
4. InferenceComponent
4185+
1. EndpointConfig and Endpoint
4186+
2. Base model InferenceComponent (for LORA: from JumpStart base model)
4187+
3. Adapter InferenceComponent (for LORA: referencing base IC with adapter weights)
41494188
41504189
Args:
41514190
endpoint_name (str): Name of the endpoint to create or update
4152-
instance_type (str): EC2 instance type for deployment
41534191
initial_instance_count (int): Number of instances (default: 1)
4154-
wait (bool): Whether to wait for deployment to complete (default: True)
4155-
container_timeout_in_seconds (int): Container timeout in seconds (default: 300)
41564192
inference_component_name (Optional[str]): Name for the inference component
41574193
inference_config (Optional[ResourceRequirements]): Inference configuration including
41584194
resource requirements (accelerator count, memory, CPU cores)
@@ -4161,31 +4197,22 @@ def _deploy_model_customization(
41614197
Returns:
41624198
Endpoint: The deployed sagemaker.core.resources.Endpoint
41634199
"""
4164-
from sagemaker.core.resources import (
4165-
Model as CoreModel,
4166-
EndpointConfig as CoreEndpointConfig,
4167-
)
4168-
from sagemaker.core.shapes import ContainerDefinition, ProductionVariant
41694200
from sagemaker.core.shapes import (
41704201
InferenceComponentSpecification,
41714202
InferenceComponentContainerSpecification,
41724203
InferenceComponentRuntimeConfig,
41734204
InferenceComponentComputeResourceRequirements,
4174-
ModelDataSource,
4175-
S3ModelDataSource,
41764205
)
4206+
from sagemaker.core.shapes import ProductionVariant
41774207
from sagemaker.core.resources import InferenceComponent
4178-
from sagemaker.core.utils.utils import Unassigned
4208+
from sagemaker.core.resources import Tag as CoreTag
41794209

41804210
# Fetch model package
41814211
model_package = self._fetch_model_package()
41824212

41834213
# Check if endpoint exists
41844214
is_existing_endpoint = self._does_endpoint_exist(endpoint_name)
41854215

4186-
# Generate model name if not set
4187-
model_name = self.model_name or f"model-{uuid.uuid4().hex[:10]}"
4188-
41894216
if not is_existing_endpoint:
41904217
EndpointConfig.create(
41914218
endpoint_config_name=endpoint_name,
@@ -4206,114 +4233,145 @@ def _deploy_model_customization(
42064233
else:
42074234
endpoint = Endpoint.get(endpoint_name=endpoint_name)
42084235

4209-
# Set inference component name
4210-
if not inference_component_name:
4211-
if not is_existing_endpoint:
4212-
inference_component_name = f"{endpoint_name}-inference-component"
4213-
else:
4214-
inference_component_name = f"{endpoint_name}-inference-component-adapter"
4215-
4216-
# Get PEFT type and base model recipe name
42174236
peft_type = self._fetch_peft()
42184237
base_model_recipe_name = model_package.inference_specification.containers[
42194238
0
42204239
].base_model.recipe_name
4221-
base_inference_component_name = None
4222-
tag = None
4223-
4224-
# Resolve the correct model artifact URI based on deployment type
4225-
artifact_url = self._resolve_model_artifact_uri()
4226-
4227-
# Determine if this is a base model deployment
4228-
# A base model deployment uses HostingArtifactUri from JumpStart (not from model package)
4229-
is_base_model_deployment = False
4230-
if artifact_url and not peft_type:
4231-
# Check if artifact_url comes from JumpStart (not from model package)
4232-
# If model package has model_data_source, it's a full fine-tuned model
4233-
if (
4234-
hasattr(model_package.inference_specification.containers[0], "model_data_source")
4235-
and model_package.inference_specification.containers[0].model_data_source
4236-
):
4237-
is_base_model_deployment = False # Full fine-tuned model
4238-
else:
4239-
is_base_model_deployment = True # Base model from JumpStart
4240-
4241-
# Handle tagging and base component lookup
4242-
if not is_existing_endpoint and is_base_model_deployment:
4243-
# Only tag as "Base" if we're actually deploying a base model
4244-
from sagemaker.core.resources import Tag as CoreTag
42454240

4246-
tag = CoreTag(key="Base", value=base_model_recipe_name)
4247-
elif peft_type == "LORA":
4248-
# For LORA adapters, look up the existing base component
4249-
from sagemaker.core.resources import Tag as CoreTag
4241+
if peft_type == "LORA":
4242+
# LORA deployment: base IC + adapter IC
42504243

4244+
# Find or create base IC
4245+
base_ic_name = None
42514246
for component in InferenceComponent.get_all(
42524247
endpoint_name_equals=endpoint_name, status_equals="InService"
42534248
):
42544249
component_tags = CoreTag.get_all(resource_arn=component.inference_component_arn)
42554250
if any(
42564251
t.key == "Base" and t.value == base_model_recipe_name for t in component_tags
42574252
):
4258-
base_inference_component_name = component.inference_component_name
4253+
base_ic_name = component.inference_component_name
42594254
break
42604255

4261-
ic_spec = InferenceComponentSpecification(
4262-
container=InferenceComponentContainerSpecification(
4263-
image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars
4256+
if not base_ic_name:
4257+
# Deploy base model IC
4258+
base_ic_name = f"{endpoint_name}-inference-component"
4259+
4260+
base_ic_spec = InferenceComponentSpecification(
4261+
model_name=self.built_model.model_name,
4262+
)
4263+
if inference_config is not None:
4264+
base_ic_spec.compute_resource_requirements = (
4265+
InferenceComponentComputeResourceRequirements(
4266+
min_memory_required_in_mb=inference_config.min_memory,
4267+
max_memory_required_in_mb=inference_config.max_memory,
4268+
number_of_cpu_cores_required=inference_config.num_cpus,
4269+
number_of_accelerator_devices_required=inference_config.num_accelerators,
4270+
)
4271+
)
4272+
else:
4273+
base_ic_spec.compute_resource_requirements = self._cached_compute_requirements
4274+
4275+
InferenceComponent.create(
4276+
inference_component_name=base_ic_name,
4277+
endpoint_name=endpoint_name,
4278+
variant_name=endpoint_name,
4279+
specification=base_ic_spec,
4280+
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
4281+
tags=[{"key": "Base", "value": base_model_recipe_name}],
4282+
)
4283+
logger.info("Created base model InferenceComponent: '%s'", base_ic_name)
4284+
4285+
# Wait for base IC to be InService before creating adapter
4286+
base_ic = InferenceComponent.get(inference_component_name=base_ic_name)
4287+
base_ic.wait_for_status("InService")
4288+
4289+
# Deploy adapter IC
4290+
adapter_ic_name = inference_component_name or f"{endpoint_name}-adapter"
4291+
adapter_s3_uri = getattr(self, "_adapter_s3_uri", None)
4292+
4293+
adapter_ic_spec = InferenceComponentSpecification(
4294+
base_inference_component_name=base_ic_name,
4295+
container=InferenceComponentContainerSpecification(
4296+
artifact_url=adapter_s3_uri,
4297+
),
42644298
)
4265-
)
42664299

4267-
if peft_type == "LORA":
4268-
ic_spec.base_inference_component_name = base_inference_component_name
4269-
4270-
# Use inference_config if provided, otherwise fall back to cached requirements
4271-
if inference_config is not None:
4272-
# Extract compute requirements from inference_config (ResourceRequirements)
4273-
ic_spec.compute_resource_requirements = InferenceComponentComputeResourceRequirements(
4274-
min_memory_required_in_mb=inference_config.min_memory,
4275-
max_memory_required_in_mb=inference_config.max_memory,
4276-
number_of_cpu_cores_required=inference_config.num_cpus,
4277-
number_of_accelerator_devices_required=inference_config.num_accelerators,
4300+
InferenceComponent.create(
4301+
inference_component_name=adapter_ic_name,
4302+
endpoint_name=endpoint_name,
4303+
specification=adapter_ic_spec,
42784304
)
4305+
logger.info("Created adapter InferenceComponent: '%s'", adapter_ic_name)
4306+
42794307
else:
4280-
# Fall back to resolved compute requirements from build()
4281-
ic_spec.compute_resource_requirements = self._cached_compute_requirements
4308+
# Non-LORA deployment: single IC
4309+
if not inference_component_name:
4310+
inference_component_name = f"{endpoint_name}-inference-component"
42824311

4283-
InferenceComponent.create(
4284-
inference_component_name=inference_component_name,
4285-
endpoint_name=endpoint_name,
4286-
variant_name=endpoint_name,
4287-
specification=ic_spec,
4288-
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
4289-
tags=[{"key": tag.key, "value": tag.value}] if tag else [],
4290-
)
4312+
artifact_url = self._resolve_model_artifact_uri()
4313+
4314+
ic_spec = InferenceComponentSpecification(
4315+
container=InferenceComponentContainerSpecification(
4316+
image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars
4317+
)
4318+
)
4319+
4320+
if inference_config is not None:
4321+
ic_spec.compute_resource_requirements = (
4322+
InferenceComponentComputeResourceRequirements(
4323+
min_memory_required_in_mb=inference_config.min_memory,
4324+
max_memory_required_in_mb=inference_config.max_memory,
4325+
number_of_cpu_cores_required=inference_config.num_cpus,
4326+
number_of_accelerator_devices_required=inference_config.num_accelerators,
4327+
)
4328+
)
4329+
else:
4330+
ic_spec.compute_resource_requirements = self._cached_compute_requirements
4331+
4332+
InferenceComponent.create(
4333+
inference_component_name=inference_component_name,
4334+
endpoint_name=endpoint_name,
4335+
variant_name=endpoint_name,
4336+
specification=ic_spec,
4337+
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
4338+
)
42914339

42924340
# Create lineage tracking for new endpoints
42934341
if not is_existing_endpoint:
4294-
from sagemaker.core.resources import Action, Association, Artifact
4295-
from sagemaker.core.shapes import ActionSource, MetadataProperties
4342+
try:
4343+
from sagemaker.core.resources import Action, Association, Artifact
4344+
from sagemaker.core.shapes import ActionSource, MetadataProperties
42964345

4297-
inference_component = InferenceComponent.get(
4298-
inference_component_name=inference_component_name
4299-
)
4346+
ic_name = (
4347+
inference_component_name
4348+
if not peft_type == "LORA"
4349+
else adapter_ic_name
4350+
)
4351+
inference_component = InferenceComponent.get(
4352+
inference_component_name=ic_name
4353+
)
43004354

4301-
action = Action.create(
4302-
source=ActionSource(
4303-
source_uri=self._fetch_model_package_arn(), source_type="SageMaker"
4304-
),
4305-
action_name=f"{endpoint_name}-action",
4306-
action_type="ModelDeployment",
4307-
properties={"EndpointConfigName": endpoint_name},
4308-
metadata_properties=MetadataProperties(
4309-
generated_by=inference_component.inference_component_arn
4310-
),
4311-
)
4355+
action = Action.create(
4356+
source=ActionSource(
4357+
source_uri=self._fetch_model_package_arn(), source_type="SageMaker"
4358+
),
4359+
action_name=f"{endpoint_name}-action",
4360+
action_type="ModelDeployment",
4361+
properties={"EndpointConfigName": endpoint_name},
4362+
metadata_properties=MetadataProperties(
4363+
generated_by=inference_component.inference_component_arn
4364+
),
4365+
)
43124366

4313-
artifacts = Artifact.get_all(source_uri=model_package.model_package_arn)
4314-
for artifact in artifacts:
4315-
Association.add(source_arn=artifact.artifact_arn, destination_arn=action.action_arn)
4316-
break
4367+
artifacts = Artifact.get_all(source_uri=model_package.model_package_arn)
4368+
for artifact in artifacts:
4369+
Association.add(
4370+
source_arn=artifact.artifact_arn, destination_arn=action.action_arn
4371+
)
4372+
break
4373+
except Exception as e:
4374+
logger.warning(f"Failed to create lineage tracking: {e}")
43174375

43184376
logger.info("✅ Model customization deployment successful: Endpoint '%s'", endpoint_name)
43194377
return endpoint
@@ -4329,11 +4387,10 @@ def _fetch_peft(self) -> Optional[str]:
43294387

43304388
from sagemaker.core.utils.utils import Unassigned
43314389

4332-
if (
4333-
training_job.serverless_job_config != Unassigned()
4334-
and training_job.serverless_job_config.job_spec != Unassigned()
4335-
):
4336-
return training_job.serverless_job_config.job_spec.get("PEFT")
4390+
if training_job.serverless_job_config != Unassigned():
4391+
peft = getattr(training_job.serverless_job_config, "peft", None)
4392+
if peft and not isinstance(peft, Unassigned):
4393+
return peft
43374394
return None
43384395

43394396
def _does_endpoint_exist(self, endpoint_name: str) -> bool:

0 commit comments

Comments
 (0)