Skip to content

Commit e13e817

Browse files
committed
fix: address review comments (iteration #1)
1 parent 0f4ccb2 commit e13e817

File tree

2 files changed

+145
-138
lines changed

2 files changed

+145
-138
lines changed

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,10 @@ def _validate_source_directory(source_directory):
647647

648648
# Check if the source path is under any sensitive directory
649649
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
650-
if abs_source != "/" and os.path.commonpath([abs_source, sensitive_path]) == sensitive_path:
650+
if abs_source != "/" and (
651+
os.path.commonpath([abs_source, sensitive_path])
652+
== sensitive_path
653+
):
651654
raise ValueError(
652655
f"source_directory cannot access sensitive system paths. "
653656
f"Got: {source_directory} (resolved to {abs_source})"
@@ -673,7 +676,10 @@ def _validate_dependency_path(dependency):
673676

674677
# Check if the dependency path is under any sensitive directory
675678
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
676-
if abs_dependency != "/" and os.path.commonpath([abs_dependency, sensitive_path]) == sensitive_path:
679+
if abs_dependency != "/" and (
680+
os.path.commonpath([abs_dependency, sensitive_path])
681+
== sensitive_path
682+
):
677683
raise ValueError(
678684
f"dependency path cannot access sensitive system paths. "
679685
f"Got: {dependency} (resolved to {abs_dependency})"
@@ -686,10 +692,13 @@ def _create_or_update_code_dir(
686692
"""Placeholder docstring"""
687693
code_dir = os.path.join(model_dir, "code")
688694
resolved_code_dir = _get_resolved_path(code_dir)
689-
695+
690696
# Validate that code_dir does not resolve to a sensitive system path
691697
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
692-
if resolved_code_dir != "/" and os.path.commonpath([resolved_code_dir, sensitive_path]) == sensitive_path:
698+
if resolved_code_dir != "/" and (
699+
os.path.commonpath([resolved_code_dir, sensitive_path])
700+
== sensitive_path
701+
):
693702
raise ValueError(
694703
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
695704
)
@@ -1688,7 +1697,8 @@ def _is_bad_path(path, base):
16881697
bool: True if the path is not rooted under the base directory, False otherwise.
16891698
"""
16901699
# joinpath will ignore base if path is absolute
1691-
return not _get_resolved_path(joinpath(base, path)).startswith(base)
1700+
resolved = _get_resolved_path(joinpath(base, path))
1701+
return os.path.commonpath([resolved, base]) != base
16921702

16931703

16941704
def _is_bad_link(info, base):
@@ -1708,19 +1718,18 @@ def _is_bad_link(info, base):
17081718
return _is_bad_path(info.linkname, base=tip)
17091719

17101720

1711-
def _get_safe_members(members):
1721+
def _get_safe_members(members, base):
17121722
"""A generator that yields members that are safe to extract.
17131723
17141724
It filters out bad paths and bad links.
17151725
17161726
Args:
1717-
members (list): A list of members to check.
1727+
members (list): A list of TarInfo members to check.
1728+
base (str): The resolved base directory for extraction.
17181729
17191730
Yields:
17201731
tarfile.TarInfo: The tar file info.
17211732
"""
1722-
base = _get_resolved_path("")
1723-
17241733
for file_info in members:
17251734
if _is_bad_path(file_info.name, base):
17261735
logger.error("%s is blocked (illegal path)", file_info.name)
@@ -1783,7 +1792,11 @@ def custom_extractall_tarfile(tar, extract_path):
17831792
if hasattr(tarfile, "data_filter"):
17841793
tar.extractall(path=extract_path, filter="data")
17851794
else:
1786-
tar.extractall(path=extract_path, members=_get_safe_members(tar))
1795+
base = _get_resolved_path(extract_path)
1796+
tar.extractall(
1797+
path=extract_path,
1798+
members=_get_safe_members(tar.getmembers(), base),
1799+
)
17871800
# Re-validate extracted paths to catch symlink race conditions
17881801
_validate_extracted_paths(extract_path)
17891802

0 commit comments

Comments
 (0)