forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
213 lines (171 loc) · 6.82 KB
/
utils.py
File metadata and controls
213 lines (171 loc) · 6.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module provides utility functions for the container drivers."""
from __future__ import absolute_import
import os
import logging
import sys
import subprocess
import traceback
import json
from typing import List, Dict, Any, Tuple, IO, Optional
# Initialize logger
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
logger = logging.getLogger(__name__)
console_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(console_handler)
logger.setLevel(int(SM_LOG_LEVEL))
FAILURE_FILE = "/opt/ml/output/failure"
DEFAULT_FAILURE_MESSAGE = """
Training Execution failed.
For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
TrainingJob - {training_job_name}
"""
USER_CODE_PATH = "/opt/ml/input/data/code"
SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json"
DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json"
HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json"
SM_EFA_NCCL_INSTANCES = [
"ml.g4dn.8xlarge",
"ml.g4dn.12xlarge",
"ml.g5.48xlarge",
"ml.p3dn.24xlarge",
"ml.p4d.24xlarge",
"ml.p4de.24xlarge",
"ml.p5.48xlarge",
"ml.trn1.32xlarge",
]
SM_EFA_RDMA_INSTANCES = [
"ml.p4d.24xlarge",
"ml.p4de.24xlarge",
"ml.trn1.32xlarge",
]
def write_failure_file(message: Optional[str] = None):
"""Write a failure file with the message."""
if message is None:
message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"])
if not os.path.exists(FAILURE_FILE):
with open(FAILURE_FILE, "w") as f:
f.write(message)
def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON):
"""Read the source code config json file."""
try:
with open(source_code_json, "r") as f:
source_code_dict = json.load(f) or {}
except FileNotFoundError:
source_code_dict = {}
return source_code_dict
def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON):
"""Read the distribution config json file."""
try:
with open(distributed_json, "r") as f:
distributed_dict = json.load(f) or {}
except FileNotFoundError:
distributed_dict = {}
return distributed_dict
def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON):
"""Read the hyperparameters config json file."""
try:
with open(hyperparameters_json, "r") as f:
hyperparameters_dict = json.load(f) or {}
except FileNotFoundError:
hyperparameters_dict = {}
return hyperparameters_dict
def get_process_count(process_count: Optional[int] = None) -> int:
"""Get the number of processes to run on each node in the training job."""
return (
process_count
or int(os.environ.get("SM_NUM_GPUS", 0))
or int(os.environ.get("SM_NUM_NEURONS", 0))
or 1
)
def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]:
"""Convert the hyperparameters to CLI arguments."""
cli_args = []
for key, value in hyperparameters.items():
value = safe_deserialize(value)
cli_args.extend([f"--{key}", safe_serialize(value)])
return cli_args
def safe_deserialize(data: Any) -> Any:
"""Safely deserialize data from a JSON string.
This function handles the following cases:
1. If `data` is not a string, it returns the input as-is.
2. If `data` is a string and matches common boolean values ("true" or "false"),
it returns the corresponding boolean value (True or False).
3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`.
4. If `data` is a string but cannot be decoded as JSON, it returns the original string.
Returns:
Any: The deserialized data, or the original input if it cannot be JSON-decoded.
"""
if not isinstance(data, str):
return data
lower_data = data.lower()
if lower_data in ["true"]:
return True
if lower_data in ["false"]:
return False
try:
return json.loads(data)
except json.JSONDecodeError:
return data
def safe_serialize(data):
"""Serialize the data without wrapping strings in quotes.
This function handles the following cases:
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
the JSON-encoded string using `json.dumps()`.
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
representation of the data using `str(data)`.
Args:
data (Any): The data to serialize.
Returns:
str: The serialized JSON-compatible string or the string representation of the input.
"""
if isinstance(data, str):
return data
try:
return json.dumps(data)
except TypeError:
return str(data)
def get_python_executable() -> str:
"""Get the python executable path."""
return sys.executable
def log_subprocess_output(pipe: IO[bytes]):
"""Log the output from the subprocess."""
for line in iter(pipe.readline, b""):
logger.info(line.decode("utf-8").strip())
def execute_commands(commands: List[str]) -> Tuple[int, str]:
"""Execute the provided commands and return exit code with failure traceback if any."""
try:
process = subprocess.Popen(
commands,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
with process.stdout:
log_subprocess_output(process.stdout)
exitcode = process.wait()
if exitcode != 0:
raise subprocess.CalledProcessError(exitcode, commands)
return exitcode, ""
except subprocess.CalledProcessError as e:
# Capture the traceback in case of failure
error_traceback = traceback.format_exc()
print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}")
return e.returncode, error_traceback
def is_worker_node() -> bool:
"""Check if the current node is a worker node."""
return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR")
def is_master_node() -> bool:
"""Check if the current node is the master node."""
return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR")