Skip to content

Commit 520595f

Browse files
author
Awni Hannun
committed
Nits
1 parent 33d4211 commit 520595f

File tree

4 files changed

+122
-72
lines changed

4 files changed

+122
-72
lines changed

docs/src/examples/data_parallelism.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
Data Parallelism
44
================
55

6-
MLX enables efficient data parallel distributed training through its distributed communication primitives.
6+
MLX enables efficient data parallel distributed training through its
7+
distributed communication primitives.
78

89
.. _training_example:
910

@@ -15,7 +16,7 @@ distributed training. Namely, we will average the gradients across a set of
1516
hosts before applying them to the model.
1617

1718
Our training loop looks like the following code snippet if we omit the model,
18-
dataset and optimizer initialization.
19+
dataset, and optimizer initialization.
1920

2021
.. code:: python
2122
@@ -63,7 +64,7 @@ everything else remaining the same.
6364
optimizer.update(model, grads)
6465
return loss
6566
66-
Utilizing ``nn.average_gradients``
67+
Using ``nn.average_gradients``
6768
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6869

6970
Although the code example above works correctly; it performs one communication
@@ -87,4 +88,4 @@ almost identical to the example above:
8788
8889
for x, y in dataset:
8990
loss = step(model, x, y)
90-
mx.eval(loss, model.parameters())
91+
mx.eval(loss, model.parameters())

docs/src/examples/tensor_parallelism.rst

Lines changed: 112 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,70 +3,73 @@
33
Tensor Parallelism
44
==================
55

6-
MLX enables efficient implementation of tensor parallelism *(TP)* through its implementation of distributed layers. In this example, we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.
7-
6+
MLX enables efficient implementation of tensor parallelism TP through its
7+
implementation of distributed layers. In this example, we will explore what
8+
these layers are and create a small inference script for Llama family
9+
transformer models using MLX tensor parallelism.
810

911
Sharded Layers
1012
--------------
1113

12-
1314
:class:`AllToShardedLinear <mlx.nn.AllToShardedLinear>`
1415
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1516

16-
Column-wise tensor parallelism. This layer replicates a common input and shards the weight matrix along the output dimension (column-wise) across all devices in the :class:`mlx.core.distributed.Group`. The layer produces a sharded output.
17+
This layer replicates a common input and shards the weight matrix along the
18+
output dimension across all devices in the :class:`mlx.core.distributed.Group`.
19+
The layer produces a sharded output.
1720

18-
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix column-wise across the two devices, where each device receives the full input and computes a partial output.
21+
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with
22+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
23+
and a device group with 2 devices. The layer shards the weight matrix along the
24+
output dimension across the two devices, where each device receives the full
25+
input and computes a partial output.
1926

2027
.. raw:: html
2128

2229
<div>
2330
<img src="../_static/tp_inference/all-to-sharded-linear.png" alt="column-wise tensor parallelism" style="width: 100%">
24-
<p style="font-size: 0.85em; margin-top: 0.5em; color:gray;"><small>Check out <a href="https://huggingface.co/spaces/gxa33/ultrascale-playbook?section=tensor_parallelism_in_a_transformer_block">huggingface ultrascale-playbook</a> to learn more about tensor parallelism in LLMs.</small></p>
2531
</div>
2632

27-
This layer does not automatically gather all outputs from each device. This is an intended and :ref:`useful design choice <useful_design_choices>`.
33+
This layer does not automatically gather all outputs from each device. This is
34+
an intended and :ref:`useful design choice <useful_design_choices>`.
2835

29-
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
30-
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
31-
will not be included in any gradient computation.
36+
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is
37+
the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to
38+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
39+
included in any gradient computation.
3240

3341

3442
:class:`ShardedToAllLinear <mlx.nn.ShardedToAllLinear>`
3543
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3644

37-
Row-wise tensor parallelism. This layer expects inputs that are sharded along the feature dimension (column-wise) and shards the weight matrix along the input dimension (row-wise) across all devices in the :class:`mlx.core.distributed.Group`. The layer automatically aggregates the results using :class:`mlx.core.distributed.all_sum`, so all devices in the group will have the same result.
45+
This layer expects inputs that are sharded along the feature dimension and
46+
shards the weight matrix along the input dimension across all devices in the
47+
:class:`mlx.core.distributed.Group`. The layer automatically aggregates the
48+
results using :class:`mlx.core.distributed.all_sum`, so all devices in the
49+
group will have the same result.
3850

39-
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix row-wise and the input column-wise across the two devices. Each device computes a ``(4,2)`` output, which is then aggregated with all other device outputs to get layer output.
51+
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with
52+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
53+
and a device group with 2 devices. The layer shards the weight matrix along the
54+
input dimension across the two devices. Each device computes a ``(4,2)``
55+
output, which is then aggregated with all other device outputs to get layer
56+
output.
4057

41-
.. raw:: html
58+
.. raw:: html
4259

4360
<div>
4461
<img src="../_static/tp_inference/sharded-to-all-linear.png" alt="row-wise tensor parallelism" style="width: 100%">
4562
</div>
4663

