3535# - Unified Attention
3636# - More dispatcher attention backends
3737# - CFG/Data Parallel
38- # - Tensor Parallel
3938
4039
4140@dataclass
@@ -142,6 +141,63 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
142141 self ._ulysses_local_rank = self ._ulysses_mesh .get_local_rank ()
143142
144143
144+ @dataclass
145+ class TensorParallelConfig :
146+ """
147+ Configuration for tensor parallelism.
148+
149+ Tensor parallelism shards weight matrices (column-wise and row-wise) across devices.
150+ Each device computes a partial result; an AllReduce/AllGather at layer boundaries
151+ reconstructs the full output. Uses ``torch.distributed.tensor.parallelize_module``
152+ with ``ColwiseParallel`` / ``RowwiseParallel`` sharding styles.
153+
154+ On Neuron, use the ``_pre_shard_and_tp`` workaround from
155+ ``transformer_flux2_neuron_tp`` to avoid the NRT consecutive-reduce-scatter bug
156+ on large tensors (>= 5120x5120).
157+
158+ Args:
159+ tp_degree (`int`, defaults to `1`):
160+ Number of devices to shard across. Must be a divisor of the number of
161+ attention heads (and FFN hidden dimensions) of the model being parallelised.
162+ mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
163+ A custom device mesh to use. If provided, ``tp_degree`` is inferred from
164+ ``mesh.size()`` and the argument is ignored. Useful when combining TP with
165+ other parallelism strategies (e.g. CP) that share the same mesh.
166+ """
167+
168+ tp_degree : int = 1
169+ mesh : torch .distributed .device_mesh .DeviceMesh | None = None
170+
171+ _rank : int = None
172+ _world_size : int = None
173+ _device : torch .device = None
174+ _mesh : torch .distributed .device_mesh .DeviceMesh = None
175+
176+ def __post_init__ (self ):
177+ if self .tp_degree < 1 :
178+ raise ValueError ("`tp_degree` must be >= 1." )
179+
180+ def setup (
181+ self ,
182+ rank : int ,
183+ world_size : int ,
184+ device : torch .device ,
185+ mesh : torch .distributed .device_mesh .DeviceMesh | None = None ,
186+ ):
187+ self ._rank = rank
188+ self ._world_size = world_size
189+ self ._device = device
190+ if mesh is not None :
191+ self ._mesh = mesh
192+ elif self .mesh is not None :
193+ self ._mesh = self .mesh
194+ else :
195+ from torch .distributed .device_mesh import init_device_mesh
196+
197+ device_type = str (device ).split (":" )[0 ]
198+ self ._mesh = init_device_mesh (device_type , (self .tp_degree ,), mesh_dim_names = ("tp" ,))
199+
200+
145201@dataclass
146202class ParallelConfig :
147203 """
@@ -150,9 +206,12 @@ class ParallelConfig:
150206 Args:
151207 context_parallel_config (`ContextParallelConfig`, *optional*):
152208 Configuration for context parallelism.
209+ tensor_parallel_config (`TensorParallelConfig`, *optional*):
210+ Configuration for tensor parallelism.
153211 """
154212
155213 context_parallel_config : ContextParallelConfig | None = None
214+ tensor_parallel_config : TensorParallelConfig | None = None
156215
157216 _rank : int = None
158217 _world_size : int = None
@@ -173,6 +232,8 @@ def setup(
173232 self ._mesh = mesh
174233 if self .context_parallel_config is not None :
175234 self .context_parallel_config .setup (rank , world_size , device , mesh )
235+ if self .tensor_parallel_config is not None :
236+ self .tensor_parallel_config .setup (rank , world_size , device , mesh )
176237
177238
178239@dataclass (frozen = True )
0 commit comments