Skip to content

Commit 8e9e63a

Browse files
author
Awni Hannun
committed
Nits
1 parent 33d4211 commit 8e9e63a

4 files changed

Lines changed: 121 additions & 72 deletions

File tree

docs/src/examples/data_parallelism.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
Data Parallelism
44
================
55

6-
MLX enables efficient data parallel distributed training through its distributed communication primitives.
6+
MLX enables efficient data parallel distributed training through its
7+
distributed communication primitives.
78

89
.. _training_example:
910

@@ -15,7 +16,7 @@ distributed training. Namely, we will average the gradients across a set of
1516
hosts before applying them to the model.
1617

1718
Our training loop looks like the following code snippet if we omit the model,
18-
dataset and optimizer initialization.
19+
dataset, and optimizer initialization.
1920

2021
.. code:: python
2122
@@ -63,7 +64,7 @@ everything else remaining the same.
6364
optimizer.update(model, grads)
6465
return loss
6566
66-
Utilizing ``nn.average_gradients``
67+
Using ``nn.average_gradients``
6768
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6869

6970
Although the code example above works correctly; it performs one communication
@@ -87,4 +88,4 @@ almost identical to the example above:
8788
8889
for x, y in dataset:
8990
loss = step(model, x, y)
90-
mx.eval(loss, model.parameters())
91+
mx.eval(loss, model.parameters())

docs/src/examples/tensor_parallelism.rst

Lines changed: 111 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,70 +3,72 @@
33
Tensor Parallelism
44
==================
55

6-
MLX enables efficient implementation of tensor parallelism *(TP)* through its implementation of distributed layers. In this example, we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.
7-
6+
In this example, we will explore how tensor parallelism (TP) works in MLX. We
7+
will start with an overview of the distributed layers in ``mlx.nn`` and then
8+
show how to do tensor parallelism Llama-style transformer models.
89

910
Sharded Layers
1011
--------------
1112

12-
1313
:class:`AllToShardedLinear <mlx.nn.AllToShardedLinear>`
1414
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1515

16-
Column-wise tensor parallelism. This layer replicates a common input and shards the weight matrix along the output dimension (column-wise) across all devices in the :class:`mlx.core.distributed.Group`. The layer produces a sharded output.
16+
This layer replicates a common input and shards the weight matrix along the
17+
output dimension across all devices in the :class:`mlx.core.distributed.Group`.
18+
The layer produces a sharded output.
1719

18-
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix column-wise across the two devices, where each device receives the full input and computes a partial output.
20+
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with
21+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
22+
and a device group with 2 devices. The layer shards the weight matrix along the
23+
output dimension across the two devices, where each device receives the full
24+
input and computes a partial output.
1925

2026
.. raw:: html
2127

2228
<div>
2329
<img src="../_static/tp_inference/all-to-sharded-linear.png" alt="column-wise tensor parallelism" style="width: 100%">
24-
<p style="font-size: 0.85em; margin-top: 0.5em; color:gray;"><small>Check out <a href="https://huggingface.co/spaces/gxa33/ultrascale-playbook?section=tensor_parallelism_in_a_transformer_block">huggingface ultrascale-playbook</a> to learn more about tensor parallelism in LLMs.</small></p>
2530
</div>
2631

27-
This layer does not automatically gather all outputs from each device. This is an intended and :ref:`useful design choice <useful_design_choices>`.
32+
This layer does not automatically gather all outputs from each device. This is
33+
an intended and :ref:`useful design choice <useful_design_choices>`.
2834

29-
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
30-
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
31-
will not be included in any gradient computation.
35+
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is
36+
the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to
37+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
38+
included in any gradient computation.
3239

3340

3441
:class:`ShardedToAllLinear <mlx.nn.ShardedToAllLinear>`
3542
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3643

37-
Row-wise tensor parallelism. This layer expects inputs that are sharded along the feature dimension (column-wise) and shards the weight matrix along the input dimension (row-wise) across all devices in the :class:`mlx.core.distributed.Group`. The layer automatically aggregates the results using :class:`mlx.core.distributed.all_sum`, so all devices in the group will have the same result.
44+
This layer expects inputs that are sharded along the feature dimension and
45+
shards the weight matrix along the input dimension across all devices in the
46+
:class:`mlx.core.distributed.Group`. The layer automatically aggregates the
47+
results using :class:`mlx.core.distributed.all_sum`, so all devices in the
48+
group will have the same result.
3849

39-
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with ``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``, and a device group with 2 devices. The layer shards the weight matrix row-wise and the input column-wise across the two devices. Each device computes a ``(4,2)`` output, which is then aggregated with all other device outputs to get layer output.
50+
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with
51+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
52+
and a device group with 2 devices. The layer shards the weight matrix along the
53+
input dimension across the two devices. Each device computes a ``(4,2)``
54+
output, which is then aggregated with all other device outputs to get layer
55+
output.
4056

41-
.. raw:: html
57+
.. raw:: html
4258

