@@ -69,14 +69,13 @@ def test_save_state_dict_with_one_device(self):
6969 save_state_dict (state_dict , self ._ckpt_path )
7070 check_structure_name_mapping (self ._ckpt_path , state_dict )
7171
72- def test_save_state_dict_with_four_devices (self ):
72+ def test_save_state_dict_with_two_devices (self ):
7373 global_state_dict = get_global_state_dict ()
7474 keys = list (global_state_dict .keys ())
7575 w1 , w2 = list (global_state_dict .values ())
7676 mesh = dist .ProcessMesh ([0 , 1 ])
77- mesh2 = dist .ProcessMesh ([2 , 3 ])
7877 sharded_w1 = dist .shard_tensor (w1 , mesh , [dist .Shard (0 )])
79- sharded_w2 = dist .shard_tensor (w2 , mesh2 , [dist .Shard (0 )])
78+ sharded_w2 = dist .shard_tensor (w2 , mesh , [dist .Shard (0 )])
8079 state_dict = dict (zip (keys , [sharded_w1 , sharded_w2 ]))
8180 save_state_dict (state_dict , self ._ckpt_path )
8281 paddle .distributed .barrier ()
@@ -86,8 +85,8 @@ def run_test_case(self):
8685 device_num = int (os .getenv ("device_num" ))
8786 if device_num == 1 :
8887 self .test_save_state_dict_with_one_device ()
89- elif device_num == 4 :
90- self .test_save_state_dict_with_four_devices ()
88+ elif device_num == 2 :
89+ self .test_save_state_dict_with_two_devices ()
9190
9291
9392class TestSaveShardedStateDict :
@@ -110,20 +109,20 @@ def test_save_state_dict_with_one_device(self):
110109 )
111110 save_state_dict (sharded_state_dict , self ._ckpt_path )
112111
113- def test_save_state_dict_with_four_devices (self ):
112+ def test_save_state_dict_with_two_devices (self ):
114113 if dist .get_rank () == 0 :
115114 # On rank 0:
116115 # The global tensor (4x4) is distributed as:
117- # [[ 0, 1, * , *],
118- # [ 4 , *, *, *],
116+ # [[ 0, 1, 2 , *],
117+ # [ * , *, *, *],
119118 # [ *, *, *, *],
120119 # [ *, *, *, *]]
121- # Numbers 0,1,4 are local, '*' means not present on this rank.
122- local_tensor = paddle .to_tensor ([0 , 1 , 4 ], dtype = 'int32' )
120+ # Numbers 0,1,2 are local, '*' means not present on this rank.
121+ local_tensor = paddle .to_tensor ([0 , 1 , 2 ], dtype = 'int32' )
123122 sharded_weight = ShardedWeight (
124123 key = "t" ,
125124 local_tensor = local_tensor ,
126- local_shape = (4 , 2 ),
125+ local_shape = (4 , 4 ),
127126 global_shape = (4 , 4 ),
128127 global_offset = (0 , 0 ),
129128 is_flattened = True ,
@@ -132,56 +131,22 @@ def test_save_state_dict_with_four_devices(self):
132131 elif dist .get_rank () == 1 :
133132 # On rank 1:
134133 # The global tensor (4x4) is distributed as:
135- # [[ *, *, *, *],
136- # [ *, 5, *, *],
137- # [ 8, 9, *, *],
138- # [ 12, 13, *, *]]
139- # Numbers 5,8,9,12,13 are local, '*' means not present on this rank.
140- local_tensor = paddle .to_tensor ([5 , 8 , 9 , 12 , 13 ], dtype = 'int32' )
141- sharded_weight = ShardedWeight (
142- key = "t" ,
143- local_tensor = local_tensor ,
144- local_shape = (4 , 2 ),
145- global_shape = (4 , 4 ),
146- global_offset = (0 , 0 ),
147- is_flattened = True ,
148- flattened_range = slice (3 , 8 ),
149- )
150- elif dist .get_rank () == 2 :
151- # On rank 2:
152- # The global tensor (4x4) is distributed as:
153- # [[ *, *, 2, 3],
154- # [ *, *, 6, 7],
155- # [ *, *, 10, *],
156- # [ *, *, *, *]]
157- # Numbers 2,3,6,7,10 are local, '*' means not present on this rank.
158- local_tensor = paddle .to_tensor ([2 , 3 , 6 , 7 , 10 ], dtype = 'int32' )
159- sharded_weight = ShardedWeight (
160- key = "t" ,
161- local_tensor = local_tensor ,
162- local_shape = (4 , 2 ),
163- global_shape = (4 , 4 ),
164- global_offset = (0 , 2 ),
165- is_flattened = True ,
166- flattened_range = slice (0 , 5 ),
134+ # [[ *, *, *, 3],
135+ # [ 4, 5, 5, 6],
136+ # [ 8, 9, 10, 11],
137+ # [ 12, 13, 14, 15]]
138+ # Numbers 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 are local, '*' means not present on this rank.
139+ local_tensor = paddle .to_tensor (
140+ [3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 ], dtype = 'int32'
167141 )
168- else :
169- # On rank 3:
170- # The global tensor (4x4) is distributed as:
171- # [[ *, *, *, *],
172- # [ *, *, *, *],
173- # [ *, *, *, 11],
174- # [ *, *, 14, 15]]
175- # Numbers 11,14,15 are local, '*' means not present on this rank.
176- local_tensor = paddle .to_tensor ([11 , 14 , 15 ], dtype = 'int32' )
177142 sharded_weight = ShardedWeight (
178143 key = "t" ,
179144 local_tensor = local_tensor ,
180- local_shape = (4 , 2 ),
145+ local_shape = (4 , 4 ),
181146 global_shape = (4 , 4 ),
182- global_offset = (0 , 2 ),
147+ global_offset = (0 , 0 ),
183148 is_flattened = True ,
184- flattened_range = slice (5 , 8 ),
149+ flattened_range = slice (3 , 16 ),
185150 )
186151
187152 sharded_state_dict = {"t" : sharded_weight }
@@ -192,8 +157,8 @@ def run_test_case(self):
192157 device_num = int (os .getenv ("device_num" ))
193158 if device_num == 1 :
194159 self .test_save_state_dict_with_one_device ()
195- elif device_num == 4 :
196- self .test_save_state_dict_with_four_devices ()
160+ elif device_num == 2 :
161+ self .test_save_state_dict_with_two_devices ()
197162
198163
199164class TestSaveShardedStateDictWithReplica :
@@ -216,7 +181,7 @@ def test_save_state_dict_with_one_device(self):
216181 )
217182 save_state_dict (sharded_state_dict , self ._ckpt_path , save_replicas = True )
218183
219- def test_save_state_dict_with_four_devices (self ):
184+ def test_save_state_dict_with_two_devices (self ):
220185 # Construct a 4x4 integer tensor as expected result:
221186 # [[ 0, 1, 2, 3],
222187 # [ 4, 5, 6, 7],
@@ -237,8 +202,8 @@ def run_test_case(self):
237202 device_num = int (os .getenv ("device_num" ))
238203 if device_num == 1 :
239204 self .test_save_state_dict_with_one_device ()
240- elif device_num == 4 :
241- self .test_save_state_dict_with_four_devices ()
205+ elif device_num == 2 :
206+ self .test_save_state_dict_with_two_devices ()
242207
243208
244209if __name__ == "__main__" :
0 commit comments