Skip to content
Merged
11 changes: 11 additions & 0 deletions src/sagemaker/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io

from sagemaker.exceptions import ModelStreamError, InternalStreamFailure
from sagemaker.utils import _MAX_BUFFER_SIZE


def handle_stream_errors(chunk):
Expand Down Expand Up @@ -182,5 +183,15 @@ def __next__(self):
# print and move on to next response byte
print("Unknown event type:" + chunk)
continue

# Check buffer size before writing to prevent unbounded memory consumption
chunk_size = len(chunk["PayloadPart"]["Bytes"])
current_size = self.buffer.getbuffer().nbytes
if current_size + chunk_size > _MAX_BUFFER_SIZE:
raise RuntimeError(
f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. "
f"No newline found in stream."
)

self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
10 changes: 10 additions & 0 deletions src/sagemaker/local/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import sagemaker.amazon.common
import sagemaker.local.utils
import sagemaker.utils
from sagemaker.utils import _SENSITIVE_SYSTEM_PATHS


def get_data_source_instance(data_source, sagemaker_session):
Expand Down Expand Up @@ -122,6 +123,15 @@ def __init__(self, root_path):
super(LocalFileDataSource, self).__init__()

self.root_path = os.path.abspath(root_path)

# Validate that the path is not in restricted locations
for restricted_path in _SENSITIVE_SYSTEM_PATHS:
if self.root_path != "/" and self.root_path.startswith(restricted_path):
raise ValueError(
f"Local Mode does not support mounting from restricted system paths. "
f"Got: {root_path}"
)

if not os.path.exists(self.root_path):
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)

Expand Down
5 changes: 1 addition & 4 deletions src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def copy_directory_structure(destination_directory, relative_path):
destination_directory
"""
full_path = os.path.join(destination_directory, relative_path)
if os.path.exists(full_path):
return

os.makedirs(destination_directory, relative_path)
os.makedirs(full_path, exist_ok=True)


def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):
Expand Down
114 changes: 114 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@
WAITING_DOT_NUMBER = 10
MAX_ITEMS = 100
PAGE_SIZE = 10
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators

_SENSITIVE_SYSTEM_PATHS = [
abspath(os.path.expanduser("~/.aws")),
abspath(os.path.expanduser("~/.ssh")),
abspath(os.path.expanduser("~/.kube")),
abspath(os.path.expanduser("~/.docker")),
abspath(os.path.expanduser("~/.config")),
abspath(os.path.expanduser("~/.credentials")),
abspath(realpath("/etc")),
abspath(realpath("/root")),
abspath(realpath("/var/lib")),
abspath(realpath("/opt/ml/metadata")),
]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -601,11 +615,73 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))


def _validate_source_directory(source_directory):
"""Validate that source_directory is safe to use.

Ensures the source directory path does not access restricted system locations.

Args:
source_directory (str): The source directory path to validate.

Raises:
ValueError: If the path is not allowed.
"""
if not source_directory or source_directory.lower().startswith("s3://"):
# S3 paths and None are safe
return

# Resolve symlinks to get the actual path
abs_source = abspath(realpath(source_directory))

# Check if the source path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if abs_source != "/" and abs_source.startswith(sensitive_path):
raise ValueError(
f"source_directory cannot access sensitive system paths. "
f"Got: {source_directory} (resolved to {abs_source})"
)


def _validate_dependency_path(dependency):
"""Validate that a dependency path is safe to use.

Ensures the dependency path does not access restricted system locations.

Args:
dependency (str): The dependency path to validate.

