Skip to content

Commit f2a1b5c

Browse files
committed
fix(tests): correct mock paths for lazy imports in hub factory and transformation tests
Fixed 20 failing tests across test_dataset_hub_factory.py and test_dataset_transformation.py. Root cause: lazy imports inside methods meant mock patches on the hub_factory module couldn't find the target attributes. Patched at source modules instead (dataset.DataSet, dataset_transformation.DatasetTransformation, air_hub.AIRHub). Also fixed test_unsupported_format_raises to use DPO (not in inbound converters) instead of converse (which is), and fixed detect_format tests to mock DatasetFormatDetector at its source module. --- X-AI-Prompt: fix failing unit tests for dataset_hub_factory and dataset_transformation X-AI-Tool: Kiro
1 parent 3d8f0e8 commit f2a1b5c

3 files changed

Lines changed: 70 additions & 64 deletions

File tree

sagemaker-train/src/sagemaker/ai_registry/dataset_hub_factory.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
import logging
1818
import os
1919
import tempfile
20-
from typing import Dict, List, Optional, Tuple, Union
20+
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
2121

22-
from sagemaker.ai_registry.dataset import DataSet
2322
from sagemaker.ai_registry.dataset_transformation import DatasetFormat
2423
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique
25-
from sagemaker.ai_registry.air_hub import AIRHub
2624
from sagemaker.core.helper.session_helper import Session
2725

26+
if TYPE_CHECKING:
27+
from sagemaker.ai_registry.dataset import DataSet
28+
2829
logger = logging.getLogger(__name__)
2930

3031

