Skip to content

Commit 0f4ccb2

Browse files
committed
fix: Enhancements Needed for Secure Tar Extraction (5560)
1 parent ee420cc commit 0f4ccb2

File tree

2 files changed

+233
-3
lines changed

2 files changed

+233
-3
lines changed

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def _validate_source_directory(source_directory):
647647

648648
# Check if the source path is under any sensitive directory
649649
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
650-
if abs_source != "/" and abs_source.startswith(sensitive_path):
650+
if abs_source != "/" and os.path.commonpath([abs_source, sensitive_path]) == sensitive_path:
651651
raise ValueError(
652652
f"source_directory cannot access sensitive system paths. "
653653
f"Got: {source_directory} (resolved to {abs_source})"
@@ -673,7 +673,7 @@ def _validate_dependency_path(dependency):
673673

674674
# Check if the dependency path is under any sensitive directory
675675
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
676-
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
676+
if abs_dependency != "/" and os.path.commonpath([abs_dependency, sensitive_path]) == sensitive_path:
677677
raise ValueError(
678678
f"dependency path cannot access sensitive system paths. "
679679
f"Got: {dependency} (resolved to {abs_dependency})"
@@ -689,7 +689,7 @@ def _create_or_update_code_dir(
689689

690690
# Validate that code_dir does not resolve to a sensitive system path
691691
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
692-
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
692+
if resolved_code_dir != "/" and os.path.commonpath([resolved_code_dir, sensitive_path]) == sensitive_path:
693693
raise ValueError(
694694
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
695695
)

sagemaker-core/tests/unit/test_common_utils.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,236 @@ def test_custom_extractall_tarfile_basic(self, tmp_path):
11391139

11401140
assert (extract_path / "file.txt").exists()
11411141

1142+
def test_custom_extractall_tarfile_without_data_filter(self, tmp_path):
1143+
"""Test custom_extractall_tarfile uses safe members with extract_path as base when data_filter unavailable."""
1144+
from sagemaker.core.common_utils import custom_extractall_tarfile
1145+
1146+
# Create tar file
1147+
source = tmp_path / "source"
1148+
source.mkdir()
1149+
(source / "file.txt").write_text("content")
1150+
1151+
tar_path = tmp_path / "test.tar.gz"
1152+
with tarfile.open(tar_path, "w:gz") as tar:
1153+
tar.add(source / "file.txt", arcname="file.txt")
1154+
1155+
extract_path = tmp_path / "extract"
1156+
extract_path.mkdir()
1157+
1158+
with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile_module:
1159+
# Remove data_filter to force fallback path
1160+
if hasattr(mock_tarfile_module, 'data_filter'):
1161+
delattr(mock_tarfile_module, 'data_filter')
1162+
1163+
with tarfile.open(tar_path, "r:gz") as tar:
1164+
with patch('sagemaker.core.common_utils._get_safe_members') as mock_safe:
1165+
mock_safe.return_value = tar.getmembers()
1166+
custom_extractall_tarfile(tar, str(extract_path))
1167+
# Verify _get_safe_members was called with members list and base path
1168+
mock_safe.assert_called_once()
1169+
call_args = mock_safe.call_args
1170+
# First arg should be a list of TarInfo members
1171+
assert isinstance(call_args[0][0], list)
1172+
# Second arg should be the resolved extract path (not CWD)
1173+
from sagemaker.core.common_utils import _get_resolved_path
1174+
expected_base = _get_resolved_path(str(extract_path))
1175+
assert call_args[0][1] == expected_base
1176+
1177+
1178+
class TestIsBadPath:
1179+
"""Test _is_bad_path function for secure tar extraction."""
1180+
1181+
def test_is_bad_path_safe_relative(self):
1182+
"""Test _is_bad_path returns False for safe relative paths."""
1183+
from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path
1184+
1185+
base = _get_resolved_path("/tmp/safe")
1186+
assert _is_bad_path("safe/file.txt", base) is False
1187+
1188+
def test_is_bad_path_actual_escape(self):
1189+
"""Test _is_bad_path returns True for paths escaping base."""
1190+
from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path
1191+
1192+
base = _get_resolved_path("/tmp/safe")
1193+
assert _is_bad_path("/etc/passwd", base) is True
1194+
1195+
def test_is_bad_path_traversal(self):
1196+
"""Test _is_bad_path detects parent directory traversal."""
1197+
from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path
1198+
1199+
base = _get_resolved_path("/tmp/safe")
1200+
assert _is_bad_path("../../etc/passwd", base) is True
1201+
1202+
def test_is_bad_path_prefix_collision(self):
1203+
"""Test _is_bad_path does NOT flag /tmp/safe2 when base is /tmp/safe.
1204+
1205+
This is the key test for the startswith() bug fix - /tmp/safe2 starts with
1206+
/tmp/safe but is NOT actually under /tmp/safe.
1207+
"""
1208+
from sagemaker.core.common_utils import _is_bad_path, _get_resolved_path
1209+
1210+
base = _get_resolved_path("/tmp/safe")
1211+
# /tmp/safe2 is NOT under /tmp/safe, but startswith would incorrectly say it is
1212+
# With the fix using commonpath, this should correctly identify it as outside base
1213+
assert _is_bad_path("/tmp/safe2/file.txt", base) is True
1214+
1215+
1216+
class TestGetSafeMembers:
1217+
"""Test _get_safe_members function for secure tar extraction."""
1218+
1219+
def test_get_safe_members_accepts_member_list_and_base(self):
1220+
"""Test _get_safe_members works with a list of TarInfo mocks and a base path."""
1221+
from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path
1222+
1223+
base = _get_resolved_path("/tmp/extract")
1224+
1225+
mock_member = Mock()
1226+
mock_member.name = "safe/file.txt"
1227+
mock_member.issym = Mock(return_value=False)
1228+
mock_member.islnk = Mock(return_value=False)
1229+
1230+
members = [mock_member]
1231+
safe = list(_get_safe_members(members, base))
1232+
assert len(safe) == 1
1233+
assert mock_member in safe
1234+
1235+
def test_get_safe_members_filters_bad_paths(self):
1236+
"""Test _get_safe_members filters out members with bad paths."""
1237+
from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path
1238+
1239+
base = _get_resolved_path("/tmp/extract")
1240+
1241+
mock_safe = Mock()
1242+
mock_safe.name = "safe/file.txt"
1243+
mock_safe.issym = Mock(return_value=False)
1244+
mock_safe.islnk = Mock(return_value=False)
1245+
1246+
mock_bad = Mock()
1247+
mock_bad.name = "/etc/passwd"
1248+
mock_bad.issym = Mock(return_value=False)
1249+
mock_bad.islnk = Mock(return_value=False)
1250+
1251+
with patch('sagemaker.core.common_utils._is_bad_path') as mock_is_bad:
1252+
mock_is_bad.side_effect = lambda name, base: name == "/etc/passwd"
1253+
safe = list(_get_safe_members([mock_safe, mock_bad], base))
1254+
assert len(safe) == 1
1255+
assert mock_safe in safe
1256+
1257+
def test_get_safe_members_filters_bad_symlinks(self):
1258+
"""Test _get_safe_members filters out bad symlinks."""
1259+
from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path
1260+
1261+
base = _get_resolved_path("/tmp/extract")
1262+
1263+
mock_safe = Mock()
1264+
mock_safe.name = "safe/file.txt"
1265+
mock_safe.issym = Mock(return_value=False)
1266+
mock_safe.islnk = Mock(return_value=False)
1267+
1268+
mock_symlink = Mock()
1269+
mock_symlink.name = "bad/symlink"
1270+
mock_symlink.issym = Mock(return_value=True)
1271+
mock_symlink.islnk = Mock(return_value=False)
1272+
mock_symlink.linkname = "/etc/passwd"
1273+
1274+
with patch('sagemaker.core.common_utils._is_bad_path', return_value=False):
1275+
with patch('sagemaker.core.common_utils._is_bad_link') as mock_is_bad_link:
1276+
mock_is_bad_link.return_value = True
1277+
safe = list(_get_safe_members([mock_safe, mock_symlink], base))
1278+
assert len(safe) == 1
1279+
assert mock_safe in safe
1280+
1281+
def test_get_safe_members_filters_bad_hardlinks(self):
1282+
"""Test _get_safe_members filters out bad hardlinks."""
1283+
from sagemaker.core.common_utils import _get_safe_members, _get_resolved_path
1284+
1285+
base = _get_resolved_path("/tmp/extract")
1286+
1287+
mock_safe = Mock()
1288+
mock_safe.name = "safe/file.txt"
1289+
mock_safe.issym = Mock(return_value=False)
1290+
mock_safe.islnk = Mock(return_value=False)
1291+
1292+
mock_hardlink = Mock()
1293+
mock_hardlink.name = "bad/hardlink"
1294+
mock_hardlink.issym = Mock(return_value=False)
1295+
mock_hardlink.islnk = Mock(return_value=True)
1296+
mock_hardlink.linkname = "/etc/passwd"
1297+
1298+
with patch('sagemaker.core.common_utils._is_bad_path', return_value=False):
1299+
with patch('sagemaker.core.common_utils._is_bad_link') as mock_is_bad_link:
1300+
mock_is_bad_link.return_value = True
1301+
safe = list(_get_safe_members([mock_safe, mock_hardlink], base))
1302+
assert len(safe) == 1
1303+
assert mock_safe in safe
1304+
1305+
1306+
class TestValidateSourceDirectorySecurity:
1307+
"""Test _validate_source_directory prefix collision fix."""
1308+
1309+
def test_validate_source_directory_blocks_sensitive_path(self):
1310+
"""Test that actual sensitive paths are blocked."""
1311+
from sagemaker.core.common_utils import _validate_source_directory
1312+
1313+
with pytest.raises(ValueError, match="sensitive system paths"):
1314+
_validate_source_directory("/etc/secrets")
1315+
1316+
def test_validate_source_directory_prefix_collision(self):
1317+
"""Test that /etcetera is NOT blocked when /etc is in sensitive paths.
1318+
1319+
This tests the fix for the startswith() prefix collision vulnerability.
1320+
"""
1321+
from sagemaker.core.common_utils import _validate_source_directory
1322+
1323+
# /etcetera should NOT be blocked - it's not under /etc
1324+
# With the old startswith() check, this would incorrectly raise ValueError
1325+
try:
1326+
_validate_source_directory("/etcetera")
1327+
except ValueError:
1328+
pytest.fail("_validate_source_directory incorrectly blocked /etcetera due to prefix collision with /etc")
1329+
1330+
def test_validate_source_directory_s3_path(self):
1331+
"""Test that S3 paths are allowed."""
1332+
from sagemaker.core.common_utils import _validate_source_directory
1333+
1334+
_validate_source_directory("s3://my-bucket/my-prefix")
1335+
1336+
def test_validate_source_directory_none(self):
1337+
"""Test that None is allowed."""
1338+
from sagemaker.core.common_utils import _validate_source_directory
1339+
1340+
_validate_source_directory(None)
1341+
1342+
1343+
class TestValidateDependencyPathSecurity:
1344+
"""Test _validate_dependency_path prefix collision fix."""
1345+
1346+
def test_validate_dependency_path_blocks_sensitive_path(self):
1347+
"""Test that actual sensitive paths are blocked."""
1348+
from sagemaker.core.common_utils import _validate_dependency_path
1349+
1350+
with pytest.raises(ValueError, match="sensitive system paths"):
1351+
_validate_dependency_path("/root/.bashrc")
1352+
1353+
def test_validate_dependency_path_prefix_collision(self):
1354+
"""Test that /rootkit is NOT blocked when /root is in sensitive paths.
1355+
1356+
This tests the fix for the startswith() prefix collision vulnerability.
1357+
"""
1358+
from sagemaker.core.common_utils import _validate_dependency_path
1359+
1360+
# /rootkit should NOT be blocked - it's not under /root
1361+
try:
1362+
_validate_dependency_path("/rootkit")
1363+
except ValueError:
1364+
pytest.fail("_validate_dependency_path incorrectly blocked /rootkit due to prefix collision with /root")
1365+
1366+
def test_validate_dependency_path_none(self):
1367+
"""Test that None is allowed."""
1368+
from sagemaker.core.common_utils import _validate_dependency_path
1369+
1370+
_validate_dependency_path(None)
1371+
11421372

11431373
class TestCanModelPackageSourceUriAutopopulate:
11441374
"""Test can_model_package_source_uri_autopopulate function."""

0 commit comments

Comments
 (0)