Skip to content

Commit 4d922b2

Browse files
authored
Add input validation and resource management improvements (aws#5417)
* Add input validation and resource management improvements * Fix failing unit test * Fix codestyle issues * More codestyle fixes * Allowing for sym-links, better refactoring * Adding additional validation and removing home as sensitive path * Adding root directory validation to other helpers * Fixing codestyle changes * Fixes for missed codestyle changes * Fixing codestyle issue
1 parent b7c21a5 commit 4d922b2

7 files changed

Lines changed: 306 additions & 7 deletions

File tree

src/sagemaker/iterators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io
1818

1919
from sagemaker.exceptions import ModelStreamError, InternalStreamFailure
20+
from sagemaker.utils import _MAX_BUFFER_SIZE
2021

2122

2223
def handle_stream_errors(chunk):
@@ -182,5 +183,15 @@ def __next__(self):
182183
# print and move on to next response byte
183184
print("Unknown event type:" + chunk)
184185
continue
186+
187+
# Check buffer size before writing to prevent unbounded memory consumption
188+
chunk_size = len(chunk["PayloadPart"]["Bytes"])
189+
current_size = self.buffer.getbuffer().nbytes
190+
if current_size + chunk_size > _MAX_BUFFER_SIZE:
191+
raise RuntimeError(
192+
f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. "
193+
f"No newline found in stream."
194+
)
195+
185196
self.buffer.seek(0, io.SEEK_END)
186197
self.buffer.write(chunk["PayloadPart"]["Bytes"])

src/sagemaker/local/data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import sagemaker.amazon.common
2727
import sagemaker.local.utils
2828
import sagemaker.utils
29+
from sagemaker.utils import _SENSITIVE_SYSTEM_PATHS
2930

3031

3132
def get_data_source_instance(data_source, sagemaker_session):
@@ -122,6 +123,15 @@ def __init__(self, root_path):
122123
super(LocalFileDataSource, self).__init__()
123124

124125
self.root_path = os.path.abspath(root_path)
126+
127+
# Validate that the path is not in restricted locations
128+
for restricted_path in _SENSITIVE_SYSTEM_PATHS:
129+
if self.root_path != "/" and self.root_path.startswith(restricted_path):
130+
raise ValueError(
131+
f"Local Mode does not support mounting from restricted system paths. "
132+
f"Got: {root_path}"
133+
)
134+
125135
if not os.path.exists(self.root_path):
126136
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)
127137

src/sagemaker/local/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path):
4848
destination_directory
4949
"""
5050
full_path = os.path.join(destination_directory, relative_path)
51-
if os.path.exists(full_path):
52-
return
53-
54-
os.makedirs(destination_directory, relative_path)
51+
os.makedirs(full_path, exist_ok=True)
5552

5653

5754
def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):

src/sagemaker/modules/configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class OutputDataConfig(shapes.OutputDataConfig):
248248
"""OutputDataConfig.
249249
250250
The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig``
251-
and allows the user to specify the output data configuration for the training job
251+
and allows the user to specify the output data configuration for the training job
252252
(will not be carried over to any model repository or deployment).
253253
254254
Parameters:

src/sagemaker/utils.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,20 @@
7676
WAITING_DOT_NUMBER = 10
7777
MAX_ITEMS = 100
7878
PAGE_SIZE = 10
79+
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
80+
81+
_SENSITIVE_SYSTEM_PATHS = [
82+
abspath(os.path.expanduser("~/.aws")),
83+
abspath(os.path.expanduser("~/.ssh")),
84+
abspath(os.path.expanduser("~/.kube")),
85+
abspath(os.path.expanduser("~/.docker")),
86+
abspath(os.path.expanduser("~/.config")),
87+
abspath(os.path.expanduser("~/.credentials")),
88+
abspath(realpath("/etc")),
89+
abspath(realpath("/root")),
90+
abspath(realpath("/var/lib")),
91+
abspath(realpath("/opt/ml/metadata")),
92+
]
7993

8094
logger = logging.getLogger(__name__)
8195

@@ -601,11 +615,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
601615
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
602616

603617