Raises:
ValueError: If the path is not allowed.
"""
if not dependency:
return

# Resolve symlinks to get the actual path
abs_dependency = abspath(realpath(dependency))

# Check if the dependency path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
raise ValueError(
f"dependency path cannot access sensitive system paths. "
f"Got: {dependency} (resolved to {abs_dependency})"
)


def _create_or_update_code_dir(
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
):
"""Placeholder docstring"""
code_dir = os.path.join(model_dir, "code")
resolved_code_dir = _get_resolved_path(code_dir)

# Validate that code_dir does not resolve to a sensitive system path
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
raise ValueError(
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
)

if source_directory and source_directory.lower().startswith("s3://"):
local_code_path = os.path.join(tmp, "local_code.tar.gz")
download_file_from_url(source_directory, local_code_path, sagemaker_session)
Expand All @@ -614,6 +690,8 @@ def _create_or_update_code_dir(
custom_extractall_tarfile(t, code_dir)

elif source_directory:
# Validate source_directory for security
_validate_source_directory(source_directory)
if os.path.exists(code_dir):
shutil.rmtree(code_dir)
shutil.copytree(source_directory, code_dir)
Expand Down Expand Up @@ -646,6 +724,8 @@ def _create_or_update_code_dir(
)

for dependency in dependencies:
# Validate dependency path for security
_validate_dependency_path(dependency)
lib_dir = os.path.join(code_dir, "lib")
if os.path.isdir(dependency):
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
Expand Down Expand Up @@ -1620,6 +1700,38 @@ def _get_safe_members(members):
yield file_info


def _validate_extracted_paths(extract_path):
"""Validate that extracted paths remain within the expected directory.

Performs post-extraction validation to ensure all extracted files and directories
are within the intended extraction path.

Args:
extract_path (str): The path where files were extracted.

Raises:
ValueError: If any extracted file is outside the expected extraction path.
"""
base = _get_resolved_path(extract_path)

for root, dirs, files in os.walk(extract_path):
# Check directories
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
resolved = _get_resolved_path(dir_path)
if not resolved.startswith(base):
logger.error("Extracted directory escaped extraction path: %s", dir_path)
raise ValueError(f"Extracted path outside expected directory: {dir_path}")

# Check files
for file_name in files:
file_path = os.path.join(root, file_name)
resolved = _get_resolved_path(file_path)
if not resolved.startswith(base):
logger.error("Extracted file escaped extraction path: %s", file_path)
raise ValueError(f"Extracted path outside expected directory: {file_path}")


def custom_extractall_tarfile(tar, extract_path):
"""Extract a tarfile, optionally using data_filter if available.

Expand All @@ -1640,6 +1752,8 @@ def custom_extractall_tarfile(tar, extract_path):
tar.extractall(path=extract_path, filter="data")
else:
tar.extractall(path=extract_path, members=_get_safe_members(tar))
# Re-validate extracted paths to catch symlink race conditions
_validate_extracted_paths(extract_path)


def can_model_package_source_uri_autopopulate(source_uri: str):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/local/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
@patch("sagemaker.local.utils.os.path")
@patch("sagemaker.local.utils.os")
def test_copy_directory_structure(m_os, m_os_path):
m_os_path.exists.return_value = False
m_os_path.join.return_value = "/tmp/code/"
sagemaker.local.utils.copy_directory_structure("/tmp/", "code/")
m_os.makedirs.assert_called_with("/tmp/", "code/")
m_os.makedirs.assert_called_with("/tmp/code/", exist_ok=True)


@patch("shutil.rmtree", Mock())
Expand Down
157 changes: 157 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,3 +2245,160 @@ def test_get_domain_for_region(self):
self.assertEqual(get_domain_for_region("us-iso-east-1"), "c2s.ic.gov")
self.assertEqual(get_domain_for_region("us-isob-east-1"), "sc2s.sgov.gov")
self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com")



class TestValidateSourceDirectory(TestCase):
"""Tests for _validate_source_directory function"""

def test_validate_source_directory_with_s3_path(self):
"""S3 paths should be allowed"""
from sagemaker.utils import _validate_source_directory
# Should not raise any exception
_validate_source_directory("s3://my-bucket/my-prefix")

def test_validate_source_directory_with_none(self):
"""None should be allowed"""
from sagemaker.utils import _validate_source_directory
# Should not raise any exception
_validate_source_directory(None)

def test_validate_source_directory_with_safe_local_path(self):
"""Safe local paths should be allowed"""
from sagemaker.utils import _validate_source_directory
# Should not raise any exception
_validate_source_directory("/tmp/my_code")
_validate_source_directory("./my_code")
_validate_source_directory("../my_code")

def test_validate_source_directory_with_sensitive_path_aws(self):
"""Paths under ~/.aws should be rejected"""
from sagemaker.utils import _validate_source_directory
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
_validate_source_directory(os.path.expanduser("~/.aws/credentials"))

def test_validate_source_directory_with_sensitive_path_ssh(self):
"""Paths under ~/.ssh should be rejected"""
from sagemaker.utils import _validate_source_directory
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
_validate_source_directory(os.path.expanduser("~/.ssh/id_rsa"))

