Skip to content

Commit c480aa6

Browse files
authored
Merge branch 'master' into master-mc-post-launch-fixes
2 parents 1049363 + 0bec208 commit c480aa6

6 files changed

Lines changed: 196 additions & 41 deletions

File tree

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
import logging
1111
import re
12-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
12+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1313

1414
from pydantic import BaseModel, validator
1515

16+
from sagemaker.core.common_utils import TagsDict
1617
from sagemaker.core.resources import ModelPackageGroup, ModelPackage
1718
from sagemaker.core.shapes import VpcConfig
1819

@@ -413,6 +414,13 @@ def _source_model_package_arn(self) -> Optional[str]:
413414
"""Get the resolved source model package ARN (None for JumpStart models)."""
414415
info = self._get_resolved_model_info()
415416
return info.source_model_package_arn if info else None
417+
418+
@property
419+
def _is_jumpstart_model(self) -> bool:
420+
"""Determine if model is a JumpStart model"""
421+
from sagemaker.train.common_utils.model_resolution import _ModelType
422+
info = self._get_resolved_model_info()
423+
return info.model_type == _ModelType.JUMPSTART
416424

417425
def _infer_model_package_group_arn(self) -> Optional[str]:
418426
"""Infer model package group ARN from source model package ARN.
@@ -797,6 +805,12 @@ def _start_execution(
797805
EvaluationPipelineExecution: Started execution object
798806
"""
799807
from .execution import EvaluationPipelineExecution
808+
809+
tags: List[TagsDict] = []
810+
811+
if self._is_jumpstart_model:
812+
from sagemaker.core.jumpstart.utils import add_jumpstart_model_info_tags
813+
tags = add_jumpstart_model_info_tags(tags, self.model, "*")
800814

