Skip to content

Commit 0747a9d

Browse files
authored
[cherry-pick] fix fc unflatten miss key (#78818)
* fix unflatten miss key * delete useless keys
1 parent 62fa100 commit 0747a9d

3 files changed

Lines changed: 67 additions & 82 deletions

File tree

python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -999,12 +999,26 @@ def restore_unflattened_state_dict(
999999
unflattened_local_shape
10001000
)
10011001
direct_reshape_metas[key] = unflattened_meta
1002+
if (
1003+
len(unflattened_local_shape) >= 2
1004+
and unflattened_local_shape[-1] == numel_in_slice
1005+
):
1006+
reshard_needed_tensors[key] = local_tensor.reshape(
1007+
(numel_in_slice,)
1008+
)
1009+
reshard_target_infos[key] = (
1010+
numel_in_slice,
1011+
slices,
1012+
unflattened_meta,
1013+
False,
1014+
)
10021015
else:
10031016
reshard_needed_tensors[key] = local_tensor
10041017
reshard_target_infos[key] = (
10051018
numel_in_slice,
10061019
slices,
10071020
unflattened_meta,
1021+
True,
10081022
)
10091023

10101024
resharded_tensors = {}
@@ -1019,7 +1033,9 @@ def restore_unflattened_state_dict(
10191033
for key, local_tensor in reshard_needed_tensors.items():
10201034
tensor_name, file_name = key
10211035
meta = _metadata_manager.local_tensor_metadata[key]
1022-
numel, slices, unflattened_meta = reshard_target_infos[key]
1036+
numel, slices, unflattened_meta, need_resharding = reshard_target_infos[
1037+
key
1038+
]
10231039
tensor_name_expand = f"{tensor_name}.global_offset.{meta.global_offset}"
10241040

10251041
flat_start, flat_end = meta.flattened_range
@@ -1051,18 +1067,18 @@ def restore_unflattened_state_dict(
10511067
global_offset_1d = (
10521068
ravel_index(tuple(s[0] for s in slices), meta.local_shape),
10531069
)
1054-
1055-
destination_sharded_state_dict[
1056-
(tensor_name_expand, global_offset_1d)
1057-
] = ShardedWeight(
1058-
key=tensor_name_expand,
1059-
local_tensor=tmp_target_tensor,
1060-
local_shape=(numel,),
1061-
global_shape=(math.prod(meta.local_shape),),
1062-
global_offset=global_offset_1d,
1063-
)
1064-
name_mapping[key] = (tensor_name_expand, global_offset_1d)
1065-
force_gc.append(local_tensor)
1070+
if need_resharding:
1071+
destination_sharded_state_dict[
1072+
(tensor_name_expand, global_offset_1d)
1073+
] = ShardedWeight(
1074+
key=tensor_name_expand,
1075+
local_tensor=tmp_target_tensor,
1076+
local_shape=(numel,),
1077+
global_shape=(math.prod(meta.local_shape),),
1078+
global_offset=global_offset_1d,
1079+
)
1080+
name_mapping[key] = (tensor_name_expand, global_offset_1d)
1081+
force_gc.append(local_tensor)
10661082

10671083
global_state_dict_metadata, global_storage_metadata = [], []
10681084
if use_dist:
@@ -1083,6 +1099,7 @@ def restore_unflattened_state_dict(
10831099
tmp_metadata.storage_metadata = {
10841100
k: v for d in global_storage_metadata for k, v in d.items()
10851101
}
1102+
10861103
_load_state_dict(
10871104
target_state_dict=destination_sharded_state_dict,
10881105
source_state_dict=source_state_dict_for_reshard,
@@ -1093,12 +1110,15 @@ def restore_unflattened_state_dict(
10931110
)
10941111

10951112
for key in reshard_needed_tensors:
1096-
target_key = name_mapping[key]
1097-
unflattened_meta = reshard_target_infos[key][2]
1098-
1099-
final_tensor = destination_sharded_state_dict[target_key].local_tensor
1100-
final_tensor.reshape_(unflattened_meta.local_shape)
1101-
resharded_tensors[key] = final_tensor
1113+
need_resharding = reshard_target_infos[key][3]
1114+
if need_resharding:
1115+
target_key = name_mapping[key]
1116+
unflattened_meta = reshard_target_infos[key][2]
1117+
final_tensor = destination_sharded_state_dict[
1118+
target_key
1119+
].local_tensor
1120+
final_tensor.reshape_(unflattened_meta.local_shape)
1121+
resharded_tensors[key] = final_tensor
11021122

11031123
final_unflattened_state_dict = defaultdict(dict)
11041124
final_local_tensor_meta = defaultdict(list)

test/auto_parallel/hybrid_strategy/semi_auto_save_state_dict.py

Lines changed: 25 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,13 @@ def test_save_state_dict_with_one_device(self):
6969
save_state_dict(state_dict, self._ckpt_path)
7070
check_structure_name_mapping(self._ckpt_path, state_dict)
7171

72-
def test_save_state_dict_with_four_devices(self):
72+
def test_save_state_dict_with_two_devices(self):
7373
global_state_dict = get_global_state_dict()
7474
keys = list(global_state_dict.keys())
7575
w1, w2 = list(global_state_dict.values())
7676
mesh = dist.ProcessMesh([0, 1])
77-
mesh2 = dist.ProcessMesh([2, 3])
7877
sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)])
79-
sharded_w2 = dist.shard_tensor(w2, mesh2, [dist.Shard(0)])
78+
sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(0)])
8079
state_dict = dict(zip(keys, [sharded_w1, sharded_w2]))
8180
save_state_dict(state_dict, self._ckpt_path)
8281
paddle.distributed.barrier()
@@ -86,8 +85,8 @@ def run_test_case(self):
8685
device_num = int(os.getenv("device_num"))
8786
if device_num == 1:
8887
self.test_save_state_dict_with_one_device()
89-
elif device_num == 4:
90-
self.test_save_state_dict_with_four_devices()
88+
elif device_num == 2:
89+
self.test_save_state_dict_with_two_devices()
9190