def test_validate_source_directory_with_root_directory(self):
"""Root directory itself should be allowed (not rejected)"""
from sagemaker.utils import _validate_source_directory
# Should not raise any exception - root directory is explicitly allowed
_validate_source_directory("/")


class TestValidateDependencyPath(TestCase):
"""Tests for _validate_dependency_path function"""

def test_validate_dependency_path_with_none(self):
"""None should be allowed"""
from sagemaker.utils import _validate_dependency_path
# Should not raise any exception
_validate_dependency_path(None)

def test_validate_dependency_path_with_safe_local_path(self):
"""Safe local paths should be allowed"""
from sagemaker.utils import _validate_dependency_path
# Should not raise any exception
_validate_dependency_path("/tmp/my_lib")
_validate_dependency_path("./my_lib")
_validate_dependency_path("../my_lib")

def test_validate_dependency_path_with_sensitive_path_aws(self):
"""Paths under ~/.aws should be rejected"""
from sagemaker.utils import _validate_dependency_path
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
_validate_dependency_path(os.path.expanduser("~/.aws"))

def test_validate_dependency_path_with_sensitive_path_docker(self):
"""Paths under ~/.docker should be rejected"""
from sagemaker.utils import _validate_dependency_path
with pytest.raises(ValueError, match="cannot access sensitive system paths"):
_validate_dependency_path(os.path.expanduser("~/.docker/config.json"))

def test_validate_dependency_path_with_root_directory(self):
"""Root directory itself should be allowed (not rejected)"""
from sagemaker.utils import _validate_dependency_path
# Should not raise any exception - root directory is explicitly allowed
_validate_dependency_path("/")


class TestCreateOrUpdateCodeDir(TestCase):
"""Tests for _create_or_update_code_dir function"""

@patch("sagemaker.utils._validate_source_directory")
@patch("sagemaker.utils._validate_dependency_path")
@patch("sagemaker.utils.os.path.exists")
@patch("sagemaker.utils.os.mkdir")
@patch("sagemaker.utils.shutil.copy2")
def test_create_or_update_code_dir_with_inference_script(
self, mock_copy, mock_mkdir, mock_exists, mock_validate_dep, mock_validate_src
):
"""Test creating code dir with inference script"""
from sagemaker.utils import _create_or_update_code_dir

mock_exists.return_value = False

with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved:
mock_get_resolved.return_value = "/tmp/model/code"

_create_or_update_code_dir(
model_dir="/tmp/model",
inference_script="inference.py",
source_directory=None,
dependencies=[],
sagemaker_session=None,
tmp="/tmp"
)

mock_mkdir.assert_called()
mock_copy.assert_called_once()

@patch("sagemaker.utils._validate_source_directory")
@patch("sagemaker.utils.os.path.exists")
@patch("sagemaker.utils.shutil.rmtree")
@patch("sagemaker.utils.shutil.copytree")
def test_create_or_update_code_dir_with_source_directory(
self, mock_copytree, mock_rmtree, mock_exists, mock_validate_src
):
"""Test creating code dir with source directory"""
from sagemaker.utils import _create_or_update_code_dir

mock_exists.return_value = True

with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved:
mock_get_resolved.return_value = "/tmp/model/code"

_create_or_update_code_dir(
model_dir="/tmp/model",
inference_script=None,
source_directory="/tmp/my_code",
dependencies=[],
sagemaker_session=None,
tmp="/tmp"
)

mock_validate_src.assert_called_once_with("/tmp/my_code")
mock_rmtree.assert_called_once()
mock_copytree.assert_called_once()

def test_create_or_update_code_dir_with_sensitive_code_dir(self):
"""Test that code_dir resolving to sensitive path is rejected"""
from sagemaker.utils import _create_or_update_code_dir

with patch("sagemaker.utils._get_resolved_path") as mock_get_resolved:
# Simulate code_dir resolving to a sensitive path
mock_get_resolved.return_value = os.path.abspath(os.path.expanduser("~/.aws"))

with pytest.raises(ValueError, match="Invalid code_dir path"):
_create_or_update_code_dir(
model_dir="/tmp/model",
inference_script="inference.py",
source_directory=None,
dependencies=[],
sagemaker_session=None,
tmp="/tmp"
)
Loading