diff --git a/python/paddle/utils/download.py b/python/paddle/utils/download.py index 5125dfdbbdae10..4ba4a3938fbad0 100644 --- a/python/paddle/utils/download.py +++ b/python/paddle/utils/download.py @@ -360,7 +360,7 @@ def _safe_extract_zip(zip, path, members=None): for member in members_to_check: if not _safe_extract_member(member, path, 'zip'): raise ValueError( - f"Attempted path traversal in tar file: {member.name}" + f"Attempted path traversal in zip file: {member.filename}" ) zip.extractall(path, members=members_to_check) diff --git a/test/legacy_test/test_download.py b/test/legacy_test/test_download.py index da25a3021a31e0..e309427b130a42 100644 --- a/test/legacy_test/test_download.py +++ b/test/legacy_test/test_download.py @@ -13,9 +13,15 @@ # limitations under the License. import os +import tempfile import unittest +import zipfile -from paddle.utils.download import get_path_from_url, get_weights_path_from_url +from paddle.utils.download import ( + _safe_extract_zip, + get_path_from_url, + get_weights_path_from_url, +) class TestDownload(unittest.TestCase): @@ -141,5 +147,24 @@ def test_download_methods( ) +class TestSafeExtractZip(unittest.TestCase): + def test_path_traversal_error_message(self): + with tempfile.TemporaryDirectory() as tmpdir: + zip_path = os.path.join(tmpdir, "unsafe.zip") + extract_dir = os.path.join(tmpdir, "extract") + + with zipfile.ZipFile(zip_path, "w") as archive: + archive.writestr("../evil.txt", "unsafe") + + with zipfile.ZipFile(zip_path) as archive: + with self.assertRaises(ValueError) as cm: + _safe_extract_zip(archive, extract_dir) + + self.assertEqual( + str(cm.exception), + "Attempted path traversal in zip file: ../evil.txt", + ) + + if __name__ == '__main__': unittest.main()