forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdistributed.py
More file actions
181 lines (146 loc) · 6.66 KB
/
distributed.py
File metadata and controls
181 lines (146 loc) · 6.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Distributed module."""
from __future__ import absolute_import
import os
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List
from sagemaker.modules.utils import safe_serialize
from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH
from sagemaker.modules.configs import BaseConfig
class SMP(BaseConfig):
"""SMP.
This class is used for configuring the SageMaker Model Parallelism v2 parameters.
For more information on the model parallelism parameters, see:
https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-model-parallel-v2-reference.html#distributed-model-parallel-v2-reference-init-config
Parameters:
hybrid_shard_degree (Optional[int]):
Specifies a sharded parallelism degree for the model.
sm_activation_offloading (Optional[bool]):
Specifies whether to enable the SMP activation offloading implementation.
activation_loading_horizon (Optional[int]):
An integer specifying the activation offloading horizon type for FSDP. This is the
maximum number of checkpointed or offloaded layers whose inputs can be in the GPU
memory simultaneously.
fsdp_cache_flush_warnings (Optional[bool]):
Detects and warns if cache flushes happen in the PyTorch memory manager, because they
can degrade computational performance.
allow_empty_shards (Optional[bool]):
Whether to allow empty shards when sharding tensors if tensor is not divisible. This is
an experimental fix for crash during checkpointing in certain scenarios. Disabling this
falls back to the original PyTorch behavior.
tensor_parallel_degree (Optional[int]):
Specifies a tensor parallelism degree. The value must be between 1 and world_size.
context_parallel_degree (Optional[int]):
Specifies the context parallelism degree. The value must be between 1 and world_size ,
and must be <= hybrid_shard_degree.
expert_parallel_degree (Optional[int]):
Specifies a expert parallelism degree. The value must be between 1 and world_size.
random_seed (Optional[int]):
A seed number for the random operations in distributed modules by SMP tensor
parallelism or expert parallelism.
"""
hybrid_shard_degree: Optional[int] = None
sm_activation_offloading: Optional[bool] = None
activation_loading_horizon: Optional[int] = None
fsdp_cache_flush_warnings: Optional[bool] = None
allow_empty_shards: Optional[bool] = None
tensor_parallel_degree: Optional[int] = None
context_parallel_degree: Optional[int] = None
expert_parallel_degree: Optional[int] = None
random_seed: Optional[int] = None
def _to_mp_hyperparameters(self) -> Dict[str, Any]:
"""Converts to the hyperparameters format for the SageMaker Model Parallelism v2."""
mp_parameters = self.model_dump(exclude_none=True)
hyperparameters = {
"mp_parameters": safe_serialize(mp_parameters),
}
return hyperparameters
class DistributedConfig(BaseConfig, ABC):
"""Abstract base class for distributed training configurations.
This class defines the interface that all distributed training configurations
must implement. It provides a standardized way to specify driver scripts and
their locations for distributed training jobs.
"""
@property
@abstractmethod
def driver_dir(self) -> str:
"""Directory containing the driver script.
This property should return the path to the directory containing
the driver script, relative to the container's working directory.
Returns:
str: Path to directory containing the driver script
"""
@property
@abstractmethod
def driver_script(self) -> str:
"""Name of the driver script.
This property should return the name of the Python script that implements
the distributed training driver logic.
Returns:
str: Name of the driver script file
"""
class Torchrun(DistributedConfig):
"""Torchrun.
The Torchrun class configures a job that uses ``torchrun`` or
``torch.distributed.launch`` in the backend to launch distributed training.
Parameters:
process_count_per_node (int):
The number of processes to run on each node in the training job.
Will default to the number of GPUs available in the container.
smp (Optional[SMP]):
The SageMaker Model Parallelism v2 parameters.
"""
process_count_per_node: Optional[int] = None
smp: Optional["SMP"] = None
@property
def driver_dir(self) -> str:
"""Directory containing the driver script.
Returns:
str: Path to directory containing the driver script
"""
return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers")
@property
def driver_script(self) -> str:
"""Name of the driver script.
Returns:
str: Name of the driver script file
"""
return "torchrun_driver.py"
class MPI(DistributedConfig):
"""MPI.
The MPI class configures a job that uses ``mpirun`` in the backend to launch
distributed training.
Parameters:
process_count_per_node (int):
The number of processes to run on each node in the training job.
Will default to the number of GPUs available in the container.
mpi_additional_options (Optional[str]):
The custom MPI options to use for the training job.
"""
process_count_per_node: Optional[int] = None
mpi_additional_options: Optional[List[str]] = None
@property
def driver_dir(self) -> str:
"""Directory containing the driver script.
Returns:
str: Path to directory containing the driver script
"""
return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers")
@property
def driver_script(self) -> str:
"""Name of the driver script.
Returns:
str: Name of the driver script
"""
return "mpi_driver.py"