@@ -57,7 +57,8 @@ def _is_bad_path(path, base):
5757 bool: True if the path is not rooted under the base directory, False otherwise.
5858 """
5959 # joinpath will ignore base if path is absolute
60- return not _get_resolved_path (joinpath (base , path )).startswith (base )
60+ resolved = _get_resolved_path (joinpath (base , path ))
61+ return os .path .commonpath ([resolved , base ]) != base
6162
6263
6364def _is_bad_link (info , base ):
@@ -77,19 +78,18 @@ def _is_bad_link(info, base):
7778 return _is_bad_path (info .linkname , base = tip )
7879
7980
80- def _get_safe_members (members ):
81+ def _get_safe_members (members , base ):
8182 """A generator that yields members that are safe to extract.
8283
8384 It filters out bad paths and bad links.
8485
8586 Args:
8687 members (list): A list of members to check.
88+ base (str): The base directory for extraction.
8789
8890 Yields:
8991 tarfile.TarInfo: The tar file info.
9092 """
91- base = _get_resolved_path ("" )
92-
9393 for file_info in members :
9494 if _is_bad_path (file_info .name , base ):
9595 logger .error ("%s is blocked (illegal path)" , file_info .name )
@@ -120,7 +120,8 @@ def custom_extractall_tarfile(tar, extract_path):
120120 if hasattr (tarfile , "data_filter" ):
121121 tar .extractall (path = extract_path , filter = "data" )
122122 else :
123- tar .extractall (path = extract_path , members = _get_safe_members (tar ))
123+ base = _get_resolved_path (extract_path )
124+ tar .extractall (path = extract_path , members = _get_safe_members (tar .getmembers (), base ))
124125
125126
126127def repack (inference_script , model_archive , source_dir = None ): # pragma: no cover
0 commit comments