47-
This layer does not automatically shard the inputs along the feature dimension for you. It is necessary to create a "partial" input structure to feed into the layer. This is an intended and :ref:`useful design choice <useful_design_choices>`.
48-
49-
We can create partial inputs based on rank. For example, for an input with 1024 features, we can create sharded inputs in the following manner:
50-
51-
.. code-block:: python
52-
53-
world = mx.distributed.init()
54-
part = (
55-
slice(None), # batch dimension: keep all batches per feature
56-
slice(
57-
world.rank() * 1024 // world.size(), # start
58-
(world.rank() + 1) * 1024 // world.size(), # end
59-
),
60-
)
61-
62-
layer = nn.ShardedToAllLinear(1024, 1024, bias=False) # initialize the layer
63-
y = layer(x[part]) # process sharded input
64-
65-
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :ref:`Distributed Communication <usage_distributed>` page.
64+
This layer does not automatically shard the inputs along the feature dimension
65+
for you. It is necessary to create a "partial" input structure to feed into the
66+
layer. This is an intended and :ref:`useful design choice
67+
<useful_design_choices>`.
6668

67-
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
68-
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
69-
will not be included in any gradient computation.
69+
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is
70+
the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to
71+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
72+
included in any gradient computation.
7073

7174

7275
Shard Utility Functions
@@ -75,12 +78,21 @@ Shard Utility Functions
7578
:func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`
7679
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7780

78-
Converts a regular linear layer into a tensor parallel layer that distributes computation across multiple devices. Takes an existing :class:`mlx.nn.Linear` or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer (either :class:`mlx.nn.AllToShardedLinear` or :class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The original layer is not modified.
81+
Converts a regular linear layer into a tensor parallel layer that distributes
82+
computation across multiple devices. Takes an existing :class:`mlx.nn.Linear`
83+
or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer
84+
(either :class:`mlx.nn.AllToShardedLinear` or
85+
:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The
86+
original layer is not modified.
7987

8088
:func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>`
8189
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
8290

83-
Splits the parameters of an existing layer across multiple devices by modifying the layer in-place. Unlike :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`, this function does not create a new layer or add distributed communication. The layer itself must handle distributed communication if needed.
91+
Splits the parameters of an existing layer across multiple devices by modifying
92+
the layer in-place. Unlike :func:`shard_linear
93+
<mlx.nn.layers.distributed.shard_linear>`, this function does not create a new
94+
layer or add distributed communication. The layer itself must handle
95+
distributed communication if needed.
8496

8597

8698
.. _useful_design_choices:
@@ -90,11 +102,18 @@ Useful Design Choices
90102

91103
The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.
92104

93-
Column-wise and row-wise tensor parallel layers naturally go together because the output of the column-wise TP layer is exactly the size needed for the sharded input of a row-wise TP layer. This removes the need for an intermediate gather step between the layers, reducing communication overhead.
105+
All-to-sharded and sharded-to-all layers naturally go together because the
106+
output of the former layer is exactly the input needed needed for the latter.
107+
This removes the need for an intermediate gather step between the layers,
108+
reducing communication overhead.
94109

95-
This is why AllToShardedLinear does not aggregate results automatically and why ShardedToAllLinear does not shard inputs automatically. It is so that they can be placed in successive order and work together easily.
110+
This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results
111+
automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs
112+
automatically. It is so that they can be placed in successive order and work
113+
together easily.
96114

97-
We can demonstrate this through a simple model using our two types of distributed layers.
115+
We can demonstrate this through a simple model using our two types of
116+
distributed layers.
98117

99118
.. code-block:: python
100119
@@ -109,19 +128,24 @@ We can demonstrate this through a simple model using our two types of distribute
109128
.. raw:: html
110129

111130
<div>
112-
<img src="../_static/tp_inference/column-row-tp.png" alt="column-wise tensor parallelism" style="width: 100%">
113-
<p style="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using column-wise then row-wise tensor parallelism across 2 devices.</small></p>
131+
<img src="../_static/tp_inference/column-row-tp.png" alt="two layer tensor parallelism" style="width: 100%">
132+
<p style="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.</small></p>
114133
</div>
115134

116135

117136
LLM Inference with Tensor Parallelism
118137
-------------------------------------
119138

120-
We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices.
139+
We can apply these TP techniques to LLMs in order to enable inference for much
140+
larger models by sharding parameters from huge layers across multiple devices.
121141

122-
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
142+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama
143+
Inference <llama-inference>` example. In this example, we will use the same
144+
inference script as the Llama Inference example, which can be found in
145+
`mlx-examples`_.
123146

124-
Our first edit is to initialize the distributed communication group and get the current process rank:
147+
Our first edit is to initialize the distributed communication group and get the
148+
current process rank:
125149

126150
.. code-block:: python
127151
@@ -137,33 +161,57 @@ Next, let's look at the current architecture of the transformer block and see ho
137161
</div>
138162

139163

