1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313
14- import json
1514import os
1615import pytest
17- from unittest .mock import Mock , patch , MagicMock
16+ from unittest .mock import Mock , patch
1817
1918from sagemaker .ai_registry .dataset_hub_factory import DataSetHubFactory
20- from sagemaker .ai_registry .dataset import DataSet
2119from sagemaker .ai_registry .dataset_transformation import DatasetFormat
2220from sagemaker .ai_registry .dataset_utils import CustomizationTechnique
2321from sagemaker .ai_registry .air_constants import HubContentStatus
2422
2523
2624def _make_dataset (** overrides ):
27- """Helper to create a DataSet instance with sensible defaults."""
28- defaults = dict (
29- name = "test-ds" ,
30- arn = "arn:aws:sagemaker:us-east-1:123456789012:hub-content/test-ds" ,
31- version = "1.0.0" ,
32- status = HubContentStatus .AVAILABLE ,
33- source = "s3://bucket/datasets/test-ds/data.jsonl" ,
34- description = "test" ,
35- customization_technique = CustomizationTechnique .SFT ,
36- )
37- defaults .update (overrides )
38- return DataSet (** defaults )
25+ """Helper to create a mock DataSet with sensible defaults."""
26+ ds = Mock ()
27+ ds .name = overrides .get ("name" , "test-ds" )
28+ ds .arn = overrides .get ("arn" , "arn:aws:sagemaker:us-east-1:123456789012:hub-content/test-ds" )
29+ ds .version = overrides .get ("version" , "1.0.0" )
30+ ds .status = overrides .get ("status" , HubContentStatus .AVAILABLE )
31+ ds .source = overrides .get ("source" , "s3://bucket/datasets/test-ds/data.jsonl" )
32+ ds .description = overrides .get ("description" , "test" )
33+ ds .customization_technique = overrides .get ("customization_technique" , CustomizationTechnique .SFT )
34+ return ds
3935
4036
4137class TestResolveDataset :
@@ -44,14 +40,14 @@ def test_resolve_with_dataset_instance(self):
4440 result = DataSetHubFactory ._resolve_dataset (ds )
4541 assert result is ds
4642
47- @patch ("sagemaker.ai_registry.dataset_hub_factory .DataSet.get" )
43+ @patch ("sagemaker.ai_registry.dataset .DataSet.get" )
4844 def test_resolve_with_name_string (self , mock_get ):
4945 mock_get .return_value = _make_dataset ()
5046 result = DataSetHubFactory ._resolve_dataset ("my-dataset" )
5147 mock_get .assert_called_once_with (name = "my-dataset" , sagemaker_session = None )
5248 assert result .name == "test-ds"
5349
54- @patch ("sagemaker.ai_registry.dataset_hub_factory .DataSet.get" )
50+ @patch ("sagemaker.ai_registry.dataset .DataSet.get" )
5551 def test_resolve_with_session (self , mock_get ):
5652 mock_session = Mock ()
5753 mock_get .return_value = _make_dataset ()
@@ -60,7 +56,7 @@ def test_resolve_with_session(self, mock_get):
6056
6157
6258class TestDownloadToLocal :
63- @patch ("sagemaker.ai_registry.dataset_hub_factory .AIRHub.download_from_s3" )
59+ @patch ("sagemaker.ai_registry.air_hub .AIRHub.download_from_s3" )
6460 def test_download_creates_temp_file (self , mock_download ):
6561 path = DataSetHubFactory ._download_to_local ("s3://bucket/data.jsonl" )
6662 try :
@@ -70,7 +66,7 @@ def test_download_creates_temp_file(self, mock_download):
7066 if os .path .exists (path ):
7167 os .unlink (path )
7268
73- @patch ("sagemaker.ai_registry.dataset_hub_factory .AIRHub.download_from_s3" )
69+ @patch ("sagemaker.ai_registry.air_hub .AIRHub.download_from_s3" )
7470 def test_download_default_suffix (self , mock_download ):
7571 path = DataSetHubFactory ._download_to_local ("s3://bucket/data" )
7672 try :
@@ -81,11 +77,15 @@ def test_download_default_suffix(self, mock_download):
8177
8278
8379class TestTransformDatasetValidation :
84- def test_neither_source_nor_dataset_raises (self ):
80+ @patch ("sagemaker.ai_registry.dataset_transformation.DatasetTransformation" )
81+ @patch ("sagemaker.ai_registry.dataset.DataSet" )
82+ def test_neither_source_nor_dataset_raises (self , mock_ds , mock_transform ):
8583 with pytest .raises (ValueError , match = "Exactly one of" ):
8684 DataSetHubFactory .transform_dataset (name = "test" , target_format = DatasetFormat .GENQA )
8785
88- def test_both_source_and_dataset_raises (self ):
86+ @patch ("sagemaker.ai_registry.dataset_transformation.DatasetTransformation" )
87+ @patch ("sagemaker.ai_registry.dataset.DataSet" )
88+ def test_both_source_and_dataset_raises (self , mock_ds , mock_transform ):
8989 with pytest .raises (ValueError , match = "Exactly one of" ):
9090 DataSetHubFactory .transform_dataset (
9191 name = "test" ,
@@ -96,12 +96,12 @@ def test_both_source_and_dataset_raises(self):
9696
9797
9898class TestTransformDatasetFromSource :
99- @patch ("sagemaker.ai_registry.dataset_hub_factory .DataSet.create " )
100- @patch ("sagemaker.ai_registry.dataset_hub_factory .DatasetTransformation" )
101- def test_local_file_source (self , mock_transformation_cls , mock_create ):
99+ @patch ("sagemaker.ai_registry.dataset .DataSet" )
100+ @patch ("sagemaker.ai_registry.dataset_transformation .DatasetTransformation" )
101+ def test_local_file_source (self , mock_transformation_cls , mock_dataset_cls ):
102102 mock_transformation_cls .detect_format .return_value = DatasetFormat .OPENAI_CHAT
103103 mock_transformation_cls .transform_file .return_value = "/tmp/transformed.jsonl"
104- mock_create .return_value = _make_dataset ()
104+ mock_dataset_cls . create .return_value = _make_dataset ()
105105
106106 result = DataSetHubFactory .transform_dataset (
107107 name = "new-ds" ,
@@ -116,26 +116,26 @@ def test_local_file_source(self, mock_transformation_cls, mock_create):
116116 source_format = DatasetFormat .OPENAI_CHAT ,
117117 target_format = DatasetFormat .GENQA ,
118118 )
119- mock_create .assert_called_once_with (
119+ mock_dataset_cls . create .assert_called_once_with (
120120 name = "new-ds" ,
121121 source = "/tmp/transformed.jsonl" ,
122122 customization_technique = CustomizationTechnique .SFT ,
123123 sagemaker_session = None ,
124124 )
125125 assert result .name == "test-ds"
126126
127- @patch ("sagemaker.ai_registry.dataset_hub_factory .DataSet.create " )
128- @patch ("sagemaker.ai_registry.dataset_hub_factory .DatasetTransformation" )
129- @patch ( "sagemaker.ai_registry.dataset_hub_factory. DataSetHubFactory. _download_to_local" )
127+ @patch ("sagemaker.ai_registry.dataset .DataSet" )
128+ @patch ("sagemaker.ai_registry.dataset_transformation .DatasetTransformation" )
129+ @patch . object ( DataSetHubFactory , " _download_to_local" )
130130 @patch ("os.path.exists" , return_value = True )
131131 @patch ("os.remove" )
132132 def test_s3_source_downloads_and_cleans_up (
133- self , mock_remove , mock_exists , mock_download , mock_transformation_cls , mock_create
133+ self , mock_remove , mock_exists , mock_download , mock_transformation_cls , mock_dataset_cls
134134 ):
135135 mock_download .return_value = "/tmp/downloaded.jsonl"
136136 mock_transformation_cls .detect_format .return_value = DatasetFormat .HF_PROMPT_COMPLETION
137137 mock_transformation_cls .transform_file .return_value = "/tmp/transformed.jsonl"
138- mock_create .return_value = _make_dataset ()
138+ mock_dataset_cls . create .return_value = _make_dataset ()
139139
140140 result = DataSetHubFactory .transform_dataset (
141141 name = "new-ds" ,
@@ -147,12 +147,13 @@ def test_s3_source_downloads_and_cleans_up(
147147 mock_transformation_cls .detect_format .assert_called_once_with ("/tmp/downloaded.jsonl" )
148148 mock_remove .assert_called_once_with ("/tmp/downloaded.jsonl" )
149149
150- @patch ("sagemaker.ai_registry.dataset_hub_factory.DatasetTransformation" )
151- @patch ("sagemaker.ai_registry.dataset_hub_factory.DataSetHubFactory._download_to_local" )
150+ @patch ("sagemaker.ai_registry.dataset_transformation.DatasetTransformation" )
151+ @patch ("sagemaker.ai_registry.dataset.DataSet" )
152+ @patch .object (DataSetHubFactory , "_download_to_local" )
152153 @patch ("os.path.exists" , return_value = True )
153154 @patch ("os.remove" )
154155 def test_s3_source_cleans_up_on_error (
155- self , mock_remove , mock_exists , mock_download , mock_transformation_cls
156+ self , mock_remove , mock_exists , mock_download , mock_ds , mock_transformation_cls
156157 ):
157158 mock_download .return_value = "/tmp/downloaded.jsonl"
158159 mock_transformation_cls .detect_format .side_effect = ValueError ("bad format" )
@@ -168,19 +169,19 @@ def test_s3_source_cleans_up_on_error(
168169
169170
170171class TestTransformDatasetFromExisting :
171- @patch ("sagemaker.ai_registry.dataset_hub_factory .DataSet.get " )
172- @patch ("sagemaker.ai_registry.dataset_hub_factory .DatasetTransformation" )
173- @patch ( "sagemaker.ai_registry.dataset_hub_factory. DataSetHubFactory. _download_to_local" )
172+ @patch ("sagemaker.ai_registry.dataset .DataSet" )
173+ @patch ("sagemaker.ai_registry.dataset_transformation .DatasetTransformation" )
174+ @patch . object ( DataSetHubFactory , " _download_to_local" )
174175 @patch ("os.path.exists" , return_value = True )
175176 @patch ("os.remove" )
176177 def test_existing_dataset_instance (
177- self , mock_remove , mock_exists , mock_download , mock_transformation_cls , mock_get
178+ self , mock_remove , mock_exists , mock_download , mock_transformation_cls , mock_dataset_cls
178179 ):
179180 ds = _make_dataset ()
180181 mock_download .return_value = "/tmp/downloaded.jsonl"
181182 mock_transformation_cls .detect_format .return_value = DatasetFormat .VERL
182183 mock_transformation_cls .transform_file .return_value = "/tmp/transformed.jsonl"
183- mock_get .return_value = _make_dataset (version = "2.0.0" )
184+ mock_dataset_cls . get .return_value = _make_dataset (version = "2.0.0" )
184185
185186 result = DataSetHubFactory .transform_dataset (
186187 name = "test-ds" ,
@@ -197,10 +198,10 @@ def test_existing_dataset_instance(
197198 )
198199 assert result .version == "2.0.0"
199200
200- @patch ("sagemaker.ai_registry.dataset_hub_factory .DataSet.get " )
201- @patch ("sagemaker.ai_registry.dataset_hub_factory .DatasetTransformation" )
202- @patch ( "sagemaker.ai_registry.dataset_hub_factory. DataSetHubFactory. _download_to_local" )
203- @patch ( "sagemaker.ai_registry.dataset_hub_factory. DataSetHubFactory. _resolve_dataset" )
201+ @patch ("sagemaker.ai_registry.dataset .DataSet" )
202+ @patch ("sagemaker.ai_registry.dataset_transformation .DatasetTransformation" )
203+ @patch . object ( DataSetHubFactory , " _download_to_local" )
204+ @patch . object ( DataSetHubFactory , " _resolve_dataset" )
204205 @patch ("os.path.exists" , return_value = True )
205206 @patch ("os.remove" )
206207 def test_existing_dataset_by_name (
@@ -210,14 +211,14 @@ def test_existing_dataset_by_name(
210211 mock_resolve ,
211212 mock_download ,
212213 mock_transformation_cls ,
213- mock_get ,
214+ mock_dataset_cls ,
214215 ):
215216 resolved_ds = _make_dataset ()
216217 mock_resolve .return_value = resolved_ds
217218 mock_download .return_value = "/tmp/downloaded.jsonl"
218219 mock_transformation_cls .detect_format .return_value = DatasetFormat .OPENAI_CHAT
219220 mock_transformation_cls .transform_file .return_value = "/tmp/transformed.jsonl"
220- mock_get .return_value = _make_dataset (version = "2.0.0" )
221+ mock_dataset_cls . get .return_value = _make_dataset (version = "2.0.0" )
221222
222223 result = DataSetHubFactory .transform_dataset (
223224 name = "test-ds" ,
@@ -228,12 +229,13 @@ def test_existing_dataset_by_name(
228229 mock_resolve .assert_called_once_with ("test-ds" , None )
229230 mock_download .assert_called_once_with (resolved_ds .source )
230231
231- @patch ("sagemaker.ai_registry.dataset_hub_factory.DatasetTransformation" )
232- @patch ("sagemaker.ai_registry.dataset_hub_factory.DataSetHubFactory._download_to_local" )
232+ @patch ("sagemaker.ai_registry.dataset_transformation.DatasetTransformation" )
233+ @patch ("sagemaker.ai_registry.dataset.DataSet" )
234+ @patch .object (DataSetHubFactory , "_download_to_local" )
233235 @patch ("os.path.exists" , return_value = True )
234236 @patch ("os.remove" )
235237 def test_existing_dataset_cleans_up_on_error (
236- self , mock_remove , mock_exists , mock_download , mock_transformation_cls
238+ self , mock_remove , mock_exists , mock_download , mock_ds , mock_transformation_cls
237239 ):
238240 ds = _make_dataset ()
239241 mock_download .return_value = "/tmp/downloaded.jsonl"
0 commit comments