Skip to content

Commit 5b9e82b

Browse files
author
Amit Modi
committed
fix: Add PipelineVariable support to ModelTrainer fields (fixes #5524)
Extend StrPipeVar type to ModelTrainer's direct fields: - training_image: Optional[str] -> Optional[StrPipeVar] - algorithm_name: Optional[str] -> Optional[StrPipeVar] - training_input_mode: Optional[str] -> Optional[StrPipeVar] - environment: Dict[str, str] -> Dict[str, StrPipeVar] This follows the existing V3 pattern already used by SourceCode, OutputDataConfig, and Compute (for instance_type). The StrPipeVar type alias and PipelineVariable.__get_pydantic_core_schema__() already exist in the codebase. This unblocks V2->V3 migration for SageMaker Pipelines users who need to pass ParameterString to ModelTrainer fields. Fixes #5524
1 parent 55a4ee5 commit 5b9e82b

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
from sagemaker.core.jumpstart.utils import get_eula_url
117117
from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults
118118
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
119+
from sagemaker.core.helper.pipeline_variable import StrPipeVar
119120

120121
from sagemaker.train.local.local_container import _LocalContainer
121122

@@ -235,14 +236,14 @@ class ModelTrainer(BaseModel):
235236
compute: Optional[Compute] = None
236237
networking: Optional[Networking] = None
237238
stopping_condition: Optional[StoppingCondition] = None
238-
training_image: Optional[str] = None
239+
training_image: Optional[StrPipeVar] = None
239240
training_image_config: Optional[TrainingImageConfig] = None
240-
algorithm_name: Optional[str] = None
241+
algorithm_name: Optional[StrPipeVar] = None
241242
output_data_config: Optional[shapes.OutputDataConfig] = None
242243
input_data_config: Optional[List[Union[Channel, InputData]]] = None
243244
checkpoint_config: Optional[shapes.CheckpointConfig] = None
244-
training_input_mode: Optional[str] = "File"
245-
environment: Optional[Dict[str, str]] = {}
245+
training_input_mode: Optional[StrPipeVar] = "File"
246+
environment: Optional[Dict[str, StrPipeVar]] = {}
246247
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
247248
tags: Optional[List[Tag]] = None
248249
local_container_root: Optional[str] = os.getcwd()

0 commit comments

Comments
 (0)