You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
MLX enables efficient implementation of tensor parallelism *(TP)* through its implementation of distributed layers. In this example, we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.
7
-
6
+
MLX enables efficient implementation of tensor parallelism TP through its
7
+
implementation of distributed layers. In this example, we will explore what
8
+
these layers are and create a small inference script for Llama family
Column-wise tensor parallelism. This layer replicates a common input and shards the weight matrix along the output dimension (column-wise) across all devices in the :class:`mlx.core.distributed.Group`. The layer produces a sharded output.
17
+
This layer replicates a common input and shards the weight matrix along the
18
+
output dimension across all devices in the :class:`mlx.core.distributed.Group`.
19
+
The layer produces a sharded output.
17
20
18
-
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix column-wise across the two devices, where each device receives the full input and computes a partial output.
21
+
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with
22
+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
23
+
and a device group with 2 devices. The layer shards the weight matrix along the
24
+
output dimension across the two devices, where each device receives the full
<pstyle="font-size: 0.85em; margin-top: 0.5em; color:gray;"><small>Check out <ahref="https://huggingface.co/spaces/gxa33/ultrascale-playbook?section=tensor_parallelism_in_a_transformer_block">huggingface ultrascale-playbook</a> to learn more about tensor parallelism in LLMs.</small></p>
25
31
</div>
26
32
27
-
This layer does not automatically gather all outputs from each device. This is an intended and :ref:`useful design choice <useful_design_choices>`.
33
+
This layer does not automatically gather all outputs from each device. This is
34
+
an intended and :ref:`useful design choice <useful_design_choices>`.
28
35
29
-
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
30
-
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
31
-
will not be included in any gradient computation.
36
+
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is
37
+
the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to
38
+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
Row-wise tensor parallelism. This layer expects inputs that are sharded along the feature dimension (column-wise) and shards the weight matrix along the input dimension (row-wise) across all devices in the :class:`mlx.core.distributed.Group`. The layer automatically aggregates the results using :class:`mlx.core.distributed.all_sum`, so all devices in the group will have the same result.
45
+
This layer expects inputs that are sharded along the feature dimension and
46
+
shards the weight matrix along the input dimension across all devices in the
47
+
:class:`mlx.core.distributed.Group`. The layer automatically aggregates the
48
+
results using :class:`mlx.core.distributed.all_sum`, so all devices in the
49
+
group will have the same result.
38
50
39
-
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix row-wise and the input column-wise across the two devices. Each device computes a ``(4,2)`` output, which is then aggregated with all other device outputs to get layer output.
51
+
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with
52
+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
53
+
and a device group with 2 devices. The layer shards the weight matrix along the
54
+
input dimension across the two devices. Each device computes a ``(4,2)``
55
+
output, which is then aggregated with all other device outputs to get layer
This layer does not automatically shard the inputs along the feature dimension for you. It is necessary to create a "partial" input structure to feed into the layer. This is an intended and :ref:`useful design choice <useful_design_choices>`.
48
-
49
-
We can create partial inputs based on rank. For example, for an input with 1024 features, we can create sharded inputs in the following manner:
50
-
51
-
.. code-block:: python
52
-
53
-
world = mx.distributed.init()
54
-
part = (
55
-
slice(None), # batch dimension: keep all batches per feature
56
-
slice(
57
-
world.rank() *1024// world.size(), # start
58
-
(world.rank() +1) *1024// world.size(), # end
59
-
),
60
-
)
61
-
62
-
layer = nn.ShardedToAllLinear(1024, 1024, bias=False) # initialize the layer
63
-
y = layer(x[part]) # process sharded input
64
-
65
-
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :ref:`Distributed Communication <usage_distributed>` page.
64
+
This layer does not automatically shard the inputs along the feature dimension
65
+
for you. It is necessary to create a "partial" input structure to feed into the
66
+
layer. This is an intended and :ref:`useful design choice
67
+
<useful_design_choices>`.
66
68
67
-
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
68
-
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
69
-
will not be included in any gradient computation.
69
+
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is
70
+
the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to
71
+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
Converts a regular linear layer into a tensor parallel layer that distributes computation across multiple devices. Takes an existing :class:`mlx.nn.Linear` or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer (either :class:`mlx.nn.AllToShardedLinear` or :class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The original layer is not modified.
81
+
Converts a regular linear layer into a tensor parallel layer that distributes
82
+
computation across multiple devices. Takes an existing :class:`mlx.nn.Linear`
83
+
or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer
84
+
(either :class:`mlx.nn.AllToShardedLinear` or
85
+
:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The
Splits the parameters of an existing layer across multiple devices by modifying the layer in-place. Unlike :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`, this function does not create a new layer or add distributed communication. The layer itself must handle distributed communication if needed.
91
+
Splits the parameters of an existing layer across multiple devices by modifying
92
+
the layer in-place. Unlike :func:`shard_linear
93
+
<mlx.nn.layers.distributed.shard_linear>`, this function does not create a new
94
+
layer or add distributed communication. The layer itself must handle
95
+
distributed communication if needed.
84
96
85
97
86
98
.. _useful_design_choices:
@@ -90,11 +102,18 @@ Useful Design Choices
90
102
91
103
The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.
92
104
93
-
Column-wise and row-wise tensor parallel layers naturally go together because the output of the column-wise TP layer is exactly the size needed for the sharded input of a row-wise TP layer. This removes the need for an intermediate gather step between the layers, reducing communication overhead.
105
+
All-to-sharded and sharded-to-all layers naturally go together because the
106
+
output of the former layer is exactly the input needed needed for the latter.
107
+
This removes the need for an intermediate gather step between the layers,
108
+
reducing communication overhead.
94
109
95
-
This is why AllToShardedLinear does not aggregate results automatically and why ShardedToAllLinear does not shard inputs automatically. It is so that they can be placed in successive order and work together easily.
110
+
This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results
111
+
automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs
112
+
automatically. It is so that they can be placed in successive order and work
113
+
together easily.
96
114
97
-
We can demonstrate this through a simple model using our two types of distributed layers.
115
+
We can demonstrate this through a simple model using our two types of
116
+
distributed layers.
98
117
99
118
.. code-block:: python
100
119
@@ -109,19 +128,24 @@ We can demonstrate this through a simple model using our two types of distribute
<pstyle="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using column-wise then row-wise tensor parallelism across 2 devices.</small></p>
<pstyle="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.</small></p>
114
133
</div>
115
134
116
135
117
136
LLM Inference with Tensor Parallelism
118
137
-------------------------------------
119
138
120
-
We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices.
139
+
We can apply these TP techniques to LLMs in order to enable inference for much
140
+
larger models by sharding parameters from huge layers across multiple devices.
121
141
122
-
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
142
+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama
143
+
Inference <llama-inference>` example. In this example, we will use the same
144
+
inference script as the Llama Inference example, which can be found in
145
+
`mlx-examples`_.
123
146
124
-
Our first edit is to initialize the distributed communication group and get the current process rank:
147
+
Our first edit is to initialize the distributed communication group and get the
148
+
current process rank:
125
149
126
150
.. code-block:: python
127
151
@@ -137,33 +161,57 @@ Next, let's look at the current architecture of the transformer block and see ho
137
161
</div>
138
162
139
163
140
-
This architecture has two natural places where our column-wise then row-wise tensor parallelism paradigm can be applied: the attention block and the FFN block. Both follow the same pattern: multiple parallel linear layers operating on the same input, followed by a single output linear layer. In the attention block, the Q, K, and V projections are sharded column-wise, and the output projection is sharded row-wise. In the FFN block, the gate and up projections are sharded column-wise, and the down projection is sharded row-wise.
141
-
142
-
The intermediate operations between the linear layers (RoPE, softmax, scaled dot-product attention in the attention block, and element-wise multiplication in the FFN block) do not impede the use of our TP paradigm. These operations are either:
143
-
144
-
- **Element-wise operations** (RoPE, element-wise multiplication): These operate independently on each element or position, preserving the sharding pattern without requiring cross-device communication.
145
-
146
-
- **Operations on non-sharded dimensions** (softmax, scaled dot-product attention): These operate along dimensions that are not sharded (such as the sequence length or head dimensions), so they can be computed independently on each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work correctly with sharded Q, K, V tensors because the matrix multiplications are performed along the sharded feature dimension, and the results remain properly sharded for the subsequent row-wise TP layer.
147
-
148
-
To implement sharding in our Llama inference, we use :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with distributed communication. This is easier than using :func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually in the :code:`__call__` function.
149
-
150
-
The following code shows how to shard the Attention block. The Q, K, and V projection layers are converted to column-wise sharded layers (all-to-sharded), while the output projection is converted to a row-wise sharded layer (sharded-to-all). The number of heads and repeats are also adjusted to account for the sharding:
164
+
This architecture has two natural places where
165
+
tensor parallelism can be applied: the attention block and the FFN
166
+
block. Both follow the same pattern: multiple parallel linear layers operating
167
+
on the same input, followed by a single output linear layer. In the attention
168
+
block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output
169
+
projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections
170
+
become all-to-sharded layers, and the down projection becomes an sharded-to-all layer.
171
+
172
+
The intermediate operations between the linear layers (RoPE, softmax, scaled
173
+
dot-product attention in the attention block, and element-wise multiplication
174
+
in the FFN block) do not impede the use of our TP paradigm. These operations
175
+
are either:
176
+
177
+
- **Element-wise operations** (RoPE, element-wise multiplication): These
178
+
operate independently on each element or position, preserving the sharding
179
+
pattern without requiring cross-device communication.
180
+
181
+
- **Operations on non-sharded dimensions** (softmax, scaled dot-product
182
+
attention): These operate along dimensions that are not sharded (such as the
183
+
sequence length or head dimensions), so they can be computed independently on
184
+
each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work
185
+
correctly with sharded Q, K, V tensors because the matrix multiplications are
186
+
performed along the sharded feature dimension, and the results remain
187
+
properly sharded for the subsequent sharded-to-all layer.
188
+
189
+
To implement sharding in our Llama inference, we use :func:`shard_linear
190
+
<mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with
191
+
distributed communication. This is easier than using :func:`shard_inplace
192
+
<mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually
193
+
in the :code:`__call__` function.
194
+
195
+
The following code shows how to shard the Attention block. The Q, K, and V
196
+
projection layers are converted to all-to-sharded layers, while the output
197
+
projection is converted to a sharded-to-all layer. The number of heads are also
Similarly, the FeedForward block is sharded by converting the gate (w1) and up (w3) projections to column-wise sharded layers, and the down projection (w2) to a row-wise sharded layer:
166
-
212
+
Similarly, the FeedForward block is sharded by converting the gate (w1) and up
213
+
(w3) projections to all-to-sharded layers, and the down projection (w2) to
214
+
a sharded-to-all layer:
167
215
168
216
.. code-block:: python
169
217
@@ -173,7 +221,8 @@ Similarly, the FeedForward block is sharded by converting the gate (w1) and up (
Finally, in our :code:`load_model` function, we need to apply our sharding functions to all transformer layers when using multiple devices:
224
+
Finally, in our :code:`load_model` function, we need to apply our sharding
225
+
functions to all transformer layers when using multiple devices:
177
226
178
227
.. code-block:: python
179
228
@@ -184,6 +233,8 @@ Finally, in our :code:`load_model` function, we need to apply our sharding funct
184
233
layer.attention.shard(group=world)
185
234
layer.feed_forward.shard(group=world)
186
235
187
-
This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also run it across two (or more) devices via :code:`mlx.launch -n 2 llama.py`.
236
+
This allows us to use the llama inference file as normal when running
237
+
:code:`python llama.py`, but now we can also run it across two (or more)
0 commit comments