801815
execution = EvaluationPipelineExecution.start(
802816
eval_type=eval_type,
@@ -805,7 +819,8 @@ def _start_execution(
805819
role_arn=role_arn,
806820
s3_output_path=self.s3_output_path,
807821
session=self.sagemaker_session.boto_session if hasattr(self.sagemaker_session, 'boto_session') else None,
808-
region=region
822+
region=region,
823+
tags=tags
809824
)
810825

811826
return execution

sagemaker-train/src/sagemaker/train/evaluate/execution.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# Third-party imports
1717
from botocore.exceptions import ClientError
1818
from pydantic import BaseModel, Field
19+
from sagemaker.core.common_utils import TagsDict
1920
from sagemaker.core.helper.session_helper import Session
2021
from sagemaker.core.resources import Pipeline, PipelineExecution, Tag
2122
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
@@ -38,6 +39,7 @@ def _create_evaluation_pipeline(
3839
pipeline_definition: str,
3940
session: Optional[Any] = None,
4041
region: Optional[str] = None,
42+
tags: Optional[List[TagsDict]] = [],
4143
) -> Any:
4244
"""Helper method to create a SageMaker pipeline for evaluation.
4345
@@ -49,6 +51,7 @@ def _create_evaluation_pipeline(
4951
pipeline_definition (str): JSON pipeline definition (Jinja2 template).
5052
session (Optional[Any]): SageMaker session object.
5153
region (Optional[str]): AWS region.
54+
tags (Optional[List[TagsDict]]): List of tags to include in pipeline
5255
5356
Returns:
5457
Any: Created Pipeline instance (ready for execution).
@@ -65,9 +68,9 @@ def _create_evaluation_pipeline(
6568
resolved_pipeline_definition = template.render(pipeline_name=pipeline_name)
6669

6770
# Create tags for the pipeline
68-
tags = [
71+
tags.extend([
6972
{"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"}
70-
]
73+
])
7174

7275
pipeline = Pipeline.create(
7376
pipeline_name=pipeline_name,
@@ -163,7 +166,8 @@ def _get_or_create_pipeline(
163166
pipeline_definition: str,
164167
role_arn: str,
165168
session: Optional[Session] = None,
166-
region: Optional[str] = None
169+
region: Optional[str] = None,
170+
create_tags: Optional[List[TagsDict]] = [],
167171
) -> Pipeline:
168172
"""Get existing pipeline or create/update it.
169173
@@ -177,6 +181,7 @@ def _get_or_create_pipeline(
177181
role_arn: IAM role ARN for pipeline execution
178182
session: Boto3 session (optional)
179183
region: AWS region (optional)
184+
create_tags (Optional[List[TagsDict]]): List of tags to include in pipeline
180185
181186
Returns:
182187
Pipeline instance (existing updated or newly created)
@@ -225,19 +230,19 @@ def _get_or_create_pipeline(
225230

226231
# No matching pipeline found, create new one
227232
logger.info(f"No existing pipeline found with prefix {pipeline_name_prefix}, creating new one")
228-
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
233+
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
229234

230235
except ClientError as e:
231236
error_code = e.response['Error']['Code']
232237
if "ResourceNotFound" in error_code:
233-
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
238+
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
234239
else:
235240
raise
236241

237242
except Exception as e:
238243
# If search fails for other reasons, try to create
239244
logger.info(f"Error searching for pipeline ({str(e)}), attempting to create new pipeline")
240-
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
245+
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
241246

242247

243248
def _start_pipeline_execution(
@@ -505,7 +510,8 @@ def start(
505510
role_arn: str,
506511
s3_output_path: Optional[str] = None,
507512
session: Optional[Session] = None,
508-
region: Optional[str] = None
513+
region: Optional[str] = None,
514+
tags: Optional[List[TagsDict]] = [],
509515
) -> 'EvaluationPipelineExecution':
510516
"""Create sagemaker pipeline execution. Optionally creates pipeline.
511517
@@ -517,6 +523,7 @@ def start(
517523
s3_output_path (Optional[str]): S3 location where evaluation results are stored.
518524
session (Optional[Session]): Boto3 session for API calls.
519525
region (Optional[str]): AWS region for the pipeline.
526+
tags (Optional[List[TagsDict]]): List of tags to include in pipeline
520527
521528
Returns:
522529
EvaluationPipelineExecution: Started pipeline execution instance.
@@ -547,7 +554,8 @@ def start(
547554
pipeline_definition=pipeline_definition,
548555
role_arn=role_arn,
549556
session=session,
550-
region=region
557+
region=region,
558+
create_tags=tags,
551559
)
552560

553561
# Start pipeline execution via boto3

sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,50 @@ def from_dependency_file_path(dependency_file_path):
9494
class RuntimeEnvironmentManager:
9595
"""Runtime Environment Manager class to manage runtime environment."""
9696

97+
def _validate_path(self, path: str) -> str:
98+
"""Validate and sanitize file path to prevent path traversal attacks.
99+
100+
Args:
101+
path (str): The file path to validate
102+
103+
Returns:
104+
str: The validated absolute path
105+
106+
Raises:
107+
ValueError: If the path is invalid or contains suspicious patterns
108+
"""
109+
if not path:
110+
raise ValueError("Path cannot be empty")
111+
112+
# Get absolute path to prevent path traversal
113+
abs_path = os.path.abspath(path)
114+
115+
# Check for null bytes (common in path traversal attacks)
116+
if '\x00' in path:
117+
raise ValueError(f"Invalid path contains null byte: {path}")
118+
119+
return abs_path
120+
121+
def _validate_env_name(self, env_name: str) -> None:
122+
"""Validate conda environment name to prevent command injection.
123+
124+
Args:
125+
env_name (str): The environment name to validate
126+
127+
Raises:
128+
ValueError: If the environment name contains invalid characters
129+
"""
130+
if not env_name:
131+
raise ValueError("Environment name cannot be empty")
132+
133+
# Allow only alphanumeric, underscore, and hyphen
134+
import re
135+
if not re.match(r'^[a-zA-Z0-9_-]+$', env_name):
136+
raise ValueError(
137+
f"Invalid environment name '{env_name}'. "
138+
"Only alphanumeric characters, underscores, and hyphens are allowed."
139+
)
140+
97141
def snapshot(self, dependencies: str = None) -> str:
98142
"""Creates snapshot of the user's environment
99143
@@ -252,39 +296,50 @@ def _is_file_exists(self, dependencies):
252296

253297
def _install_requirements_txt(self, local_path, python_executable):
254298
"""Install requirements.txt file"""
255-
cmd = f"{python_executable} -m pip install -r {local_path} -U"
256-
logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd())
299+
# Validate path to prevent command injection
300+
validated_path = self._validate_path(local_path)
301+
cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"]
302+
logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd())
257303
_run_shell_cmd(cmd)
258-
logger.info("Command %s ran successfully", cmd)
304+
logger.info("Command %s ran successfully", " ".join(cmd))
259305

260306
def _create_conda_env(self, env_name, local_path):
261307
"""Create conda env using conda yml file"""
308+
# Validate inputs to prevent command injection
309+
self._validate_env_name(env_name)
310+
validated_path = self._validate_path(local_path)
262311

263-
cmd = f"{self._get_conda_exe()} env create -n {env_name} --file {local_path}"
264-
logger.info("Creating conda environment %s using: %s.", env_name, cmd)
312+
cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path]
313+
logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd))
265314
_run_shell_cmd(cmd)
266315
logger.info("Conda environment %s created successfully.", env_name)
267316

