@@ -201,10 +201,10 @@ class MultiheadAttention(torch.nn.Module):
201201 parameters and if the DTensor DeviceMesh includes dimensions that do not
202202 shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
203203 For example:
204- - device_mesh["dp"] for FSDP.
205- - device_mesh["dp_cp"] if using CP ranks in FSDP.
206- - device_mesh["tp"] if using TP.
207- - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
204+ - device_mesh["dp"] for FSDP.
205+ - device_mesh["dp_cp"] if using CP ranks in FSDP.
206+ - device_mesh["tp"] if using TP.
207+ - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
208208
209209 Optimization parameters
210210 -----------------------
@@ -641,12 +641,9 @@ def set_device_mesh(
641641 """
642642 if tp_mesh is not None :
643643 # Validate TP DeviceMesh / Group. Must be consistent with tp_size.
644- assert (
645- tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (),
646- (
647- f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
648- f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
649- ),
644+ assert tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (), (
645+ f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
646+ f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
650647 )
651648 # Set the tensor parallel group from the mesh.
652649 self .set_tensor_parallel_group (tp_mesh .get_group ())
@@ -655,7 +652,7 @@ def set_device_mesh(
655652 # Iterate through child sub-modules without deep recursion.
656653 # Automatically detects TransformerEngine TP modules and
657654 # the capability to call this method at any level.
658- for name , child in self .named_children ():
655+ for child in self .children ():
659656 if hasattr (child , "set_device_mesh" ):
660657 child .set_device_mesh (tp_mesh , weight_mesh )
661658
0 commit comments