Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 39 additions & 19 deletions python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,12 +999,26 @@ def restore_unflattened_state_dict(
unflattened_local_shape
)
direct_reshape_metas[key] = unflattened_meta
if (
len(unflattened_local_shape) >= 2
and unflattened_local_shape[-1] == numel_in_slice
):
reshard_needed_tensors[key] = local_tensor.reshape(
(numel_in_slice,)
)
reshard_target_infos[key] = (
numel_in_slice,
slices,
unflattened_meta,
False,
)
else:
reshard_needed_tensors[key] = local_tensor
reshard_target_infos[key] = (
numel_in_slice,
slices,
unflattened_meta,
True,
)

resharded_tensors = {}
Expand All @@ -1019,7 +1033,9 @@ def restore_unflattened_state_dict(
for key, local_tensor in reshard_needed_tensors.items():
tensor_name, file_name = key
meta = _metadata_manager.local_tensor_metadata[key]
numel, slices, unflattened_meta = reshard_target_infos[key]
numel, slices, unflattened_meta, need_resharding = reshard_target_infos[
key
]
tensor_name_expand = f"{tensor_name}.global_offset.{meta.global_offset}"

flat_start, flat_end = meta.flattened_range
Expand Down Expand Up @@ -1051,18 +1067,18 @@ def restore_unflattened_state_dict(
global_offset_1d = (
ravel_index(tuple(s[0] for s in slices), meta.local_shape),
)

destination_sharded_state_dict[
(tensor_name_expand, global_offset_1d)
] = ShardedWeight(
key=tensor_name_expand,
local_tensor=tmp_target_tensor,
local_shape=(numel,),
global_shape=(math.prod(meta.local_shape),),
global_offset=global_offset_1d,
)
name_mapping[key] = (tensor_name_expand, global_offset_1d)
force_gc.append(local_tensor)
if need_resharding:
destination_sharded_state_dict[
(tensor_name_expand, global_offset_1d)
] = ShardedWeight(
key=tensor_name_expand,
local_tensor=tmp_target_tensor,
local_shape=(numel,),
global_shape=(math.prod(meta.local_shape),),
global_offset=global_offset_1d,
)
name_mapping[key] = (tensor_name_expand, global_offset_1d)
force_gc.append(local_tensor)

global_state_dict_metadata, global_storage_metadata = [], []
if use_dist:
Expand All @@ -1083,6 +1099,7 @@ def restore_unflattened_state_dict(
tmp_metadata.storage_metadata = {
k: v for d in global_storage_metadata for k, v in d.items()
}

_load_state_dict(
target_state_dict=destination_sharded_state_dict,
source_state_dict=source_state_dict_for_reshard,
Expand All @@ -1093,12 +1110,15 @@ def restore_unflattened_state_dict(
)

for key in reshard_needed_tensors:
target_key = name_mapping[key]
unflattened_meta = reshard_target_infos[key][2]

final_tensor = destination_sharded_state_dict[target_key].local_tensor
final_tensor.reshape_(unflattened_meta.local_shape)
resharded_tensors[key] = final_tensor
need_resharding = reshard_target_infos[key][3]
if need_resharding:
target_key = name_mapping[key]
unflattened_meta = reshard_target_infos[key][2]
final_tensor = destination_sharded_state_dict[
target_key
].local_tensor
final_tensor.reshape_(unflattened_meta.local_shape)
resharded_tensors[key] = final_tensor

final_unflattened_state_dict = defaultdict(dict)
final_local_tensor_meta = defaultdict(list)
Expand Down
85 changes: 25 additions & 60 deletions test/auto_parallel/hybrid_strategy/semi_auto_save_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,13 @@ def test_save_state_dict_with_one_device(self):
save_state_dict(state_dict, self._ckpt_path)
check_structure_name_mapping(self._ckpt_path, state_dict)

def test_save_state_dict_with_four_devices(self):
def test_save_state_dict_with_two_devices(self):
global_state_dict = get_global_state_dict()
keys = list(global_state_dict.keys())
w1, w2 = list(global_state_dict.values())
mesh = dist.ProcessMesh([0, 1])
mesh2 = dist.ProcessMesh([2, 3])
sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)])
sharded_w2 = dist.shard_tensor(w2, mesh2, [dist.Shard(0)])
sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(0)])
state_dict = dict(zip(keys, [sharded_w1, sharded_w2]))
save_state_dict(state_dict, self._ckpt_path)
paddle.distributed.barrier()
Expand All @@ -86,8 +85,8 @@ def run_test_case(self):
device_num = int(os.getenv("device_num"))
if device_num == 1:
self.test_save_state_dict_with_one_device()
elif device_num == 4:
self.test_save_state_dict_with_four_devices()
elif device_num == 2:
self.test_save_state_dict_with_two_devices()


class TestSaveShardedStateDict:
Expand All @@ -110,20 +109,20 @@ def test_save_state_dict_with_one_device(self):
)
save_state_dict(sharded_state_dict, self._ckpt_path)

