Skip to content

Commit 6c6b0a6

Browse files
committed
fix: Enhancements Needed for Secure Tar Extraction (5560)
1 parent 272fdbf commit 6c6b0a6

3 files changed

Lines changed: 254 additions & 6 deletions

File tree

sagemaker-core/src/sagemaker/core/common_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,7 +1688,8 @@ def _is_bad_path(path, base):
16881688
bool: True if the path is not rooted under the base directory, False otherwise.
16891689
"""
16901690
# joinpath will ignore base if path is absolute
1691-
return not _get_resolved_path(joinpath(base, path)).startswith(base)
1691+
resolved = _get_resolved_path(joinpath(base, path))
1692+
return os.path.commonpath([resolved, base]) != base
16921693

16931694

16941695
def _is_bad_link(info, base):
@@ -1708,19 +1709,18 @@ def _is_bad_link(info, base):
17081709
return _is_bad_path(info.linkname, base=tip)
17091710

17101711

1711-
def _get_safe_members(members):
1712+
def _get_safe_members(members, base):
17121713
"""A generator that yields members that are safe to extract.
17131714
17141715
It filters out bad paths and bad links.
17151716
17161717
Args:
1717-
members (list): A list of members to check.
1718+
members (list): A list of TarInfo members to check.
1719+
base (str): The base directory for extraction.
17181720
17191721
Yields:
17201722
tarfile.TarInfo: The tar file info.
17211723
"""
1722-
base = _get_resolved_path("")
1723-
17241724
for file_info in members:
17251725
if _is_bad_path(file_info.name, base):
17261726
logger.error("%s is blocked (illegal path)", file_info.name)
@@ -1783,7 +1783,8 @@ def custom_extractall_tarfile(tar, extract_path):
17831783
if hasattr(tarfile, "data_filter"):
17841784
tar.extractall(path=extract_path, filter="data")
17851785
else:
1786-
tar.extractall(path=extract_path, members=_get_safe_members(tar))
1786+
base = _get_resolved_path(extract_path)
1787+
tar.extractall(path=extract_path, members=_get_safe_members(tar.getmembers(), base))
17871788
# Re-validate extracted paths to catch symlink race conditions
17881789
_validate_extracted_paths(extract_path)
17891790

sagemaker-core/src/sagemaker/core/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
"sagemaker_timestamp",
3737
"sagemaker_short_timestamp",
3838
"get_config_value",
39+
"_get_resolved_path",
40+
"_is_bad_path",
41+
"_is_bad_link",
42+
"_get_safe_members",
3943
]
4044

4145

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Unit tests for tar extraction safety functions in common_utils."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
import pytest
18+
import tarfile
19+
import tempfile
20+
from unittest.mock import Mock, patch, MagicMock
21+
22+
from sagemaker.core.common_utils import (
23+
_get_resolved_path,
24+
_is_bad_path,
25+
_is_bad_link,
26+
_get_safe_members,
27+
custom_extractall_tarfile,
28+
)
29+
30+
31+
def test_get_resolved_path_returns_normalized_absolute_path():
32+
"""Test _get_resolved_path returns normalized absolute path."""
33+
path = "./test/path"
34+
result = _get_resolved_path(path)
35+
assert os.path.isabs(result)
36+
assert result == os.path.normpath(os.path.realpath(os.path.abspath(path)))
37+
38+
39+
def test_get_resolved_path_with_absolute_path():
40+
"""Test _get_resolved_path with absolute path."""
41+
path = "/absolute/test/path"
42+
result = _get_resolved_path(path)
43+
assert result == os.path.normpath(os.path.realpath(os.path.abspath(path)))
44+
45+
46+
def test_is_bad_path_returns_false_for_safe_relative_path():
47+
"""Test _is_bad_path returns False for safe relative paths."""
48+
base = _get_resolved_path("/tmp/extract")
49+
safe_path = "safe/path/file.txt"
50+
assert _is_bad_path(safe_path, base) is False
51+
52+
53+
def test_is_bad_path_returns_true_for_absolute_escape_path():
54+
"""Test _is_bad_path returns True for absolute paths that escape base."""
55+
base = _get_resolved_path("/tmp/safe")
56+
unsafe_path = "/etc/passwd"
57+
assert _is_bad_path(unsafe_path, base) is True
58+
59+
60+
def test_is_bad_path_returns_true_for_parent_traversal():
61+
"""Test _is_bad_path detects parent directory traversal."""
62+
base = _get_resolved_path("/tmp/safe/extract")
63+
traversal_path = "../../../etc/passwd"
64+
assert _is_bad_path(traversal_path, base) is True
65+
66+
67+
def test_is_bad_path_with_similar_prefix_does_not_false_positive():
68+
"""Test that /tmp/x2 is correctly identified as bad when base is /tmp/x.
69+
70+
This verifies the commonpath fix: startswith would incorrectly allow
71+
/tmp/x2 when base is /tmp/x, but commonpath correctly rejects it.
72+
"""
73+
base = _get_resolved_path("/tmp/x")
74+
# A path like /tmp/x2/file should NOT be under /tmp/x
75+
# With startswith, "/tmp/x2".startswith("/tmp/x") would be True (bug)
76+
# With commonpath, commonpath(["/tmp/x2", "/tmp/x"]) == "/tmp" != "/tmp/x" (correct)
77+
escape_path = "/tmp/x2/file"
78+
result = _is_bad_path(escape_path, base)
79+
assert result is True
80+
81+
82+
def test_is_bad_link_returns_false_for_safe_symlink():
83+
"""Test _is_bad_link returns False for safe links."""
84+
base = _get_resolved_path("/tmp/extract")
85+
86+
mock_info = Mock()
87+
mock_info.name = "safe/link"
88+
mock_info.linkname = "safe/target"
89+
90+
assert _is_bad_link(mock_info, base) is False
91+
92+
93+
def test_is_bad_link_returns_true_for_escape_symlink():
94+
"""Test _is_bad_link returns True for links that escape base."""
95+
base = _get_resolved_path("/tmp/safe")
96+
97+
mock_info = Mock()
98+
mock_info.name = "link"
99+
mock_info.linkname = "/etc/passwd"
100+
101+
result = _is_bad_link(mock_info, base)
102+
assert result is True
103+
104+
105+
def test_get_safe_members_yields_all_safe_members():
106+
"""Test _get_safe_members yields all safe members."""
107+
base = _get_resolved_path("/tmp/extract")
108+
109+
mock_member1 = Mock()
110+
mock_member1.name = "safe/file1.txt"
111+
mock_member1.issym = Mock(return_value=False)
112+
mock_member1.islnk = Mock(return_value=False)
113+
114+
mock_member2 = Mock()
115+
mock_member2.name = "safe/file2.txt"
116+
mock_member2.issym = Mock(return_value=False)
117+
mock_member2.islnk = Mock(return_value=False)
118+
119+
members = [mock_member1, mock_member2]
120+
safe_members = list(_get_safe_members(members, base))
121+
122+
assert len(safe_members) == 2
123+
assert mock_member1 in safe_members
124+
assert mock_member2 in safe_members
125+
126+
127+
def test_get_safe_members_filters_bad_path_member():
128+
"""Test _get_safe_members filters out members with bad paths."""
129+
base = _get_resolved_path("/tmp/extract")
130+
131+
mock_member_safe = Mock()
132+
mock_member_safe.name = "safe/file.txt"
133+
mock_member_safe.issym = Mock(return_value=False)
134+
mock_member_safe.islnk = Mock(return_value=False)
135+
136+
mock_member_bad = Mock()
137+
mock_member_bad.name = "/etc/passwd"
138+
mock_member_bad.issym = Mock(return_value=False)
139+
mock_member_bad.islnk = Mock(return_value=False)
140+
141+
members = [mock_member_safe, mock_member_bad]
142+
safe_members = list(_get_safe_members(members, base))
143+
144+
assert len(safe_members) == 1
145+
assert mock_member_safe in safe_members
146+
147+
148+
def test_get_safe_members_filters_bad_symlink_member():
149+
"""Test _get_safe_members filters out bad symlinks."""
150+
base = _get_resolved_path("/tmp/extract")
151+
152+
mock_member_safe = Mock()
153+
mock_member_safe.name = "safe/file.txt"
154+
mock_member_safe.issym = Mock(return_value=False)
155+
mock_member_safe.islnk = Mock(return_value=False)
156+
157+
mock_member_symlink = Mock()
158+
mock_member_symlink.name = "bad/symlink"
159+
mock_member_symlink.issym = Mock(return_value=True)
160+
mock_member_symlink.islnk = Mock(return_value=False)
161+
mock_member_symlink.linkname = "/etc/passwd"
162+
163+
members = [mock_member_safe, mock_member_symlink]
164+
safe_members = list(_get_safe_members(members, base))
165+
166+
assert len(safe_members) == 1
167+
assert mock_member_safe in safe_members
168+
169+
170+
def test_get_safe_members_filters_bad_hardlink_member():
171+
"""Test _get_safe_members filters out bad hardlinks."""
172+
base = _get_resolved_path("/tmp/extract")
173+
174+
mock_member_safe = Mock()
175+
mock_member_safe.name = "safe/file.txt"
176+
mock_member_safe.issym = Mock(return_value=False)
177+
mock_member_safe.islnk = Mock(return_value=False)
178+
179+
mock_member_hardlink = Mock()
180+
mock_member_hardlink.name = "bad/hardlink"
181+
mock_member_hardlink.issym = Mock(return_value=False)
182+
mock_member_hardlink.islnk = Mock(return_value=True)
183+
mock_member_hardlink.linkname = "/etc/passwd"
184+
185+
members = [mock_member_safe, mock_member_hardlink]
186+
safe_members = list(_get_safe_members(members, base))
187+
188+
assert len(safe_members) == 1
189+
assert mock_member_safe in safe_members
190+
191+
192+
def test_custom_extractall_tarfile_with_data_filter_uses_filter_param():
193+
"""Test custom_extractall_tarfile uses data_filter when available."""
194+
mock_tar = Mock()
195+
mock_tar.extractall = Mock()
196+
extract_path = "/tmp/extract"
197+
198+
with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile:
199+
mock_tarfile.data_filter = "data"
200+
201+
custom_extractall_tarfile(mock_tar, extract_path)
202+
203+
mock_tar.extractall.assert_called_once_with(path=extract_path, filter="data")
204+
205+
206+
def test_custom_extractall_tarfile_without_data_filter_uses_safe_members():
207+
"""Test custom_extractall_tarfile uses safe members with getmembers() and resolved extract_path."""
208+
mock_member = Mock()
209+
mock_member.name = "safe/file.txt"
210+
mock_member.issym = Mock(return_value=False)
211+
mock_member.islnk = Mock(return_value=False)
212+
213+
mock_tar = Mock()
214+
mock_tar.extractall = Mock()
215+
mock_tar.getmembers = Mock(return_value=[mock_member])
216+
extract_path = "/tmp/extract"
217+
218+
with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile:
219+
# Remove data_filter attribute to simulate Python < 3.12
220+
if hasattr(mock_tarfile, 'data_filter'):
221+
delattr(mock_tarfile, 'data_filter')
222+
223+
with patch('sagemaker.core.common_utils._get_safe_members') as mock_safe:
224+
mock_safe.return_value = [mock_member]
225+
226+
with patch('sagemaker.core.common_utils._validate_extracted_paths'):
227+
custom_extractall_tarfile(mock_tar, extract_path)
228+
229+
# Verify getmembers() was called (not iterating over tar directly)
230+
mock_tar.getmembers.assert_called_once()
231+
232+
# Verify _get_safe_members was called with the members list and resolved base
233+
mock_safe.assert_called_once()
234+
call_args = mock_safe.call_args
235+
assert call_args[0][0] == [mock_member] # members list
236+
# base should be resolved extract_path, not cwd
237+
expected_base = _get_resolved_path(extract_path)
238+
assert call_args[0][1] == expected_base
239+
240+
mock_tar.extractall.assert_called_once()
241+
call_kwargs = mock_tar.extractall.call_args[1]
242+
assert call_kwargs['path'] == extract_path
243+
assert 'members' in call_kwargs

0 commit comments

Comments
 (0)