268317
def _install_req_txt_in_conda_env(self, env_name, local_path):
269318
"""Install requirements.txt in the given conda environment"""
319+
# Validate inputs to prevent command injection
320+
self._validate_env_name(env_name)
321+
validated_path = self._validate_path(local_path)
270322

271-
cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U"
272-
logger.info("Activating conda env and installing requirements: %s", cmd)
323+
cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"]
324+
logger.info("Activating conda env and installing requirements: %s", " ".join(cmd))
273325
_run_shell_cmd(cmd)
274326
logger.info("Requirements installed successfully in conda env %s", env_name)
275327

276328
def _update_conda_env(self, env_name, local_path):
277329
"""Update conda env using conda yml file"""
330+
# Validate inputs to prevent command injection
331+
self._validate_env_name(env_name)
332+
validated_path = self._validate_path(local_path)
278333

279-
cmd = f"{self._get_conda_exe()} env update -n {env_name} --file {local_path}"
280-
logger.info("Updating conda env: %s", cmd)
334+
cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path]
335+
logger.info("Updating conda env: %s", " ".join(cmd))
281336
_run_shell_cmd(cmd)
282337
logger.info("Conda env %s updated succesfully", env_name)
283338

284339
def _export_conda_env_from_prefix(self, prefix, local_path):
285340
"""Export the conda env to a conda yml file"""
286341

287-
cmd = f"{self._get_conda_exe()} env export -p {prefix} --no-builds > {local_path}"
342+
cmd = [self._get_conda_exe(), "env", "export", "-p", prefix, "--no-builds", ">", local_path]
288343
logger.info("Exporting conda environment: %s", cmd)
289344
_run_shell_cmd(cmd)
290345
logger.info("Conda environment %s exported successfully", prefix)
@@ -402,19 +457,26 @@ def _run_pre_execution_command_script(script_path: str):
402457
return return_code, error_logs
403458

404459

405-
def _run_shell_cmd(cmd: str):
460+
def _run_shell_cmd(cmd: list):
406461
"""This method runs a given shell command using subprocess
407462
408-
Raises RuntimeEnvironmentError if the command fails
463+
Args:
464+
cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt'])
465+
466+
Raises:
467+
RuntimeEnvironmentError: If the command fails
468+
ValueError: If cmd is not a list
409469
"""
470+
if not isinstance(cmd, list):
471+
raise ValueError("Command must be a list of arguments for security reasons")
410472

411-
process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
473+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
412474

413475
_log_output(process)
414476
error_logs = _log_error(process)
415477
return_code = process.wait()
416478
if return_code:
417-
error_message = f"Encountered error while running command '{cmd}'. Reason: {error_logs}"
479+
error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}"
418480
raise RuntimeEnvironmentError(error_message)
419481

420482

0 commit comments

Comments
 (0)