4359
<div>
4460
<img src="../_static/tp_inference/sharded-to-all-linear.png" alt="row-wise tensor parallelism" style="width: 100%">
4561
</div>
4662

47-
This layer does not automatically shard the inputs along the feature dimension for you. It is necessary to create a "partial" input structure to feed into the layer. This is an intended and :ref:`useful design choice <useful_design_choices>`.
48-
49-
We can create partial inputs based on rank. For example, for an input with 1024 features, we can create sharded inputs in the following manner:
50-
51-
.. code-block:: python
52-
53-
world = mx.distributed.init()
54-
part = (
55-
slice(None), # batch dimension: keep all batches per feature
56-
slice(
57-
world.rank() * 1024 // world.size(), # start
58-
(world.rank() + 1) * 1024 // world.size(), # end
59-
),
60-
)
61-
62-
layer = nn.ShardedToAllLinear(1024, 1024, bias=False) # initialize the layer
63-
y = layer(x[part]) # process sharded input
64-
65-
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :ref:`Distributed Communication <usage_distributed>` page.
63+
This layer does not automatically shard the inputs along the feature dimension
64+
for you. It is necessary to create a "partial" input structure to feed into the
65+
layer. This is an intended and :ref:`useful design choice
66+
<useful_design_choices>`.
6667

67-
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
68-
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
69-
will not be included in any gradient computation.
68+
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is
69+
the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to
70+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
71+
included in any gradient computation.
7072

7173

7274
Shard Utility Functions
@@ -75,12 +77,21 @@ Shard Utility Functions
7577
:func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`
7678
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7779

78-
Converts a regular linear layer into a tensor parallel layer that distributes computation across multiple devices. Takes an existing :class:`mlx.nn.Linear` or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer (either :class:`mlx.nn.AllToShardedLinear` or :class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The original layer is not modified.
80+
Converts a regular linear layer into a tensor parallel layer that distributes
81+
computation across multiple devices. Takes an existing :class:`mlx.nn.Linear`
82+
or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer
83+
(either :class:`mlx.nn.AllToShardedLinear` or
84+
:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The
85+
original layer is not modified.
7986

8087
:func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>`
8188
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
8289

83-
Splits the parameters of an existing layer across multiple devices by modifying the layer in-place. Unlike :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`, this function does not create a new layer or add distributed communication. The layer itself must handle distributed communication if needed.
90+
Splits the parameters of an existing layer across multiple devices by modifying
91+
the layer in-place. Unlike :func:`shard_linear
92+
<mlx.nn.layers.distributed.shard_linear>`, this function does not create a new
93+
layer or add distributed communication. The layer itself must handle
94+
distributed communication if needed.
8495

8596

8697
.. _useful_design_choices:
@@ -90,11 +101,18 @@ Useful Design Choices
90101

91102
The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.
92103

93-
Column-wise and row-wise tensor parallel layers naturally go together because the output of the column-wise TP layer is exactly the size needed for the sharded input of a row-wise TP layer. This removes the need for an intermediate gather step between the layers, reducing communication overhead.
104+
All-to-sharded and sharded-to-all layers naturally go together because the
105+
output of the former layer is exactly the input needed needed for the latter.
106+
This removes the need for an intermediate gather step between the layers,
107+
reducing communication overhead.
94108

95-
This is why AllToShardedLinear does not aggregate results automatically and why ShardedToAllLinear does not shard inputs automatically. It is so that they can be placed in successive order and work together easily.
109+
This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results
110+
automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs
111+
automatically. It is so that they can be placed in successive order and work
112+
together easily.
96113

97-
We can demonstrate this through a simple model using our two types of distributed layers.
114+
We can demonstrate this through a simple model using our two types of
115+
distributed layers.
98116

99117
.. code-block:: python
100118
@@ -109,19 +127,24 @@ We can demonstrate this through a simple model using our two types of distribute
109127
.. raw:: html
110128

111129
<div>
112-
<img src="../_static/tp_inference/column-row-tp.png" alt="column-wise tensor parallelism" style="width: 100%">
113-
<p style="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using column-wise then row-wise tensor parallelism across 2 devices.</small></p>
130+
<img src="../_static/tp_inference/column-row-tp.png" alt="two layer tensor parallelism" style="width: 100%">
131+
<p style="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.</small></p>
114132
</div>
115133

116134

117135
LLM Inference with Tensor Parallelism
118136
-------------------------------------
119137

120-
We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices.
138+
We can apply these TP techniques to LLMs in order to enable inference for much
139+
larger models by sharding parameters from huge layers across multiple devices.
121140

122-
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
141+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama
142+
Inference <llama-inference>` example. In this example, we will use the same
143+
inference script as the Llama Inference example, which can be found in
144+
`mlx-examples`_.
123145

124-
Our first edit is to initialize the distributed communication group and get the current process rank:
146+
Our first edit is to initialize the distributed communication group and get the
147+
current process rank:
125148

126149
.. code-block:: python
127150
@@ -137,33 +160,57 @@ Next, let's look at the current architecture of the transformer block and see ho
137160
</div>
138161

139162

