|
179 | 179 | fi |
180 | 180 |
|
181 | 181 | printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" |
182 | | - printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n" |
183 | | - $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@" |
| 182 | + printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function \\n" |
| 183 | + $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function "$@" |
184 | 184 | else |
185 | 185 | printf "INFO: No conda env provided. Invoking remote function\\n" |
186 | | - printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n" |
187 | | - python -m sagemaker.train.remote_function.invoke_function "$@" |
| 186 | + printf "INFO: python -m sagemaker.core.remote_function.invoke_function \\n" |
| 187 | + python -m sagemaker.core.remote_function.invoke_function "$@" |
188 | 188 | fi |
189 | 189 | """ |
190 | 190 |
|
|
238 | 238 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
239 | 239 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
240 | 240 |
|
241 | | - python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" |
| 241 | + python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n" |
242 | 242 | $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ |
243 | 243 | --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ |
244 | 244 | -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ |
245 | 245 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
246 | 246 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
247 | 247 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
248 | | - python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" |
| 248 | + python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@" |
249 | 249 |
|
250 | 250 | python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 |
251 | 251 | else |
|
263 | 263 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
264 | 264 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
265 | 265 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
266 | | - python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" |
| 266 | + python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n" |
267 | 267 |
|
268 | 268 | mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ |
269 | 269 | --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ |
270 | 270 | -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ |
271 | 271 | -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ |
272 | 272 | -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ |
273 | 273 | $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ |
274 | | - python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" |
| 274 | + python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@" |
275 | 275 |
|
276 | 276 | python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 |
277 | 277 | else |
|
324 | 324 | printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" |
325 | 325 | printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ |
326 | 326 | --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ |
327 | | - -m sagemaker.train.remote_function.invoke_function \\n" |
| 327 | + -m sagemaker.core.remote_function.invoke_function \\n" |
328 | 328 |
|
329 | 329 | $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ |
330 | 330 | --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ |
331 | | - -m sagemaker.train.remote_function.invoke_function "$@" |
| 331 | + -m sagemaker.core.remote_function.invoke_function "$@" |
332 | 332 | else |
333 | 333 | printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" |
334 | 334 | printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ |
335 | | - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n" |
| 335 | + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function \\n" |
336 | 336 |
|
337 | 337 | torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ |
338 | | - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@" |
| 338 | + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function "$@" |
339 | 339 | fi |
340 | 340 | """ |
341 | 341 |
|
@@ -728,7 +728,7 @@ def __init__( |
728 | 728 | sagemaker_session=self.sagemaker_session, |
729 | 729 | ) |
730 | 730 | if _role: |
731 | | - self.role = expand_role(self.sagemaker_session.boto_session, _role) |
| 731 | + self.role = expand_role(self.sagemaker_session, _role) |
732 | 732 | else: |
733 | 733 | self.role = get_execution_role(self.sagemaker_session) |
734 | 734 |
|
@@ -941,16 +941,24 @@ def compile( |
941 | 941 | # generate asymmetric key pair for integrity check |
942 | 942 | if step_compilation_context is None: |
943 | 943 | private_key = ec.generate_private_key(ec.SECP256R1()) |
944 | | - public_key_pem = private_key.public_key().public_bytes( |
945 | | - crypto_serialization.Encoding.PEM, |
946 | | - crypto_serialization.PublicFormat.SubjectPublicKeyInfo, |
947 | | - ).decode("utf-8") |
| 944 | + public_key_pem = ( |
| 945 | + private_key.public_key() |
| 946 | + .public_bytes( |
| 947 | + crypto_serialization.Encoding.PEM, |
| 948 | + crypto_serialization.PublicFormat.SubjectPublicKeyInfo, |
| 949 | + ) |
| 950 | + .decode("utf-8") |
| 951 | + ) |
948 | 952 | else: |
949 | 953 | private_key = step_compilation_context.function_step_secret_token |
950 | | - public_key_pem = private_key.public_key().public_bytes( |
951 | | - crypto_serialization.Encoding.PEM, |
952 | | - crypto_serialization.PublicFormat.SubjectPublicKeyInfo, |
953 | | - ).decode("utf-8") |
| 954 | + public_key_pem = ( |
| 955 | + private_key.public_key() |
| 956 | + .public_bytes( |
| 957 | + crypto_serialization.Encoding.PEM, |
| 958 | + crypto_serialization.PublicFormat.SubjectPublicKeyInfo, |
| 959 | + ) |
| 960 | + .decode("utf-8") |
| 961 | + ) |
954 | 962 |
|
955 | 963 | # serialize function and arguments |
956 | 964 | if step_compilation_context is None: |
|
0 commit comments