def test_save_state_dict_with_four_devices(self):
def test_save_state_dict_with_two_devices(self):
if dist.get_rank() == 0:
# On rank 0:
# The global tensor (4x4) is distributed as:
# [[ 0, 1, *, *],
# [ 4, *, *, *],
# [[ 0, 1, 2, *],
# [ *, *, *, *],
# [ *, *, *, *],
# [ *, *, *, *]]
# Numbers 0,1,4 are local, '*' means not present on this rank.
local_tensor = paddle.to_tensor([0, 1, 4], dtype='int32')
# Numbers 0,1,2 are local, '*' means not present on this rank.
local_tensor = paddle.to_tensor([0, 1, 2], dtype='int32')
sharded_weight = ShardedWeight(
key="t",
local_tensor=local_tensor,
local_shape=(4, 2),
local_shape=(4, 4),
global_shape=(4, 4),
global_offset=(0, 0),
is_flattened=True,
Expand All @@ -132,56 +131,22 @@ def test_save_state_dict_with_four_devices(self):
elif dist.get_rank() == 1:
# On rank 1:
# The global tensor (4x4) is distributed as:
# [[ *, *, *, *],
# [ *, 5, *, *],
# [ 8, 9, *, *],
# [ 12, 13, *, *]]
# Numbers 5,8,9,12,13 are local, '*' means not present on this rank.
local_tensor = paddle.to_tensor([5, 8, 9, 12, 13], dtype='int32')
sharded_weight = ShardedWeight(
key="t",
local_tensor=local_tensor,
local_shape=(4, 2),
global_shape=(4, 4),
global_offset=(0, 0),
is_flattened=True,
flattened_range=slice(3, 8),
)
elif dist.get_rank() == 2:
# On rank 2:
# The global tensor (4x4) is distributed as:
# [[ *, *, 2, 3],
# [ *, *, 6, 7],
# [ *, *, 10, *],
# [ *, *, *, *]]
# Numbers 2,3,6,7,10 are local, '*' means not present on this rank.
local_tensor = paddle.to_tensor([2, 3, 6, 7, 10], dtype='int32')
sharded_weight = ShardedWeight(
key="t",
local_tensor=local_tensor,
local_shape=(4, 2),
global_shape=(4, 4),
global_offset=(0, 2),
is_flattened=True,
flattened_range=slice(0, 5),
# [[ *, *, *, 3],
# [ 4, 5, 5, 6],
# [ 8, 9, 10, 11],
# [ 12, 13, 14, 15]]
# Numbers 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 are local, '*' means not present on this rank.
local_tensor = paddle.to_tensor(
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype='int32'
)
else:
# On rank 3:
# The global tensor (4x4) is distributed as:
# [[ *, *, *, *],
# [ *, *, *, *],
# [ *, *, *, 11],
# [ *, *, 14, 15]]
# Numbers 11,14,15 are local, '*' means not present on this rank.
local_tensor = paddle.to_tensor([11, 14, 15], dtype='int32')
sharded_weight = ShardedWeight(
key="t",
local_tensor=local_tensor,
local_shape=(4, 2),
local_shape=(4, 4),
global_shape=(4, 4),
global_offset=(0, 2),
global_offset=(0, 0),
is_flattened=True,
flattened_range=slice(5, 8),
flattened_range=slice(3, 16),
)

sharded_state_dict = {"t": sharded_weight}
Expand All @@ -192,8 +157,8 @@ def run_test_case(self):
device_num = int(os.getenv("device_num"))
if device_num == 1:
self.test_save_state_dict_with_one_device()
elif device_num == 4:
self.test_save_state_dict_with_four_devices()
elif device_num == 2:
self.test_save_state_dict_with_two_devices()


class TestSaveShardedStateDictWithReplica:
Expand All @@ -216,7 +181,7 @@ def test_save_state_dict_with_one_device(self):
)
save_state_dict(sharded_state_dict, self._ckpt_path, save_replicas=True)

def test_save_state_dict_with_four_devices(self):
def test_save_state_dict_with_two_devices(self):
# Construct a 4x4 integer tensor as expected result:
# [[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
Expand All @@ -237,8 +202,8 @@ def run_test_case(self):
device_num = int(os.getenv("device_num"))
if device_num == 1:
self.test_save_state_dict_with_one_device()
elif device_num == 4:
self.test_save_state_dict_with_four_devices()
elif device_num == 2:
self.test_save_state_dict_with_two_devices()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def test_reshard(self):
ckpt_path_2.cleanup()
ckpt_path_3.cleanup()

# save with 4 devices
# save with 2 devices
ckpt_path = tempfile.TemporaryDirectory()
ckpt_path_2 = tempfile.TemporaryDirectory()
ckpt_path_3 = tempfile.TemporaryDirectory()
super().setUp(num_of_devices=4, timeout=120, nnode=1)
super().setUp(num_of_devices=2, timeout=120, nnode=1)
self.run_test_case(
"semi_auto_save_state_dict.py",
user_defined_envs={
"device_num": "4",
"device_num": "2",
"ckpt_path": ckpt_path.name,
"ckpt_path_2": ckpt_path_2.name,
"ckpt_path_3": ckpt_path_3.name,
Expand Down
Loading