@@ -957,6 +957,10 @@ def _fetch_and_cache_recipe_config(self):
957957 if not self .image_uri :
958958 self .image_uri = config .get ("EcrAddress" )
959959
960+ # Cache environment variables from recipe config
961+ if not self .env_vars :
962+ self .env_vars = config .get ("Environment" , {})
963+
960964 # Infer instance type from JumpStart metadata if not provided
961965 # This is only called for model_customization deployments
962966 if not self .instance_type :
@@ -2211,21 +2215,57 @@ def _build_single_modelbuilder(
22112215 "Only SageMaker Endpoint Mode is supported for Model Customization use cases"
22122216 )
22132217 model_package = self ._fetch_model_package ()
2214- # Fetch recipe config first to set image_uri, instance_type, and s3_upload_path
2218+ # Fetch recipe config first to set image_uri, instance_type, env_vars, and s3_upload_path
22152219 self ._fetch_and_cache_recipe_config ()
2216- self .s3_upload_path = model_package .inference_specification .containers [
2217- 0
2218- ].model_data_source .s3_data_source .s3_uri
2219- container_def = ContainerDefinition (
2220- image = self .image_uri ,
2221- model_data_source = {
2222- "s3_data_source" : {
2223- "s3_uri" : f"{ self .s3_upload_path } /" ,
2224- "s3_data_type" : "S3Prefix" ,
2225- "compression_type" : "None" ,
2226- }
2227- },
2228- )
2220+ peft_type = self ._fetch_peft ()
2221+
2222+ if peft_type == "LORA" :
2223+ # For LORA: Model points at JumpStart base model, not training output
2224+ hub_document = self ._fetch_hub_document_for_custom_model ()
2225+ hosting_artifact_uri = hub_document .get ("HostingArtifactUri" )
2226+ if not hosting_artifact_uri :
2227+ raise ValueError (
2228+ "HostingArtifactUri not found in JumpStart hub metadata. "
2229+ "Cannot deploy LORA adapter without base model artifacts."
2230+ )
2231+ container_def = ContainerDefinition (
2232+ image = self .image_uri ,
2233+ environment = self .env_vars ,
2234+ model_data_source = {
2235+ "s3_data_source" : {
2236+ "s3_uri" : hosting_artifact_uri ,
2237+ "s3_data_type" : "S3Prefix" ,
2238+ "compression_type" : "None" ,
2239+ "model_access_config" : {"accept_eula" : True },
2240+ }
2241+ },
2242+ )
2243+ # Store adapter path for use during deploy
2244+ if isinstance (self .model , TrainingJob ):
2245+ self ._adapter_s3_uri = (
2246+ f"{ self .model .model_artifacts .s3_model_artifacts } /checkpoints/hf/"
2247+ )
2248+ elif isinstance (self .model , ModelTrainer ):
2249+ self ._adapter_s3_uri = (
2250+ f"{ self .model ._latest_training_job .model_artifacts .s3_model_artifacts } "
2251+ "/checkpoints/hf/"
2252+ )
2253+ else :
2254+ # Non-LORA: Model points at training output
2255+ self .s3_upload_path = model_package .inference_specification .containers [
2256+ 0
2257+ ].model_data_source .s3_data_source .s3_uri
2258+ container_def = ContainerDefinition (
2259+ image = self .image_uri ,
2260+ model_data_source = {
2261+ "s3_data_source" : {
2262+ "s3_uri" : self .s3_upload_path .rstrip ("/" ) + "/" ,
2263+ "s3_data_type" : "S3Prefix" ,
2264+ "compression_type" : "None" ,
2265+ }
2266+ },
2267+ )
2268+
22292269 model_name = self .model_name or f"model-{ uuid .uuid4 ().hex [:10 ]} "
22302270 # Create model
22312271 self .built_model = Model .create (
@@ -4142,17 +4182,13 @@ def _deploy_model_customization(
41424182 """Deploy a model customization (fine-tuned) model to an endpoint with inference components.
41434183
41444184 This method handles the special deployment flow for fine-tuned models, creating:
4145- 1. Core Model resource
4146- 2. EndpointConfig
4147- 3. Endpoint
4148- 4. InferenceComponent
4185+ 1. EndpointConfig and Endpoint
4186+ 2. Base model InferenceComponent (for LORA: from JumpStart base model)
4187+ 3. Adapter InferenceComponent (for LORA: referencing base IC with adapter weights)
41494188
41504189 Args:
41514190 endpoint_name (str): Name of the endpoint to create or update
4152- instance_type (str): EC2 instance type for deployment
41534191 initial_instance_count (int): Number of instances (default: 1)
4154- wait (bool): Whether to wait for deployment to complete (default: True)
4155- container_timeout_in_seconds (int): Container timeout in seconds (default: 300)
41564192 inference_component_name (Optional[str]): Name for the inference component
41574193 inference_config (Optional[ResourceRequirements]): Inference configuration including
41584194 resource requirements (accelerator count, memory, CPU cores)
@@ -4161,31 +4197,22 @@ def _deploy_model_customization(
41614197 Returns:
41624198 Endpoint: The deployed sagemaker.core.resources.Endpoint
41634199 """
4164- from sagemaker .core .resources import (
4165- Model as CoreModel ,
4166- EndpointConfig as CoreEndpointConfig ,
4167- )
4168- from sagemaker .core .shapes import ContainerDefinition , ProductionVariant
41694200 from sagemaker .core .shapes import (
41704201 InferenceComponentSpecification ,
41714202 InferenceComponentContainerSpecification ,
41724203 InferenceComponentRuntimeConfig ,
41734204 InferenceComponentComputeResourceRequirements ,
4174- ModelDataSource ,
4175- S3ModelDataSource ,
41764205 )
4206+ from sagemaker .core .shapes import ProductionVariant
41774207 from sagemaker .core .resources import InferenceComponent
4178- from sagemaker .core .utils . utils import Unassigned
4208+ from sagemaker .core .resources import Tag as CoreTag
41794209
41804210 # Fetch model package
41814211 model_package = self ._fetch_model_package ()
41824212
41834213 # Check if endpoint exists
41844214 is_existing_endpoint = self ._does_endpoint_exist (endpoint_name )
41854215
4186- # Generate model name if not set
4187- model_name = self .model_name or f"model-{ uuid .uuid4 ().hex [:10 ]} "
4188-
41894216 if not is_existing_endpoint :
41904217 EndpointConfig .create (
41914218 endpoint_config_name = endpoint_name ,
@@ -4206,114 +4233,145 @@ def _deploy_model_customization(
42064233 else :
42074234 endpoint = Endpoint .get (endpoint_name = endpoint_name )
42084235
4209- # Set inference component name
4210- if not inference_component_name :
4211- if not is_existing_endpoint :
4212- inference_component_name = f"{ endpoint_name } -inference-component"
4213- else :
4214- inference_component_name = f"{ endpoint_name } -inference-component-adapter"
4215-
4216- # Get PEFT type and base model recipe name
42174236 peft_type = self ._fetch_peft ()
42184237 base_model_recipe_name = model_package .inference_specification .containers [
42194238 0
42204239 ].base_model .recipe_name
4221- base_inference_component_name = None
4222- tag = None
4223-
4224- # Resolve the correct model artifact URI based on deployment type
4225- artifact_url = self ._resolve_model_artifact_uri ()
4226-
4227- # Determine if this is a base model deployment
4228- # A base model deployment uses HostingArtifactUri from JumpStart (not from model package)
4229- is_base_model_deployment = False
4230- if artifact_url and not peft_type :
4231- # Check if artifact_url comes from JumpStart (not from model package)
4232- # If model package has model_data_source, it's a full fine-tuned model
4233- if (
4234- hasattr (model_package .inference_specification .containers [0 ], "model_data_source" )
4235- and model_package .inference_specification .containers [0 ].model_data_source
4236- ):
4237- is_base_model_deployment = False # Full fine-tuned model
4238- else :
4239- is_base_model_deployment = True # Base model from JumpStart
4240-
4241- # Handle tagging and base component lookup
4242- if not is_existing_endpoint and is_base_model_deployment :
4243- # Only tag as "Base" if we're actually deploying a base model
4244- from sagemaker .core .resources import Tag as CoreTag
42454240
4246- tag = CoreTag (key = "Base" , value = base_model_recipe_name )
4247- elif peft_type == "LORA" :
4248- # For LORA adapters, look up the existing base component
4249- from sagemaker .core .resources import Tag as CoreTag
4241+ if peft_type == "LORA" :
4242+ # LORA deployment: base IC + adapter IC
42504243
4244+ # Find or create base IC
4245+ base_ic_name = None
42514246 for component in InferenceComponent .get_all (
42524247 endpoint_name_equals = endpoint_name , status_equals = "InService"
42534248 ):
42544249 component_tags = CoreTag .get_all (resource_arn = component .inference_component_arn )
42554250 if any (
42564251 t .key == "Base" and t .value == base_model_recipe_name for t in component_tags
42574252 ):
4258- base_inference_component_name = component .inference_component_name
4253+ base_ic_name = component .inference_component_name
42594254 break
42604255
4261- ic_spec = InferenceComponentSpecification (
4262- container = InferenceComponentContainerSpecification (
4263- image = self .image_uri , artifact_url = artifact_url , environment = self .env_vars
4256+ if not base_ic_name :
4257+ # Deploy base model IC
4258+ base_ic_name = f"{ endpoint_name } -inference-component"
4259+
4260+ base_ic_spec = InferenceComponentSpecification (
4261+ model_name = self .built_model .model_name ,
4262+ )
4263+ if inference_config is not None :
4264+ base_ic_spec .compute_resource_requirements = (
4265+ InferenceComponentComputeResourceRequirements (
4266+ min_memory_required_in_mb = inference_config .min_memory ,
4267+ max_memory_required_in_mb = inference_config .max_memory ,
4268+ number_of_cpu_cores_required = inference_config .num_cpus ,
4269+ number_of_accelerator_devices_required = inference_config .num_accelerators ,
4270+ )
4271+ )
4272+ else :
4273+ base_ic_spec .compute_resource_requirements = self ._cached_compute_requirements
4274+
4275+ InferenceComponent .create (
4276+ inference_component_name = base_ic_name ,
4277+ endpoint_name = endpoint_name ,
4278+ variant_name = endpoint_name ,
4279+ specification = base_ic_spec ,
4280+ runtime_config = InferenceComponentRuntimeConfig (copy_count = 1 ),
4281+ tags = [{"key" : "Base" , "value" : base_model_recipe_name }],
4282+ )
4283+ logger .info ("Created base model InferenceComponent: '%s'" , base_ic_name )
4284+
4285+ # Wait for base IC to be InService before creating adapter
4286+ base_ic = InferenceComponent .get (inference_component_name = base_ic_name )
4287+ base_ic .wait_for_status ("InService" )
4288+
4289+ # Deploy adapter IC
4290+ adapter_ic_name = inference_component_name or f"{ endpoint_name } -adapter"
4291+ adapter_s3_uri = getattr (self , "_adapter_s3_uri" , None )
4292+
4293+ adapter_ic_spec = InferenceComponentSpecification (
4294+ base_inference_component_name = base_ic_name ,
4295+ container = InferenceComponentContainerSpecification (
4296+ artifact_url = adapter_s3_uri ,
4297+ ),
42644298 )
4265- )
42664299
4267- if peft_type == "LORA" :
4268- ic_spec .base_inference_component_name = base_inference_component_name
4269-
4270- # Use inference_config if provided, otherwise fall back to cached requirements
4271- if inference_config is not None :
4272- # Extract compute requirements from inference_config (ResourceRequirements)
4273- ic_spec .compute_resource_requirements = InferenceComponentComputeResourceRequirements (
4274- min_memory_required_in_mb = inference_config .min_memory ,
4275- max_memory_required_in_mb = inference_config .max_memory ,
4276- number_of_cpu_cores_required = inference_config .num_cpus ,
4277- number_of_accelerator_devices_required = inference_config .num_accelerators ,
4300+ InferenceComponent .create (
4301+ inference_component_name = adapter_ic_name ,
4302+ endpoint_name = endpoint_name ,
4303+ specification = adapter_ic_spec ,
42784304 )
4305+ logger .info ("Created adapter InferenceComponent: '%s'" , adapter_ic_name )
4306+
42794307 else :
4280- # Fall back to resolved compute requirements from build()
4281- ic_spec .compute_resource_requirements = self ._cached_compute_requirements
4308+ # Non-LORA deployment: single IC
4309+ if not inference_component_name :
4310+ inference_component_name = f"{ endpoint_name } -inference-component"
42824311
4283- InferenceComponent .create (
4284- inference_component_name = inference_component_name ,
4285- endpoint_name = endpoint_name ,
4286- variant_name = endpoint_name ,
4287- specification = ic_spec ,
4288- runtime_config = InferenceComponentRuntimeConfig (copy_count = 1 ),
4289- tags = [{"key" : tag .key , "value" : tag .value }] if tag else [],
4290- )
4312+ artifact_url = self ._resolve_model_artifact_uri ()
4313+
4314+ ic_spec = InferenceComponentSpecification (
4315+ container = InferenceComponentContainerSpecification (
4316+ image = self .image_uri , artifact_url = artifact_url , environment = self .env_vars
4317+ )
4318+ )
4319+
4320+ if inference_config is not None :
4321+ ic_spec .compute_resource_requirements = (
4322+ InferenceComponentComputeResourceRequirements (
4323+ min_memory_required_in_mb = inference_config .min_memory ,
4324+ max_memory_required_in_mb = inference_config .max_memory ,
4325+ number_of_cpu_cores_required = inference_config .num_cpus ,
4326+ number_of_accelerator_devices_required = inference_config .num_accelerators ,
4327+ )
4328+ )
4329+ else :
4330+ ic_spec .compute_resource_requirements = self ._cached_compute_requirements
4331+
4332+ InferenceComponent .create (
4333+ inference_component_name = inference_component_name ,
4334+ endpoint_name = endpoint_name ,
4335+ variant_name = endpoint_name ,
4336+ specification = ic_spec ,
4337+ runtime_config = InferenceComponentRuntimeConfig (copy_count = 1 ),
4338+ )
42914339
42924340 # Create lineage tracking for new endpoints
42934341 if not is_existing_endpoint :
4294- from sagemaker .core .resources import Action , Association , Artifact
4295- from sagemaker .core .shapes import ActionSource , MetadataProperties
4342+ try :
4343+ from sagemaker .core .resources import Action , Association , Artifact
4344+ from sagemaker .core .shapes import ActionSource , MetadataProperties
42964345
4297- inference_component = InferenceComponent .get (
4298- inference_component_name = inference_component_name
4299- )
4346+ ic_name = (
4347+ inference_component_name
4348+ if not peft_type == "LORA"
4349+ else adapter_ic_name
4350+ )
4351+ inference_component = InferenceComponent .get (
4352+ inference_component_name = ic_name
4353+ )
43004354
4301- action = Action .create (
4302- source = ActionSource (
4303- source_uri = self ._fetch_model_package_arn (), source_type = "SageMaker"
4304- ),
4305- action_name = f"{ endpoint_name } -action" ,
4306- action_type = "ModelDeployment" ,
4307- properties = {"EndpointConfigName" : endpoint_name },
4308- metadata_properties = MetadataProperties (
4309- generated_by = inference_component .inference_component_arn
4310- ),
4311- )
4355+ action = Action .create (
4356+ source = ActionSource (
4357+ source_uri = self ._fetch_model_package_arn (), source_type = "SageMaker"
4358+ ),
4359+ action_name = f"{ endpoint_name } -action" ,
4360+ action_type = "ModelDeployment" ,
4361+ properties = {"EndpointConfigName" : endpoint_name },
4362+ metadata_properties = MetadataProperties (
4363+ generated_by = inference_component .inference_component_arn
4364+ ),
4365+ )
43124366
4313- artifacts = Artifact .get_all (source_uri = model_package .model_package_arn )
4314- for artifact in artifacts :
4315- Association .add (source_arn = artifact .artifact_arn , destination_arn = action .action_arn )
4316- break
4367+ artifacts = Artifact .get_all (source_uri = model_package .model_package_arn )
4368+ for artifact in artifacts :
4369+ Association .add (
4370+ source_arn = artifact .artifact_arn , destination_arn = action .action_arn
4371+ )
4372+ break
4373+ except Exception as e :
4374+ logger .warning (f"Failed to create lineage tracking: { e } " )
43174375
43184376 logger .info ("✅ Model customization deployment successful: Endpoint '%s'" , endpoint_name )
43194377 return endpoint
@@ -4329,11 +4387,10 @@ def _fetch_peft(self) -> Optional[str]:
43294387
43304388 from sagemaker .core .utils .utils import Unassigned
43314389
4332- if (
4333- training_job .serverless_job_config != Unassigned ()
4334- and training_job .serverless_job_config .job_spec != Unassigned ()
4335- ):
4336- return training_job .serverless_job_config .job_spec .get ("PEFT" )
4390+ if training_job .serverless_job_config != Unassigned ():
4391+ peft = getattr (training_job .serverless_job_config , "peft" , None )
4392+ if peft and not isinstance (peft , Unassigned ):
4393+ return peft
43374394 return None
43384395
43394396 def _does_endpoint_exist (self , endpoint_name : str ) -> bool :
0 commit comments