618+
def _validate_source_directory(source_directory):
619+
"""Validate that source_directory is safe to use.
620+
621+
Ensures the source directory path does not access restricted system locations.
622+
623+
Args:
624+
source_directory (str): The source directory path to validate.
625+
626+
Raises:
627+
ValueError: If the path is not allowed.
628+
"""
629+
if not source_directory or source_directory.lower().startswith("s3://"):
630+
# S3 paths and None are safe
631+
return
632+
633+
# Resolve symlinks to get the actual path
634+
abs_source = abspath(realpath(source_directory))
635+
636+
# Check if the source path is under any sensitive directory
637+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
638+
if abs_source != "/" and abs_source.startswith(sensitive_path):
639+
raise ValueError(
640+
f"source_directory cannot access sensitive system paths. "
641+
f"Got: {source_directory} (resolved to {abs_source})"
642+
)
643+
644+
645+
def _validate_dependency_path(dependency):
646+
"""Validate that a dependency path is safe to use.
647+
648+
Ensures the dependency path does not access restricted system locations.
649+
650+
Args:
651+
dependency (str): The dependency path to validate.
652+
653+
Raises:
654+
ValueError: If the path is not allowed.
655+
"""
656+
if not dependency:
657+
return
658+
659+
# Resolve symlinks to get the actual path
660+
abs_dependency = abspath(realpath(dependency))
661+
662+
# Check if the dependency path is under any sensitive directory
663+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
664+
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
665+
raise ValueError(
666+
f"dependency path cannot access sensitive system paths. "
667+
f"Got: {dependency} (resolved to {abs_dependency})"
668+
)
669+
670+
604671
def _create_or_update_code_dir(
605672
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
606673
):
607674
"""Placeholder docstring"""
608675
code_dir = os.path.join(model_dir, "code")
676+
resolved_code_dir = _get_resolved_path(code_dir)
677+
678+
# Validate that code_dir does not resolve to a sensitive system path
679+
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
680+
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
681+
raise ValueError(
682+
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
683+
)
684+
609685
if source_directory and source_directory.lower().startswith("s3://"):
610686
local_code_path = os.path.join(tmp, "local_code.tar.gz")
611687
download_file_from_url(source_directory, local_code_path, sagemaker_session)
@@ -614,6 +690,8 @@ def _create_or_update_code_dir(
614690
custom_extractall_tarfile(t, code_dir)
615691

616692
elif source_directory:
693+
# Validate source_directory for security
694+
_validate_source_directory(source_directory)
617695
if os.path.exists(code_dir):
618696
shutil.rmtree(code_dir)
619697
shutil.copytree(source_directory, code_dir)
@@ -646,6 +724,8 @@ def _create_or_update_code_dir(
646724
)
647725

648726
for dependency in dependencies:
727+
# Validate dependency path for security
728+
_validate_dependency_path(dependency)
649729
lib_dir = os.path.join(code_dir, "lib")
650730
if os.path.isdir(dependency):
651731
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
@@ -1620,6 +1700,38 @@ def _get_safe_members(members):
16201700
yield file_info
16211701

16221702

1703+
def _validate_extracted_paths(extract_path):
1704+
"""Validate that extracted paths remain within the expected directory.
1705+
1706+
Performs post-extraction validation to ensure all extracted files and directories
1707+
are within the intended extraction path.
1708+
1709+
Args:
1710+
extract_path (str): The path where files were extracted.
1711+
1712+
Raises:
1713+
ValueError: If any extracted file is outside the expected extraction path.
1714+
"""
1715+
base = _get_resolved_path(extract_path)
1716+
1717+
for root, dirs, files in os.walk(extract_path):
1718+
# Check directories
1719+
for dir_name in dirs:
1720+
dir_path = os.path.join(root, dir_name)
1721+
resolved = _get_resolved_path(dir_path)
1722+
if not resolved.startswith(base):
1723+
logger.error("Extracted directory escaped extraction path: %s", dir_path)
1724+
raise ValueError(f"Extracted path outside expected directory: {dir_path}")
1725+
1726+
# Check files
1727+
for file_name in files:
1728+
file_path = os.path.join(root, file_name)
1729+
resolved = _get_resolved_path(file_path)
1730+
if not resolved.startswith(base):
1731+
logger.error("Extracted file escaped extraction path: %s", file_path)
1732+
raise ValueError(f"Extracted path outside expected directory: {file_path}")
1733+
1734+
16231735
def custom_extractall_tarfile(tar, extract_path):
16241736
"""Extract a tarfile, optionally using data_filter if available.
16251737
@@ -1640,6 +1752,8 @@ def custom_extractall_tarfile(tar, extract_path):
16401752
tar.extractall(path=extract_path, filter="data")
16411753
else:
16421754
tar.extractall(path=extract_path, members=_get_safe_members(tar))
1755+
# Re-validate extracted paths to catch symlink race conditions
1756+
_validate_extracted_paths(extract_path)
16431757

16441758

16451759
def can_model_package_source_uri_autopopulate(source_uri: str):

tests/unit/sagemaker/local/test_local_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
@patch("sagemaker.local.utils.os.path")
2626
@patch("sagemaker.local.utils.os")
2727
def test_copy_directory_structure(m_os, m_os_path):
28-
m_os_path.exists.return_value = False
28+
m_os_path.join.return_value = "/tmp/code/"
2929
sagemaker.local.utils.copy_directory_structure("/tmp/", "code/")
30-
m_os.makedirs.assert_called_with("/tmp/", "code/")
30+
m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True)
3131

3232

3333
@patch("shutil.rmtree", Mock())

0 commit comments

Comments
 (0)