9291

9392
class TestSaveShardedStateDict:
@@ -110,20 +109,20 @@ def test_save_state_dict_with_one_device(self):
110109
)
111110
save_state_dict(sharded_state_dict, self._ckpt_path)
112111

113-
def test_save_state_dict_with_four_devices(self):
112+
def test_save_state_dict_with_two_devices(self):
114113
if dist.get_rank() == 0:
115114
# On rank 0:
116115
# The global tensor (4x4) is distributed as:
117-
# [[ 0, 1, *, *],
118-
# [ 4, *, *, *],
116+
# [[ 0, 1, 2, *],
117+
# [ *, *, *, *],
119118
# [ *, *, *, *],
120119
# [ *, *, *, *]]
121-
# Numbers 0,1,4 are local, '*' means not present on this rank.
122-
local_tensor = paddle.to_tensor([0, 1, 4], dtype='int32')
120+
# Numbers 0,1,2 are local, '*' means not present on this rank.
121+
local_tensor = paddle.to_tensor([0, 1, 2], dtype='int32')
123122
sharded_weight = ShardedWeight(
124123
key="t",
125124
local_tensor=local_tensor,
126-
local_shape=(4, 2),
125+
local_shape=(4, 4),
127126
global_shape=(4, 4),
128127
global_offset=(0, 0),
129128
is_flattened=True,
@@ -132,56 +131,22 @@ def test_save_state_dict_with_four_devices(self):
132131
elif dist.get_rank() == 1:
133132
# On rank 1:
134133
# The global tensor (4x4) is distributed as:
135-
# [[ *, *, *, *],
136-
# [ *, 5, *, *],
137-
# [ 8, 9, *, *],
138-
# [ 12, 13, *, *]]
139-
# Numbers 5,8,9,12,13 are local, '*' means not present on this rank.
140-
local_tensor = paddle.to_tensor([5, 8, 9, 12, 13], dtype='int32')
141-
sharded_weight = ShardedWeight(
142-
key="t",
143-
local_tensor=local_tensor,
144-
local_shape=(4, 2),
145-
global_shape=(4, 4),
146-
global_offset=(0, 0),
147-
is_flattened=True,
148-
flattened_range=slice(3, 8),
149-
)
150-
elif dist.get_rank() == 2:
151-
# On rank 2:
152-
# The global tensor (4x4) is distributed as:
153-
# [[ *, *, 2, 3],
154-
# [ *, *, 6, 7],
155-
# [ *, *, 10, *],
156-
# [ *, *, *, *]]
157-
# Numbers 2,3,6,7,10 are local, '*' means not present on this rank.
158-
local_tensor = paddle.to_tensor([2, 3, 6, 7, 10], dtype='int32')
159-
sharded_weight = ShardedWeight(
160-
key="t",
161-
local_tensor=local_tensor,
162-
local_shape=(4, 2),
163-
global_shape=(4, 4),
164-
global_offset=(0, 2),
165-
is_flattened=True,
166-
flattened_range=slice(0, 5),
134+
# [[ *, *, *, 3],
135+
# [ 4, 5, 5, 6],
136+
# [ 8, 9, 10, 11],
137+
# [ 12, 13, 14, 15]]
138+
# Numbers 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 are local, '*' means not present on this rank.
139+
local_tensor = paddle.to_tensor(
140+
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype='int32'
167141
)
168-
else:
169-
# On rank 3:
170-
# The global tensor (4x4) is distributed as:
171-
# [[ *, *, *, *],
172-
# [ *, *, *, *],
173-
# [ *, *, *, 11],
174-
# [ *, *, 14, 15]]
175-
# Numbers 11,14,15 are local, '*' means not present on this rank.
176-
local_tensor = paddle.to_tensor([11, 14, 15], dtype='int32')
177142
sharded_weight = ShardedWeight(
178143
key="t",
179144
local_tensor=local_tensor,
180-
local_shape=(4, 2),
145+
local_shape=(4, 4),
181146
global_shape=(4, 4),
182-
global_offset=(0, 2),
147+
global_offset=(0, 0),
183148
is_flattened=True,
184-
flattened_range=slice(5, 8),
149+
flattened_range=slice(3, 16),
185150
)
186151

187152
sharded_state_dict = {"t": sharded_weight}
@@ -192,8 +157,8 @@ def run_test_case(self):
192157
device_num = int(os.getenv("device_num"))
193158
if device_num == 1:
194159
self.test_save_state_dict_with_one_device()
195-
elif device_num == 4:
196-
self.test_save_state_dict_with_four_devices()
160+
elif device_num == 2:
161+
self.test_save_state_dict_with_two_devices()
197162

198163

199164
class TestSaveShardedStateDictWithReplica:
@@ -216,7 +181,7 @@ def test_save_state_dict_with_one_device(self):
216181
)
217182
save_state_dict(sharded_state_dict, self._ckpt_path, save_replicas=True)
218183

219-
def test_save_state_dict_with_four_devices(self):
184+
def test_save_state_dict_with_two_devices(self):
220185
# Construct a 4x4 integer tensor as expected result:
221186
# [[ 0, 1, 2, 3],
222187
# [ 4, 5, 6, 7],
@@ -237,8 +202,8 @@ def run_test_case(self):
237202
device_num = int(os.getenv("device_num"))
238203
if device_num == 1:
239204
self.test_save_state_dict_with_one_device()
240-
elif device_num == 4:
241-
self.test_save_state_dict_with_four_devices()
205+
elif device_num == 2:
206+
self.test_save_state_dict_with_two_devices()
242207

243208

244209
if __name__ == "__main__":

test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ def test_reshard(self):
6464
ckpt_path_2.cleanup()
6565
ckpt_path_3.cleanup()
6666

67-
# save with 4 devices
67+
# save with 2 devices
6868
ckpt_path = tempfile.TemporaryDirectory()
6969
ckpt_path_2 = tempfile.TemporaryDirectory()
7070
ckpt_path_3 = tempfile.TemporaryDirectory()
71-
super().setUp(num_of_devices=4, timeout=120, nnode=1)
71+
super().setUp(num_of_devices=2, timeout=120, nnode=1)
7272
self.run_test_case(
7373
"semi_auto_save_state_dict.py",
7474
user_defined_envs={
75-
"device_num": "4",
75+
"device_num": "2",
7676
"ckpt_path": ckpt_path.name,
7777
"ckpt_path_2": ckpt_path_2.name,
7878
"ckpt_path_3": ckpt_path_3.name,

0 commit comments

Comments
 (0)