Skip to content

Commit f4eee66

Browse files
committed
fix: address review comments (iteration #2)
1 parent b8c6e49 commit f4eee66

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

sagemaker-core/src/sagemaker/core/utils/__init__.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
"""
2020
from __future__ import absolute_import
2121

22-
# Public API surface: only non-private functions are exported via __all__.
22+
# Public API surface.
23+
# Note: _save_model is underscore-prefixed but was already in __all__ (pre-existing).
24+
# custom_extractall_tarfile is the main public entry point for safe tar extraction.
2325
# Private helpers (_get_resolved_path, _is_bad_path, _is_bad_link, _get_safe_members)
24-
# are still importable directly but are not part of the public API.
26+
# are importable directly from sagemaker.core.common_utils but are not re-exported here.
2527
__all__ = [
2628
"_save_model",
2729
"download_file_from_url",
@@ -41,18 +43,10 @@
4143
"get_config_value",
4244
]
4345

44-
# Internal helpers that are importable but not part of the public API
45-
_INTERNAL_NAMES = [
46-
"_get_resolved_path",
47-
"_is_bad_path",
48-
"_is_bad_link",
49-
"_get_safe_members",
50-
]
51-
5246

5347
def __getattr__(name):
5448
"""Lazy import to avoid circular dependencies."""
55-
if name in __all__ or name in _INTERNAL_NAMES:
49+
if name in __all__:
5650
from sagemaker.core import common_utils
5751

5852
return getattr(common_utils, name)

sagemaker-core/tests/unit/test_common_utils_tar_safety.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515

1616
import os
1717
import tempfile
18-
import tarfile
1918

20-
import pytest
21-
from unittest.mock import Mock, patch, MagicMock
19+
from unittest.mock import Mock, patch
2220

2321
from sagemaker.core.common_utils import (
2422
_get_resolved_path,
@@ -204,18 +202,20 @@ def test_get_safe_members_filters_bad_hardlink_member():
204202
def test_custom_extractall_tarfile_with_data_filter_uses_filter_param():
205203
"""Test custom_extractall_tarfile uses data_filter when available.
206204
207-
We set mock_tarfile.data_filter explicitly to ensure hasattr returns True.
208-
The MagicMock would auto-create the attribute anyway, but we set it
209-
explicitly for clarity. The key assertion is that filter="data" is passed.
205+
We patch the module-level `tarfile` import in common_utils (not the `tar` parameter).
206+
Setting mock_tarfile.data_filter explicitly ensures hasattr(tarfile, 'data_filter')
207+
returns True inside custom_extractall_tarfile. The key assertion is that
208+
filter="data" is passed to tar.extractall.
210209
"""
211210
mock_tar = Mock()
212211
mock_tar.extractall = Mock()
213212

214213
with tempfile.TemporaryDirectory() as tmpdir:
215214
extract_path = os.path.join(tmpdir, "extract")
216215

216+
# Patch the module-level tarfile import in common_utils
217217
with patch("sagemaker.core.common_utils.tarfile") as mock_tarfile:
218-
# Explicitly set data_filter to ensure the hasattr check passes
218+
# Explicitly set data_filter so hasattr check passes
219219
mock_tarfile.data_filter = True
220220

221221
custom_extractall_tarfile(mock_tar, extract_path)
@@ -226,6 +226,9 @@ def test_custom_extractall_tarfile_with_data_filter_uses_filter_param():
226226
def test_custom_extractall_tarfile_without_data_filter_uses_safe_members():
227227
"""Test custom_extractall_tarfile uses safe members when data_filter is unavailable.
228228
229+
Uses spec=['TarFile'] to restrict the mock so that
230+
hasattr(mock_tarfile, 'data_filter') returns False, forcing the fallback path.
231+
229232
Verifies that:
230233
1. tar.getmembers() is called (not iterating over tar directly)
231234
2. _get_safe_members is called with the members list and resolved extract_path as base

0 commit comments

Comments
 (0)