140-
This architecture has two natural places where our column-wise then row-wise tensor parallelism paradigm can be applied: the attention block and the FFN block. Both follow the same pattern: multiple parallel linear layers operating on the same input, followed by a single output linear layer. In the attention block, the Q, K, and V projections are sharded column-wise, and the output projection is sharded row-wise. In the FFN block, the gate and up projections are sharded column-wise, and the down projection is sharded row-wise.
141-
142-
The intermediate operations between the linear layers (RoPE, softmax, scaled dot-product attention in the attention block, and element-wise multiplication in the FFN block) do not impede the use of our TP paradigm. These operations are either:
143-
144-
- **Element-wise operations** (RoPE, element-wise multiplication): These operate independently on each element or position, preserving the sharding pattern without requiring cross-device communication.
145-
146-
- **Operations on non-sharded dimensions** (softmax, scaled dot-product attention): These operate along dimensions that are not sharded (such as the sequence length or head dimensions), so they can be computed independently on each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work correctly with sharded Q, K, V tensors because the matrix multiplications are performed along the sharded feature dimension, and the results remain properly sharded for the subsequent row-wise TP layer.
147-
148-
To implement sharding in our Llama inference, we use :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with distributed communication. This is easier than using :func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually in the :code:`__call__` function.
149-
150-
The following code shows how to shard the Attention block. The Q, K, and V projection layers are converted to column-wise sharded layers (all-to-sharded), while the output projection is converted to a row-wise sharded layer (sharded-to-all). The number of heads and repeats are also adjusted to account for the sharding:
164+
This architecture has two natural places where
165+
tensor parallelism can be applied: the attention block and the FFN
166+
block. Both follow the same pattern: multiple parallel linear layers operating
167+
on the same input, followed by a single output linear layer. In the attention
168+
block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output
169+
projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections
170+
become all-to-sharded layers, and the down projection becomes an sharded-to-all layer.
171+
172+
The intermediate operations between the linear layers (RoPE, softmax, scaled
173+
dot-product attention in the attention block, and element-wise multiplication
174+
in the FFN block) do not impede the use of our TP paradigm. These operations
175+
are either:
176+
177+
- **Element-wise operations** (RoPE, element-wise multiplication): These
178+
operate independently on each element or position, preserving the sharding
179+
pattern without requiring cross-device communication.
180+
181+
- **Operations on non-sharded dimensions** (softmax, scaled dot-product
182+
attention): These operate along dimensions that are not sharded (such as the
183+
sequence length or head dimensions), so they can be computed independently on
184+
each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work
185+
correctly with sharded Q, K, V tensors because the matrix multiplications are
186+
performed along the sharded feature dimension, and the results remain
187+
properly sharded for the subsequent sharded-to-all layer.
188+
189+
To implement sharding in our Llama inference, we use :func:`shard_linear
190+
<mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with
191+
distributed communication. This is easier than using :func:`shard_inplace
192+
<mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually
193+
in the :code:`__call__` function.
194+
195+
The following code shows how to shard the Attention block. The Q, K, and V
196+
projection layers are converted to all-to-sharded layers, while the output
197+
projection is converted to a sharded-to-all layer. The number of heads are also
198+
adjusted to account for the sharding:
151199

152200
.. code-block:: python
153201
154202
# ... in Attention class
155203
def shard(self, group: mx.distributed.Group):
156204
self.n_heads = self.n_heads // group.size()
157205
self.n_kv_heads = self.n_kv_heads // group.size()
158-
self.repeats = self.n_heads // self.n_kv_heads
159206
160207
self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group)
161208
self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group)
162209
self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group)
163210
self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group)
164211
165-
Similarly, the FeedForward block is sharded by converting the gate (w1) and up (w3) projections to column-wise sharded layers, and the down projection (w2) to a row-wise sharded layer:
166-
212+
Similarly, the FeedForward block is sharded by converting the gate (w1) and up
213+
(w3) projections to all-to-sharded layers, and the down projection (w2) to
214+
a sharded-to-all layer:
167215

168216
.. code-block:: python
169217
@@ -173,7 +221,8 @@ Similarly, the FeedForward block is sharded by converting the gate (w1) and up (
173221
self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group)
174222
self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group)
175223
176-
Finally, in our :code:`load_model` function, we need to apply our sharding functions to all transformer layers when using multiple devices:
224+
Finally, in our :code:`load_model` function, we need to apply our sharding
225+
functions to all transformer layers when using multiple devices:
177226

178227
.. code-block:: python
179228
@@ -184,6 +233,8 @@ Finally, in our :code:`load_model` function, we need to apply our sharding funct
184233
layer.attention.shard(group=world)
185234
layer.feed_forward.shard(group=world)
186235
187-
This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also run it across two (or more) devices via :code:`mlx.launch -n 2 llama.py`.
236+
This allows us to use the llama inference file as normal when running
237+
:code:`python llama.py`, but now we can also run it across two (or more)
238+
devices via :code:`mlx.launch -n 2 llama.py`.
188239

189240
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama

0 commit comments

Comments
 (0)