Skip to content

Commit 89bf0d2

Browse files
nathon-leeCopilotsfc-gh-truwase
authored
Refactor consolidate transpose (#7934)
refactor(module_inject): consolidate duplicate transpose functions - Extract the duplicated `transpose` function into `deepspeed/module_inject/utils.py`. - Remove redundant `transpose` definitions from `policy.py` and `load_checkpoint.py`. - This resolves an existing `TODO (lekurile)` to consolidate the function across containers. --------- Signed-off-by: nathon-lee <leejianwoo@gmail.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: nathon-lee <248585198+nathon-lee@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
1 parent 607b55f commit 89bf0d2

3 files changed

Lines changed: 12 additions & 18 deletions

File tree

deepspeed/module_inject/load_checkpoint.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import gc
1818
from deepspeed.accelerator import get_accelerator
1919
import re
20+
from .utils import transpose
2021

2122

2223
def load_model_with_checkpoint(r_module,
@@ -42,14 +43,6 @@ def prefix_check():
4243

4344
skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix
4445

45-
def transpose(data):
46-
with torch.no_grad():
47-
data = data.contiguous()
48-
data1 = data.transpose(-1, -2).reshape(-1)
49-
data.reshape(-1).copy_(data1)
50-
data1 = None
51-
return data.reshape(data.shape[-1], data.shape[-2])
52-
5346
def load(module, prefix):
5447
args = (sd[0], prefix, {}, True, [], [], error_msgs)
5548

deepspeed/module_inject/policy.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from deepspeed.utils.types import ActivationFuncType, NormType
88
import torch
99
from deepspeed.accelerator import get_accelerator
10+
from .utils import transpose
1011

1112
transformer_param_names = (
1213
'attn_qkvw', \
@@ -109,16 +110,6 @@ def layernorm(self):
109110
raise NotImplementedError
110111

111112

112-
# TODO (lekurile): This function exists in base container as well, consolidate as some point
113-
def transpose(data):
114-
with torch.no_grad():
115-
data = data.contiguous()
116-
data1 = data.transpose(-1, -2).reshape(-1)
117-
data.reshape(-1).copy_(data1)
118-
data1 = None
119-
return data.reshape(data.shape[-1], data.shape[-2])
120-
121-
122113
# TODO (lekurile): This function exists in megatron feature container as well, consolidate as some point
123114
def _transpose(x, heads=1, mp_replace=None):
124115
heads = heads // mp_replace.mp_size # type: ignore

deepspeed/module_inject/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,19 @@
33

44
# DeepSpeed Team
55

6+
import torch
67
from deepspeed.utils import log_dist
78

89

10+
def transpose(data):
11+
with torch.no_grad():
12+
data = data.contiguous()
13+
data1 = data.transpose(-1, -2).reshape(-1)
14+
data.reshape(-1).copy_(data1)
15+
data1 = None
16+
return data.reshape(data.shape[-1], data.shape[-2])
17+
18+
919
# helper function to map between DS policies and DS containers
1020
def policy_to_ds_container(**kwargs):
1121
from .containers import HFGPT2LayerPolicy, DS_GPT2Container

0 commit comments

Comments
 (0)