7373 },
7474}
7575
76+ SOURCE_CODE = {
77+ "source_dir" : "code" ,
78+ "entry_script" : "train.py" ,
79+ }
80+
81+ DISTRIBUTED_CONFIG = {
82+ "process_count_per_node" : 2 ,
83+ }
84+
7685OUTPUT_FILE = os .path .join (os .path .dirname (__file__ ), "sm_training.env" )
7786
7887# flake8: noqa
8796export SM_LOG_LEVEL='20'
8897export SM_MASTER_ADDR='algo-1'
8998export SM_MASTER_PORT='7777'
99+ export SM_SOURCE_DIR='/opt/ml/input/data/code'
100+ export SM_ENTRY_SCRIPT='train.py'
101+ export SM_DRIVER_DIR='/opt/ml/input/data/sm_drivers/drivers'
102+ export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}'
90103export SM_CHANNEL_TRAIN='/opt/ml/input/data/train'
91104export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation'
92105export SM_CHANNELS='["train", "validation"]'
110123"""
111124
112125
126+ @patch (
127+ "sagemaker.modules.train.container_drivers.scripts.environment.read_source_code_json" ,
128+ return_value = SOURCE_CODE ,
129+ )
130+ @patch (
131+ "sagemaker.modules.train.container_drivers.scripts.environment.read_distributed_json" ,
132+ return_value = DISTRIBUTED_CONFIG ,
133+ )
113134@patch ("sagemaker.modules.train.container_drivers.scripts.environment.num_cpus" , return_value = 8 )
114135@patch ("sagemaker.modules.train.container_drivers.scripts.environment.num_gpus" , return_value = 0 )
115136@patch ("sagemaker.modules.train.container_drivers.scripts.environment.num_neurons" , return_value = 0 )
122143 side_effect = safe_deserialize ,
123144)
124145def test_set_env (
125- mock_safe_deserialize , mock_safe_serialize , mock_num_cpus , mock_num_gpus , mock_num_neurons
146+ mock_safe_deserialize ,
147+ mock_safe_serialize ,
148+ mock_num_neurons ,
149+ mock_num_gpus ,
150+ mock_num_cpus ,
151+ mock_read_distributed_json ,
152+ mock_read_source_code_json ,
126153):
127154 with patch .dict (os .environ , {"TRAINING_JOB_NAME" : "test-job" }):
128155 set_env (
@@ -135,6 +162,8 @@ def test_set_env(
135162 mock_num_cpus .assert_called_once ()
136163 mock_num_gpus .assert_called_once ()
137164 mock_num_neurons .assert_called_once ()
165+ mock_read_distributed_json .assert_called_once ()
166+ mock_read_source_code_json .assert_called_once ()
138167
139168 with open (OUTPUT_FILE , "r" ) as f :
140169 env_file = f .read ().strip ()
0 commit comments