@@ -52,9 +53,10 @@ def _resolve_dataset(
5253
Returns:
5354
A hydrated DataSet instance.
5455
"""
55-
if isinstance(dataset, DataSet):
56-
return dataset
57-
return DataSet.get(name=dataset, sagemaker_session=sagemaker_session)
56+
if isinstance(dataset, str):
57+
from sagemaker.ai_registry.dataset import DataSet
58+
return DataSet.get(name=dataset, sagemaker_session=sagemaker_session)
59+
return dataset
5860

5961
@classmethod
6062
def _download_to_local(cls, s3_uri: str) -> str:
@@ -69,6 +71,7 @@ def _download_to_local(cls, s3_uri: str) -> str:
6971
suffix = os.path.splitext(s3_uri)[-1] or ".jsonl"
7072
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
7173
tmp.close()
74+
from sagemaker.ai_registry.air_hub import AIRHub
7275
AIRHub.download_from_s3(s3_uri, tmp.name)
7376
return tmp.name
7477

@@ -106,6 +109,7 @@ def transform_dataset(
106109
ValueError: If neither or both of source and dataset are provided.
107110
"""
108111
from sagemaker.ai_registry.dataset_transformation import DatasetTransformation
112+
from sagemaker.ai_registry.dataset import DataSet
109113

110114
if (source is None) == (dataset is None):
111115
raise ValueError(

sagemaker-train/tests/unit/ai_registry/test_dataset_hub_factory.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,27 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313

14-
import json
1514
import os
1615
import pytest
17-
from unittest.mock import Mock, patch, MagicMock
16+
from unittest.mock import Mock, patch
1817

1918
from sagemaker.ai_registry.dataset_hub_factory import DataSetHubFactory
20-
from sagemaker.ai_registry.dataset import DataSet
2119
from sagemaker.ai_registry.dataset_transformation import DatasetFormat
2220
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique
2321
from sagemaker.ai_registry.air_constants import HubContentStatus
2422

2523

2624
def _make_dataset(**overrides):
27-
"""Helper to create a DataSet instance with sensible defaults."""
28-
defaults = dict(
29-
name="test-ds",
30-
arn="arn:aws:sagemaker:us-east-1:123456789012:hub-content/test-ds",
31-
version="1.0.0",
32-
status=HubContentStatus.AVAILABLE,
33-
source="s3://bucket/datasets/test-ds/data.jsonl",
34-
description="test",
35-
customization_technique=CustomizationTechnique.SFT,
36-
)
37-
defaults.update(overrides)
38-
return DataSet(**defaults)
25+
"""Helper to create a mock DataSet with sensible defaults."""
26+
ds = Mock()
27+
ds.name = overrides.get("name", "test-ds")
28+
ds.arn = overrides.get("arn", "arn:aws:sagemaker:us-east-1:123456789012:hub-content/test-ds")
29+
ds.version = overrides.get("version", "1.0.0")
30+
ds.status = overrides.get("status", HubContentStatus.AVAILABLE)
31+
ds.source = overrides.get("source", "s3://bucket/datasets/test-ds/data.jsonl")
32+
ds.description = overrides.get("description", "test")
33+
ds.customization_technique = overrides.get("customization_technique", CustomizationTechnique.SFT)
34+
return ds
3935

4036

4137
class TestResolveDataset:
@@ -44,14 +40,14 @@ def test_resolve_with_dataset_instance(self):
4440
result = DataSetHubFactory._resolve_dataset(ds)
4541
assert result is ds
4642

47-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSet.get")
43+
@patch("sagemaker.ai_registry.dataset.DataSet.get")
4844
def test_resolve_with_name_string(self, mock_get):
4945
mock_get.return_value = _make_dataset()
5046
result = DataSetHubFactory._resolve_dataset("my-dataset")
5147
mock_get.assert_called_once_with(name="my-dataset", sagemaker_session=None)
5248
assert result.name == "test-ds"
5349

54-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSet.get")
50+
@patch("sagemaker.ai_registry.dataset.DataSet.get")
5551
def test_resolve_with_session(self, mock_get):
5652
mock_session = Mock()
5753
mock_get.return_value = _make_dataset()
@@ -60,7 +56,7 @@ def test_resolve_with_session(self, mock_get):
6056

6157

6258
class TestDownloadToLocal:
63-
@patch("sagemaker.ai_registry.dataset_hub_factory.AIRHub.download_from_s3")
59+
@patch("sagemaker.ai_registry.air_hub.AIRHub.download_from_s3")
6460
def test_download_creates_temp_file(self, mock_download):
6561
path = DataSetHubFactory._download_to_local("s3://bucket/data.jsonl")
6662
try:
@@ -70,7 +66,7 @@ def test_download_creates_temp_file(self, mock_download):
7066
if os.path.exists(path):
7167
os.unlink(path)
7268

73-
@patch("sagemaker.ai_registry.dataset_hub_factory.AIRHub.download_from_s3")
69+
@patch("sagemaker.ai_registry.air_hub.AIRHub.download_from_s3")
7470
def test_download_default_suffix(self, mock_download):
7571
path = DataSetHubFactory._download_to_local("s3://bucket/data")
7672
try:
@@ -81,11 +77,15 @@ def test_download_default_suffix(self, mock_download):
8177

8278

8379
class TestTransformDatasetValidation:
84-
def test_neither_source_nor_dataset_raises(self):
80+
@patch("sagemaker.ai_registry.dataset_transformation.DatasetTransformation")
81+
@patch("sagemaker.ai_registry.dataset.DataSet")
82+
def test_neither_source_nor_dataset_raises(self, mock_ds, mock_transform):
8583
with pytest.raises(ValueError, match="Exactly one of"):
8684
DataSetHubFactory.transform_dataset(name="test", target_format=DatasetFormat.GENQA)
8785

88-
def test_both_source_and_dataset_raises(self):
86+
@patch("sagemaker.ai_registry.dataset_transformation.DatasetTransformation")
87+
@patch("sagemaker.ai_registry.dataset.DataSet")
88+
def test_both_source_and_dataset_raises(self, mock_ds, mock_transform):
8989
with pytest.raises(ValueError, match="Exactly one of"):
9090
DataSetHubFactory.transform_dataset(
9191
name="test",
@@ -96,12 +96,12 @@ def test_both_source_and_dataset_raises(self):
9696

9797

9898
class TestTransformDatasetFromSource:
99-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSet.create")
100-
@patch("sagemaker.ai_registry.dataset_hub_factory.DatasetTransformation")
101-
def test_local_file_source(self, mock_transformation_cls, mock_create):
99+
@patch("sagemaker.ai_registry.dataset.DataSet")
100+
@patch("sagemaker.ai_registry.dataset_transformation.DatasetTransformation")
101+
def test_local_file_source(self, mock_transformation_cls, mock_dataset_cls):
102102
mock_transformation_cls.detect_format.return_value = DatasetFormat.OPENAI_CHAT
103103
mock_transformation_cls.transform_file.return_value = "/tmp/transformed.jsonl"
104-
mock_create.return_value = _make_dataset()
104+
mock_dataset_cls.create.return_value = _make_dataset()
105105

106106
result = DataSetHubFactory.transform_dataset(
107107
name="new-ds",
@@ -116,26 +116,26 @@ def test_local_file_source(self, mock_transformation_cls, mock_create):
116116
source_format=DatasetFormat.OPENAI_CHAT,
117117
target_format=DatasetFormat.GENQA,
118118
)
119-
mock_create.assert_called_once_with(
119+
mock_dataset_cls.create.assert_called_once_with(
120120
name="new-ds",
121121
source="/tmp/transformed.jsonl",
122122
customization_technique=CustomizationTechnique.SFT,
123123
sagemaker_session=None,
124124
)
125125
assert result.name == "test-ds"
126126

127-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSet.create")
128-
@patch("sagemaker.ai_registry.dataset_hub_factory.DatasetTransformation")
129-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSetHubFactory._download_to_local")
127+
@patch("sagemaker.ai_registry.dataset.DataSet")
128+
@patch("sagemaker.ai_registry.dataset_transformation.DatasetTransformation")
129+
@patch.object(DataSetHubFactory, "_download_to_local")
130130
@patch("os.path.exists", return_value=True)
131131
@patch("os.remove")
132132
def test_s3_source_downloads_and_cleans_up(
133-
self, mock_remove, mock_exists, mock_download, mock_transformation_cls, mock_create
133+
self, mock_remove, mock_exists, mock_download, mock_transformation_cls, mock_dataset_cls
134134
):
135135
mock_download.return_value = "/tmp/downloaded.jsonl"
136136
mock_transformation_cls.detect_format.return_value = DatasetFormat.HF_PROMPT_COMPLETION
137137
mock_transformation_cls.transform_file.return_value = "/tmp/transformed.jsonl"
138-
mock_create.return_value = _make_dataset()
138+
mock_dataset_cls.create.return_value = _make_dataset()
139139

140140
result = DataSetHubFactory.transform_dataset(
141141
name="new-ds",
@@ -147,12 +147,13 @@ def test_s3_source_downloads_and_cleans_up(
147147
mock_transformation_cls.detect_format.assert_called_once_with("/tmp/downloaded.jsonl")
148148
mock_remove.assert_called_once_with("/tmp/downloaded.jsonl")
149149

150-
@patch("sagemaker.ai_registry.dataset_hub_factory.DatasetTransformation")
151-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSetHubFactory._download_to_local")
150+
@patch("sagemaker.ai_registry.dataset_transformation.DatasetTransformation")
151+
@patch("sagemaker.ai_registry.dataset.DataSet")
152+
@patch.object(DataSetHubFactory, "_download_to_local")
152153
@patch("os.path.exists", return_value=True)
153154
@patch("os.remove")
154155
def test_s3_source_cleans_up_on_error(
155-
self, mock_remove, mock_exists, mock_download, mock_transformation_cls
156+
self, mock_remove, mock_exists, mock_download, mock_ds, mock_transformation_cls
156157
):
157158
mock_download.return_value = "/tmp/downloaded.jsonl"
158159
mock_transformation_cls.detect_format.side_effect = ValueError("bad format")
@@ -168,19 +169,19 @@ def test_s3_source_cleans_up_on_error(
168169

169170

170171
class TestTransformDatasetFromExisting:
171-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSet.get")
172-
@patch("sagemaker.ai_registry.dataset_hub_factory.DatasetTransformation")
173-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSetHubFactory._download_to_local")
172+
@patch("sagemaker.ai_registry.dataset.DataSet")
173+
@patch("sagemaker.ai_registry.dataset_transformation.DatasetTransformation")
174+
@patch.object(DataSetHubFactory, "_download_to_local")
174175
@patch("os.path.exists", return_value=True)
175176
@patch("os.remove")
176177
def test_existing_dataset_instance(
177-
self, mock_remove, mock_exists, mock_download, mock_transformation_cls, mock_get
178+
self, mock_remove, mock_exists, mock_download, mock_transformation_cls, mock_dataset_cls
178179
):
179180
ds = _make_dataset()
180181
mock_download.return_value = "/tmp/downloaded.jsonl"
181182
mock_transformation_cls.detect_format.return_value = DatasetFormat.VERL
182183
mock_transformation_cls.transform_file.return_value = "/tmp/transformed.jsonl"
183-
mock_get.return_value = _make_dataset(version="2.0.0")
184+
mock_dataset_cls.get.return_value = _make_dataset(version="2.0.0")
184185

185186
result = DataSetHubFactory.transform_dataset(
186187
name="test-ds",
@@ -197,10 +198,10 @@ def test_existing_dataset_instance(
197198
)
198199
assert result.version == "2.0.0"
199200

200-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSet.get")
201-
@patch("sagemaker.ai_registry.dataset_hub_factory.DatasetTransformation")
202-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSetHubFactory._download_to_local")
203-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSetHubFactory._resolve_dataset")
201+
@patch("sagemaker.ai_registry.dataset.DataSet")
202+
@patch("sagemaker.ai_registry.dataset_transformation.DatasetTransformation")
203+
@patch.object(DataSetHubFactory, "_download_to_local")
204+
@patch.object(DataSetHubFactory, "_resolve_dataset")
204205
@patch("os.path.exists", return_value=True)
205206
@patch("os.remove")
206207
def test_existing_dataset_by_name(
@@ -210,14 +211,14 @@ def test_existing_dataset_by_name(
210211
mock_resolve,
211212
mock_download,
212213
mock_transformation_cls,
213-
mock_get,
214+
mock_dataset_cls,
214215
):
215216
resolved_ds = _make_dataset()
216217
mock_resolve.return_value = resolved_ds
217218
mock_download.return_value = "/tmp/downloaded.jsonl"
218219
mock_transformation_cls.detect_format.return_value = DatasetFormat.OPENAI_CHAT
219220
mock_transformation_cls.transform_file.return_value = "/tmp/transformed.jsonl"
220-
mock_get.return_value = _make_dataset(version="2.0.0")
221+
mock_dataset_cls.get.return_value = _make_dataset(version="2.0.0")
221222

222223
result = DataSetHubFactory.transform_dataset(
223224
name="test-ds",
@@ -228,12 +229,13 @@ def test_existing_dataset_by_name(
228229
mock_resolve.assert_called_once_with("test-ds", None)
229230
mock_download.assert_called_once_with(resolved_ds.source)
230231

231-
@patch("sagemaker.ai_registry.dataset_hub_factory.DatasetTransformation")
232-
@patch("sagemaker.ai_registry.dataset_hub_factory.DataSetHubFactory._download_to_local")
232+
@patch("sagemaker.ai_registry.dataset_transformation.DatasetTransformation")
233+
@patch("sagemaker.ai_registry.dataset.DataSet")
234+
@patch.object(DataSetHubFactory, "_download_to_local")
233235
@patch("os.path.exists", return_value=True)
234236
@patch("os.remove")
235237
def test_existing_dataset_cleans_up_on_error(
236-
self, mock_remove, mock_exists, mock_download, mock_transformation_cls
238+
self, mock_remove, mock_exists, mock_download, mock_ds, mock_transformation_cls
237239
):
238240
ds = _make_dataset()
239241
mock_download.return_value = "/tmp/downloaded.jsonl"

sagemaker-train/tests/unit/ai_registry/test_dataset_transformation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def test_openai_dispatch(self):
332332

333333
def test_unsupported_format_raises(self):
334334
with pytest.raises(ValueError, match="Unsupported dataset format"):
335-
convert_to_genqa({}, DatasetFormat("converse"))
335+
convert_to_genqa({}, DatasetFormat.DPO)
336336

337337

338338
class TestConvertFromGenqa:
@@ -347,21 +347,21 @@ def test_unimplemented_raises(self):
347347

348348

349349
class TestDatasetTransformationDetectFormat:
350-
@patch("sagemaker.ai_registry.dataset_transformation.DatasetFormatDetector")
351-
def test_detect_known_format(self, mock_detector_cls):
352-
mock_detector_cls.detect_format.return_value = "openai_chat"
350+
@patch("sagemaker.ai_registry.dataset_format_detector.DatasetFormatDetector.detect_format")
351+
def test_detect_known_format(self, mock_detect):
352+
mock_detect.return_value = "openai_chat"
353353
result = DatasetTransformation.detect_format("/path/to/file.jsonl")
354354
assert result == DatasetFormat.OPENAI_CHAT
355355

356-
@patch("sagemaker.ai_registry.dataset_transformation.DatasetFormatDetector")
357-
def test_detect_none_raises(self, mock_detector_cls):
358-
mock_detector_cls.detect_format.return_value = None
356+
@patch("sagemaker.ai_registry.dataset_format_detector.DatasetFormatDetector.detect_format")
357+
def test_detect_none_raises(self, mock_detect):
358+
mock_detect.return_value = None
359359
with pytest.raises(ValueError, match="Unable to detect dataset format"):
360360
DatasetTransformation.detect_format("/path/to/file.jsonl")
361361

362-
@patch("sagemaker.ai_registry.dataset_transformation.DatasetFormatDetector")
363-
def test_detect_unsupported_format_raises(self, mock_detector_cls):
364-
mock_detector_cls.detect_format.return_value = "rft"
362+
@patch("sagemaker.ai_registry.dataset_format_detector.DatasetFormatDetector.detect_format")
363+
def test_detect_unsupported_format_raises(self, mock_detect):
364+
mock_detect.return_value = "rft"
365365
with pytest.raises(ValueError, match="not a supported transformation format"):
366366
DatasetTransformation.detect_format("/path/to/file.jsonl")
367367

0 commit comments

Comments
 (0)