1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313"""Tests for ModelTrainer dependencies feature."""
14- from __future__ import absolute_import
14+ from __future__ import annotations
1515
1616import os
1717import shutil
5353)
5454
5555
56- @pytest .fixture (scope = "module" , autouse = True )
56+ @pytest .fixture (autouse = True )
5757def modules_session ():
5858 with patch ("sagemaker.train.Session" , spec = Session ) as session_mock :
5959 session_instance = session_mock .return_value
@@ -107,6 +107,27 @@ def test_validate_source_code_with_invalid_dependency_path_raises_value_error(mo
107107 )
108108
109109
110+ def test_validate_source_code_with_file_dependency_raises_value_error (modules_session ):
111+ """Verify _validate_source_code raises ValueError when a dependency is a file, not a directory."""
112+ dep_file = tempfile .NamedTemporaryFile (suffix = ".py" , delete = False )
113+ dep_file .close ()
114+ try :
115+ with pytest .raises (ValueError , match = "Invalid dependency path" ):
116+ ModelTrainer (
117+ training_image = DEFAULT_IMAGE ,
118+ role = DEFAULT_ROLE ,
119+ sagemaker_session = modules_session ,
120+ compute = DEFAULT_COMPUTE_CONFIG ,
121+ source_code = SourceCode (
122+ source_dir = DEFAULT_SOURCE_DIR ,
123+ entry_script = "custom_script.py" ,
124+ dependencies = [dep_file .name ],
125+ ),
126+ )
127+ finally :
128+ os .unlink (dep_file .name )
129+
130+
110131def test_source_code_dependencies_validation_with_valid_dirs (modules_session ):
111132 """Create SourceCode with dependencies pointing to multiple valid directories."""
112133 dep_dir1 = tempfile .mkdtemp ()
@@ -250,23 +271,39 @@ def test_train_with_dependencies_generates_pythonpath_setup_in_train_script(
250271 dependencies = [dep_dir ],
251272 ),
252273 )
253- trainer .train ()
254-
255- # Check the generated train script in the temp directory
256- assert trainer ._temp_code_dir is not None or True # temp dir may be cleaned up
257- # The key assertion is that the template was used - check via the training job call
258- mock_training_job .create .assert_called_once ()
274+ # Call _create_training_job_args to generate the train script without cleanup
275+ trainer ._create_training_job_args ()
276+
277+ # Read the generated train script and verify it contains PYTHONPATH setup
278+ assert trainer ._temp_code_dir is not None
279+ train_script_path = os .path .join (trainer ._temp_code_dir .name , "sm_train.sh" )
280+ with open (train_script_path ) as f :
281+ script_content = f .read ()
282+ assert "sm_dependencies" in script_content
283+ assert "PYTHONPATH" in script_content
284+ assert "Setting up additional dependencies" in script_content
259285 finally :
286+ if trainer ._temp_code_dir is not None :
287+ trainer ._temp_code_dir .cleanup ()
288+ if trainer ._temp_deps_dir is not None :
289+ trainer ._temp_deps_dir .cleanup ()
260290 shutil .rmtree (dep_dir )
261291
262292
263- def test_dependencies_copied_to_temp_dir_preserving_basenames (modules_session ):
293+ @patch ("sagemaker.train.model_trainer.TrainingJob" )
294+ def test_dependencies_copied_to_temp_dir_preserving_basenames (
295+ mock_training_job , modules_session
296+ ):
264297 """Verify that each dependency directory's basename is preserved when copied."""
265298 dep_dir = tempfile .mkdtemp (suffix = "_mylib" )
266299 sub_file = os .path .join (dep_dir , "module.py" )
267300 with open (sub_file , "w" ) as f :
268301 f .write ("# test module" )
269302
303+ modules_session .upload_data .return_value = (
304+ f"s3://{ DEFAULT_BUCKET } /prefix/sm_dependencies"
305+ )
306+
270307 try :
271308 trainer = ModelTrainer (
272309 training_image = DEFAULT_IMAGE ,
@@ -281,9 +318,22 @@ def test_dependencies_copied_to_temp_dir_preserving_basenames(modules_session):
281318 dependencies = [dep_dir ],
282319 ),
283320 )
284- # Verify the source_code has the dependencies set
285- assert trainer .source_code .dependencies == [dep_dir ]
321+ # Call _create_training_job_args to trigger the copy
322+ trainer ._create_training_job_args ()
323+
324+ # Verify the dependencies were copied preserving basenames
325+ assert trainer ._temp_deps_dir is not None
286326 dep_basename = os .path .basename (os .path .normpath (dep_dir ))
287- assert dep_basename .endswith ("_mylib" )
327+ copied_dep_path = os .path .join (trainer ._temp_deps_dir .name , dep_basename )
328+ assert os .path .isdir (copied_dep_path ), (
329+ f"Expected dependency directory { copied_dep_path } to exist"
330+ )
331+ assert os .path .isfile (os .path .join (copied_dep_path , "module.py" )), (
332+ "Expected module.py to be copied into the dependency directory"
333+ )
288334 finally :
335+ if trainer ._temp_code_dir is not None :
336+ trainer ._temp_code_dir .cleanup ()
337+ if trainer ._temp_deps_dir is not None :
338+ trainer ._temp_deps_dir .cleanup ()
289339 shutil .rmtree (dep_dir )
0 commit comments