@@ -647,7 +647,10 @@ def _validate_source_directory(source_directory):
647647
648648 # Check if the source path is under any sensitive directory
649649 for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
650- if abs_source != "/" and os .path .commonpath ([abs_source , sensitive_path ]) == sensitive_path :
650+ if abs_source != "/" and (
651+ os .path .commonpath ([abs_source , sensitive_path ])
652+ == sensitive_path
653+ ):
651654 raise ValueError (
652655 f"source_directory cannot access sensitive system paths. "
653656 f"Got: { source_directory } (resolved to { abs_source } )"
@@ -673,7 +676,10 @@ def _validate_dependency_path(dependency):
673676
674677 # Check if the dependency path is under any sensitive directory
675678 for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
676- if abs_dependency != "/" and os .path .commonpath ([abs_dependency , sensitive_path ]) == sensitive_path :
679+ if abs_dependency != "/" and (
680+ os .path .commonpath ([abs_dependency , sensitive_path ])
681+ == sensitive_path
682+ ):
677683 raise ValueError (
678684 f"dependency path cannot access sensitive system paths. "
679685 f"Got: { dependency } (resolved to { abs_dependency } )"
@@ -686,10 +692,13 @@ def _create_or_update_code_dir(
686692 """Placeholder docstring"""
687693 code_dir = os .path .join (model_dir , "code" )
688694 resolved_code_dir = _get_resolved_path (code_dir )
689-
695+
690696 # Validate that code_dir does not resolve to a sensitive system path
691697 for sensitive_path in _SENSITIVE_SYSTEM_PATHS :
692- if resolved_code_dir != "/" and os .path .commonpath ([resolved_code_dir , sensitive_path ]) == sensitive_path :
698+ if resolved_code_dir != "/" and (
699+ os .path .commonpath ([resolved_code_dir , sensitive_path ])
700+ == sensitive_path
701+ ):
693702 raise ValueError (
694703 f"Invalid code_dir path: { code_dir } resolves to sensitive system path { resolved_code_dir } "
695704 )
@@ -1688,7 +1697,8 @@ def _is_bad_path(path, base):
16881697 bool: True if the path is not rooted under the base directory, False otherwise.
16891698 """
16901699 # joinpath will ignore base if path is absolute
1691- return not _get_resolved_path (joinpath (base , path )).startswith (base )
1700+ resolved = _get_resolved_path (joinpath (base , path ))
1701+ return os .path .commonpath ([resolved , base ]) != base
16921702
16931703
16941704def _is_bad_link (info , base ):
@@ -1708,19 +1718,18 @@ def _is_bad_link(info, base):
17081718 return _is_bad_path (info .linkname , base = tip )
17091719
17101720
1711- def _get_safe_members (members ):
1721+ def _get_safe_members (members , base ):
17121722 """A generator that yields members that are safe to extract.
17131723
17141724 It filters out bad paths and bad links.
17151725
17161726 Args:
1717- members (list): A list of members to check.
1727+ members (list): A list of TarInfo members to check.
1728+ base (str): The resolved base directory for extraction.
17181729
17191730 Yields:
17201731 tarfile.TarInfo: The tar file info.
17211732 """
1722- base = _get_resolved_path ("" )
1723-
17241733 for file_info in members :
17251734 if _is_bad_path (file_info .name , base ):
17261735 logger .error ("%s is blocked (illegal path)" , file_info .name )
@@ -1783,7 +1792,11 @@ def custom_extractall_tarfile(tar, extract_path):
17831792 if hasattr (tarfile , "data_filter" ):
17841793 tar .extractall (path = extract_path , filter = "data" )
17851794 else :
1786- tar .extractall (path = extract_path , members = _get_safe_members (tar ))
1795+ base = _get_resolved_path (extract_path )
1796+ tar .extractall (
1797+ path = extract_path ,
1798+ members = _get_safe_members (tar .getmembers (), base ),
1799+ )
17871800 # Re-validate extracted paths to catch symlink race conditions
17881801 _validate_extracted_paths (extract_path )
17891802
0 commit comments