You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
MLX enables efficient implementation of tensor parallelism *(TP)* through its implementation of distributed layers. In this example, we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.
7
-
6
+
In this example, we will explore how tensor parallelism (TP) works in MLX. We
7
+
will start with an overview of the distributed layers in ``mlx.nn`` and then
8
+
show how to do tensor parallelism Llama-style transformer models.
Column-wise tensor parallelism. This layer replicates a common input and shards the weight matrix along the output dimension (column-wise) across all devices in the :class:`mlx.core.distributed.Group`. The layer produces a sharded output.
16
+
This layer replicates a common input and shards the weight matrix along the
17
+
output dimension across all devices in the :class:`mlx.core.distributed.Group`.
18
+
The layer produces a sharded output.
17
19
18
-
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix column-wise across the two devices, where each device receives the full input and computes a partial output.
20
+
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with
21
+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
22
+
and a device group with 2 devices. The layer shards the weight matrix along the
23
+
output dimension across the two devices, where each device receives the full
<pstyle="font-size: 0.85em; margin-top: 0.5em; color:gray;"><small>Check out <ahref="https://huggingface.co/spaces/gxa33/ultrascale-playbook?section=tensor_parallelism_in_a_transformer_block">huggingface ultrascale-playbook</a> to learn more about tensor parallelism in LLMs.</small></p>
25
30
</div>
26
31
27
-
This layer does not automatically gather all outputs from each device. This is an intended and :ref:`useful design choice <useful_design_choices>`.
32
+
This layer does not automatically gather all outputs from each device. This is
33
+
an intended and :ref:`useful design choice <useful_design_choices>`.
28
34
29
-
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
30
-
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
31
-
will not be included in any gradient computation.
35
+
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is
36
+
the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to
37
+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
Row-wise tensor parallelism. This layer expects inputs that are sharded along the feature dimension (column-wise) and shards the weight matrix along the input dimension (row-wise) across all devices in the :class:`mlx.core.distributed.Group`. The layer automatically aggregates the results using :class:`mlx.core.distributed.all_sum`, so all devices in the group will have the same result.
44
+
This layer expects inputs that are sharded along the feature dimension and
45
+
shards the weight matrix along the input dimension across all devices in the
46
+
:class:`mlx.core.distributed.Group`. The layer automatically aggregates the
47
+
results using :class:`mlx.core.distributed.all_sum`, so all devices in the
48
+
group will have the same result.
38
49
39
-
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix row-wise and the input column-wise across the two devices. Each device computes a ``(4,2)`` output, which is then aggregated with all other device outputs to get layer output.
50
+
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with
51
+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
52
+
and a device group with 2 devices. The layer shards the weight matrix along the
53
+
input dimension across the two devices. Each device computes a ``(4,2)``
54
+
output, which is then aggregated with all other device outputs to get layer
This layer does not automatically shard the inputs along the feature dimension for you. It is necessary to create a "partial" input structure to feed into the layer. This is an intended and :ref:`useful design choice <useful_design_choices>`.
48
-
49
-
We can create partial inputs based on rank. For example, for an input with 1024 features, we can create sharded inputs in the following manner:
50
-
51
-
.. code-block:: python
52
-
53
-
world = mx.distributed.init()
54
-
part = (
55
-
slice(None), # batch dimension: keep all batches per feature
56
-
slice(
57
-
world.rank() *1024// world.size(), # start
58
-
(world.rank() +1) *1024// world.size(), # end
59
-
),
60
-
)
61
-
62
-
layer = nn.ShardedToAllLinear(1024, 1024, bias=False) # initialize the layer
63
-
y = layer(x[part]) # process sharded input
64
-
65
-
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :ref:`Distributed Communication <usage_distributed>` page.
63
+
This layer does not automatically shard the inputs along the feature dimension
64
+
for you. It is necessary to create a "partial" input structure to feed into the
65
+
layer. This is an intended and :ref:`useful design choice
66
+
<useful_design_choices>`.
66
67
67
-
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
68
-
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
69
-
will not be included in any gradient computation.
68
+
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is
69
+
the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to
70
+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
Converts a regular linear layer into a tensor parallel layer that distributes computation across multiple devices. Takes an existing :class:`mlx.nn.Linear` or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer (either :class:`mlx.nn.AllToShardedLinear` or :class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The original layer is not modified.
80
+
Converts a regular linear layer into a tensor parallel layer that distributes
81
+
computation across multiple devices. Takes an existing :class:`mlx.nn.Linear`
82
+
or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer
83
+
(either :class:`mlx.nn.AllToShardedLinear` or
84
+
:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The
Splits the parameters of an existing layer across multiple devices by modifying the layer in-place. Unlike :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`, this function does not create a new layer or add distributed communication. The layer itself must handle distributed communication if needed.
90
+
Splits the parameters of an existing layer across multiple devices by modifying
91
+
the layer in-place. Unlike :func:`shard_linear
92
+
<mlx.nn.layers.distributed.shard_linear>`, this function does not create a new
93
+
layer or add distributed communication. The layer itself must handle
94
+
distributed communication if needed.
84
95
85
96
86
97
.. _useful_design_choices:
@@ -90,11 +101,18 @@ Useful Design Choices
90
101
91
102
The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.
92
103
93
-
Column-wise and row-wise tensor parallel layers naturally go together because the output of the column-wise TP layer is exactly the size needed for the sharded input of a row-wise TP layer. This removes the need for an intermediate gather step between the layers, reducing communication overhead.
104
+
All-to-sharded and sharded-to-all layers naturally go together because the
105
+
output of the former layer is exactly the input needed needed for the latter.
106
+
This removes the need for an intermediate gather step between the layers,
107
+
reducing communication overhead.
94
108
95
-
This is why AllToShardedLinear does not aggregate results automatically and why ShardedToAllLinear does not shard inputs automatically. It is so that they can be placed in successive order and work together easily.
109
+
This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results
110
+
automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs
111
+
automatically. It is so that they can be placed in successive order and work
112
+
together easily.
96
113
97
-
We can demonstrate this through a simple model using our two types of distributed layers.
114
+
We can demonstrate this through a simple model using our two types of
115
+
distributed layers.
98
116
99
117
.. code-block:: python
100
118
@@ -109,19 +127,24 @@ We can demonstrate this through a simple model using our two types of distribute
<pstyle="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using column-wise then row-wise tensor parallelism across 2 devices.</small></p>
<pstyle="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.</small></p>
114
132
</div>
115
133
116
134
117
135
LLM Inference with Tensor Parallelism
118
136
-------------------------------------
119
137
120
-
We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices.
138
+
We can apply these TP techniques to LLMs in order to enable inference for much
139
+
larger models by sharding parameters from huge layers across multiple devices.
121
140
122
-
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
141
+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama
142
+
Inference <llama-inference>` example. In this example, we will use the same
143
+
inference script as the Llama Inference example, which can be found in
144
+
`mlx-examples`_.
123
145
124
-
Our first edit is to initialize the distributed communication group and get the current process rank:
146
+
Our first edit is to initialize the distributed communication group and get the
147
+
current process rank:
125
148
126
149
.. code-block:: python
127
150
@@ -137,33 +160,57 @@ Next, let's look at the current architecture of the transformer block and see ho
137
160
</div>
138
161
139
162
140
-
This architecture has two natural places where our column-wise then row-wise tensor parallelism paradigm can be applied: the attention block and the FFN block. Both follow the same pattern: multiple parallel linear layers operating on the same input, followed by a single output linear layer. In the attention block, the Q, K, and V projections are sharded column-wise, and the output projection is sharded row-wise. In the FFN block, the gate and up projections are sharded column-wise, and the down projection is sharded row-wise.
141
-
142
-
The intermediate operations between the linear layers (RoPE, softmax, scaled dot-product attention in the attention block, and element-wise multiplication in the FFN block) do not impede the use of our TP paradigm. These operations are either:
143
-
144
-
- **Element-wise operations** (RoPE, element-wise multiplication): These operate independently on each element or position, preserving the sharding pattern without requiring cross-device communication.
145
-
146
-
- **Operations on non-sharded dimensions** (softmax, scaled dot-product attention): These operate along dimensions that are not sharded (such as the sequence length or head dimensions), so they can be computed independently on each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work correctly with sharded Q, K, V tensors because the matrix multiplications are performed along the sharded feature dimension, and the results remain properly sharded for the subsequent row-wise TP layer.
147
-
148
-
To implement sharding in our Llama inference, we use :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with distributed communication. This is easier than using :func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually in the :code:`__call__` function.
149
-
150
-
The following code shows how to shard the Attention block. The Q, K, and V projection layers are converted to column-wise sharded layers (all-to-sharded), while the output projection is converted to a row-wise sharded layer (sharded-to-all). The number of heads and repeats are also adjusted to account for the sharding:
163
+
This architecture has two natural places where
164
+
tensor parallelism can be applied: the attention block and the FFN
165
+
block. Both follow the same pattern: multiple parallel linear layers operating
166
+
on the same input, followed by a single output linear layer. In the attention
167
+
block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output
168
+
projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections
169
+
become all-to-sharded layers, and the down projection becomes an sharded-to-all layer.
170
+
171
+
The intermediate operations between the linear layers (RoPE, softmax, scaled
172
+
dot-product attention in the attention block, and element-wise multiplication
173
+
in the FFN block) do not impede the use of our TP paradigm. These operations
174
+
are either:
175
+
176
+
- **Element-wise operations** (RoPE, element-wise multiplication): These
177
+
operate independently on each element or position, preserving the sharding
178
+
pattern without requiring cross-device communication.
179
+
180
+
- **Operations on non-sharded dimensions** (softmax, scaled dot-product
181
+
attention): These operate along dimensions that are not sharded (such as the
182
+
sequence length or head dimensions), so they can be computed independently on
183
+
each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work
184
+
correctly with sharded Q, K, V tensors because the matrix multiplications are
185
+
performed along the sharded feature dimension, and the results remain
186
+
properly sharded for the subsequent sharded-to-all layer.
187
+
188
+
To implement sharding in our Llama inference, we use :func:`shard_linear
189
+
<mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with
190
+
distributed communication. This is easier than using :func:`shard_inplace
191
+
<mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually
192
+
in the :code:`__call__` function.
193
+
194
+
The following code shows how to shard the Attention block. The Q, K, and V
195
+
projection layers are converted to all-to-sharded layers, while the output
196
+
projection is converted to a sharded-to-all layer. The number of heads are also
Similarly, the FeedForward block is sharded by converting the gate (w1) and up (w3) projections to column-wise sharded layers, and the down projection (w2) to a row-wise sharded layer:
166
-
211
+
Similarly, the FeedForward block is sharded by converting the gate (w1) and up
212
+
(w3) projections to all-to-sharded layers, and the down projection (w2) to
213
+
a sharded-to-all layer:
167
214
168
215
.. code-block:: python
169
216
@@ -173,7 +220,8 @@ Similarly, the FeedForward block is sharded by converting the gate (w1) and up (
Finally, in our :code:`load_model` function, we need to apply our sharding functions to all transformer layers when using multiple devices:
223
+
Finally, in our :code:`load_model` function, we need to apply our sharding
224
+
functions to all transformer layers when using multiple devices:
177
225
178
226
.. code-block:: python
179
227
@@ -184,6 +232,8 @@ Finally, in our :code:`load_model` function, we need to apply our sharding funct
184
232
layer.attention.shard(group=world)
185
233
layer.feed_forward.shard(group=world)
186
234
187
-
This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also run it across two (or more) devices via :code:`mlx.launch -n 2 llama.py`.
235
+
This allows us to use the llama inference file as normal when running
236
+
:code:`python llama.py`, but now we can also run it across two (or more)
0 commit comments