@@ -269,27 +269,44 @@ def test_get_training_code_hash_with_source_dir(self):
269269 with tempfile .TemporaryDirectory () as temp_dir :
270270 entry_file = Path (temp_dir , "train.py" )
271271 entry_file .write_text ("print('training')" )
272+ requirements_file = Path (temp_dir , "requirements.txt" )
273+ requirements_file .write_text ("numpy==1.21.0" )
272274
273- result = get_training_code_hash (
274- entry_point = str (entry_file ), source_dir = temp_dir , dependencies = []
275+ result_no_deps = get_training_code_hash (
276+ entry_point = str (entry_file ), source_dir = temp_dir , dependencies = None
277+ )
278+ result_with_deps = get_training_code_hash (
279+ entry_point = str (entry_file ), source_dir = temp_dir , dependencies = str (requirements_file )
275280 )
276281
277- assert result is not None
278- assert len (result ) == 64
282+ assert result_no_deps is not None
283+ assert result_with_deps is not None
284+ assert len (result_no_deps ) == 64
285+ assert len (result_with_deps ) == 64
286+ assert result_no_deps != result_with_deps
279287
280288 def test_get_training_code_hash_entry_point_only (self ):
281289 """Test get_training_code_hash with entry_point only"""
282- with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" , delete = False ) as f :
283- f .write ("print('training')" )
284- temp_file = f .name
290+ with tempfile .TemporaryDirectory () as temp_dir :
291+ entry_file = Path (temp_dir , "train.py" )
292+ entry_file .write_text ("print('training')" )
293+ requirements_file = Path (temp_dir , "requirements.txt" )
294+ requirements_file .write_text ("numpy==1.21.0" )
285295
286- try :
287- result = get_training_code_hash (entry_point = temp_file , source_dir = None , dependencies = [])
296+ # Without dependencies
297+ result_no_deps = get_training_code_hash (
298+ entry_point = str (entry_file ), source_dir = None , dependencies = None
299+ )
300+ # With dependencies
301+ result_with_deps = get_training_code_hash (
302+ entry_point = str (entry_file ), source_dir = None , dependencies = str (requirements_file )
303+ )
288304
289- assert result is not None
290- assert len (result ) == 64
291- finally :
292- os .unlink (temp_file )
305+ assert result_no_deps is not None
306+ assert result_with_deps is not None
307+ assert len (result_no_deps ) == 64
308+ assert len (result_with_deps ) == 64
309+ assert result_no_deps != result_with_deps
293310
294311 def test_get_training_code_hash_s3_uri (self ):
295312 """Test get_training_code_hash with S3 URI returns None"""
0 commit comments