140-
This architecture has two natural places where our column-wise then row-wise tensor parallelism paradigm can be applied: the attention block and the FFN block. Both follow the same pattern: multiple parallel linear layers operating on the same input, followed by a single output linear layer. In the attention block, the Q, K, and V projections are sharded column-wise, and the output projection is sharded row-wise. In the FFN block, the gate and up projections are sharded column-wise, and the down projection is sharded row-wise.
141-
142-
The intermediate operations between the linear layers (RoPE, softmax, scaled dot-product attention in the attention block, and element-wise multiplication in the FFN block) do not impede the use of our TP paradigm. These operations are either:
143-
144-
- **Element-wise operations** (RoPE, element-wise multiplication): These operate independently on each element or position, preserving the sharding pattern without requiring cross-device communication.
145-
146-
- **Operations on non-sharded dimensions** (softmax, scaled dot-product attention): These operate along dimensions that are not sharded (such as the sequence length or head dimensions), so they can be computed independently on each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work correctly with sharded Q, K, V tensors because the matrix multiplications are performed along the sharded feature dimension, and the results remain properly sharded for the subsequent row-wise TP layer.
147-
148-
To implement sharding in our Llama inference, we use :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with distributed communication. This is easier than using :func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually in the :code:`__call__` function.
149-
150-
The following code shows how to shard the Attention block. The Q, K, and V projection layers are converted to column-wise sharded layers (all-to-sharded), while the output projection is converted to a row-wise sharded layer (sharded-to-all). The number of heads and repeats are also adjusted to account for the sharding:
163+
This architecture has two natural places where
164+
tensor parallelism can be applied: the attention block and the FFN
165+
block. Both follow the same pattern: multiple parallel linear layers operating
166+
on the same input, followed by a single output linear layer. In the attention
167+
block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output
168+
projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections
169+
become all-to-sharded layers, and the down projection becomes an sharded-to-all layer.
170+
171+
The intermediate operations between the linear layers (RoPE, softmax, scaled
172+
dot-product attention in the attention block, and element-wise multiplication
173+
in the FFN block) do not impede the use of our TP paradigm. These operations
174+
are either:
175+
176+
- **Element-wise operations** (RoPE, element-wise multiplication): These
177+
operate independently on each element or position, preserving the sharding
178+
pattern without requiring cross-device communication.
179+
180+
- **Operations on non-sharded dimensions** (softmax, scaled dot-product
181+
attention): These operate along dimensions that are not sharded (such as the
182+
sequence length or head dimensions), so they can be computed independently on
183+
each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work
184+
correctly with sharded Q, K, V tensors because the matrix multiplications are
185+
performed along the sharded feature dimension, and the results remain
186+
properly sharded for the subsequent sharded-to-all layer.
187+
188+
To implement sharding in our Llama inference, we use :func:`shard_linear
189+
<mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with
190+
distributed communication. This is easier than using :func:`shard_inplace
191+
<mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually
192+
in the :code:`__call__` function.
193+
194+
The following code shows how to shard the Attention block. The Q, K, and V
195+
projection layers are converted to all-to-sharded layers, while the output
196+
projection is converted to a sharded-to-all layer. The number of heads are also
197+
adjusted to account for the sharding:
151198

152199
.. code-block:: python
153200
154201
# ... in Attention class
155202
def shard(self, group: mx.distributed.Group):
156203
self.n_heads = self.n_heads // group.size()
157204
self.n_kv_heads = self.n_kv_heads // group.size()
158-
self.repeats = self.n_heads // self.n_kv_heads
159205
160206
self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group)
161207
self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group)
162208
self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group)
163209
self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group)
164210
165-
Similarly, the FeedForward block is sharded by converting the gate (w1) and up (w3) projections to column-wise sharded layers, and the down projection (w2) to a row-wise sharded layer:
166-
211+
Similarly, the FeedForward block is sharded by converting the gate (w1) and up
212+
(w3) projections to all-to-sharded layers, and the down projection (w2) to
213+
a sharded-to-all layer:
167214

168215
.. code-block:: python
169216
@@ -173,7 +220,8 @@ Similarly, the FeedForward block is sharded by converting the gate (w1) and up (
173220
self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group)
174221
self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group)
175222
176-
Finally, in our :code:`load_model` function, we need to apply our sharding functions to all transformer layers when using multiple devices:
223+
Finally, in our :code:`load_model` function, we need to apply our sharding
224+
functions to all transformer layers when using multiple devices:
177225

178226
.. code-block:: python
179227
@@ -184,6 +232,8 @@ Finally, in our :code:`load_model` function, we need to apply our sharding funct
184232
layer.attention.shard(group=world)
185233
layer.feed_forward.shard(group=world)
186234
187-
This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also run it across two (or more) devices via :code:`mlx.launch -n 2 llama.py`.
235+
This allows us to use the llama inference file as normal when running
236+
:code:`python llama.py`, but now we can also run it across two (or more)
237+
devices via :code:`mlx.launch -n 2 llama.py`.
188238

189239
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama

0 commit comments

Comments
 (0)