You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/examples/tensor_parallelism.rst
+33-13Lines changed: 33 additions & 13 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,10 +1,12 @@
1
1
Tensor Parallelism
2
2
==================
3
3
4
-
MLX enables efficient implementation of tensor parallelism *(TP)* through its implementation of distributed layers. In this example we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.
4
+
MLX enables efficient implementation of tensor parallelism *(TP)* through its implementation of distributed layers. In this example, we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.
<pstyle="font-size: 0.85em; margin-top: 0.5em;"><small>Check out <ahref="https://huggingface.co/spaces/gxa33/ultrascale-playbook?section=tensor_parallelism_in_a_transformer_block">huggingface ultrascale-playbook</a> to learn about TP more in depth.</small></p>
22
+
<pstyle="font-size: 0.85em; margin-top: 0.5em;color:gray;"><small>Check out <ahref="https://huggingface.co/spaces/gxa33/ultrascale-playbook?section=tensor_parallelism_in_a_transformer_block">huggingface ultrascale-playbook</a> to learn more about tensor parallelism in LLMs.</small></p>
21
23
</div>
22
24
23
25
This layer does not automatically gather all outputs from each device. This is an intended and :ref:`useful design choice <useful_design_choices>`.
24
26
25
27
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
26
-
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
28
+
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
@@ -60,16 +63,32 @@ We can create partial inputs based on rank. For example, for an input with 1024
60
63
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :doc:`Distributed Communication <../usage/distributed>` page.
61
64
62
65
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
63
-
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
66
+
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
Converts a regular linear layer into a tensor parallel layer that distributes computation across multiple devices. Takes an existing :class:`mlx.nn.Linear` or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer (either :class:`mlx.nn.AllToShardedLinear` or :class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The original layer is not modified.
Splits the parameters of an existing layer across multiple devices by modifying the layer in-place. Unlike :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`, this function does not create a new layer or add distributed communication. The layer itself must handle distributed communication if needed.
82
+
83
+
66
84
.. _useful_design_choices:
85
+
67
86
Useful Design Choices
68
87
---------------------
69
88
70
-
There are design choices made above related to when things are done automatically that are done on purpose to make model training / inference easier.
89
+
The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.
71
90
72
-
Column-wise and row-wise tensor parallel layers naturally go together due to the output of the column-wise TP layer being the exact size needed for the sharded input of a row-wise TP layer. This removes the need for an intermediary gather step between the layers, reducing communication overhead.
91
+
Column-wise and row-wise tensor parallel layers naturally go together because the output of the column-wise TP layer is exactly the size needed for the sharded input of a row-wise TP layer. This removes the need for an intermediate gather step between the layers, reducing communication overhead.
73
92
74
93
This is why AllToShardedLinear does not aggregate results automatically and why ShardedToAllLinear does not shard inputs automatically. It is so that they can be placed in successive order and work together easily.
75
94
@@ -92,13 +111,13 @@ We can demonstrate this through a simple model using our two types of distribute
92
111
<pstyle="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using column-wise then row-wise tensor parallelism across 2 devices.</small></p>
93
112
</div>
94
113
95
-
.. _llm_inference_with_tp:
114
+
96
115
LLM Inference with Tensor Parallelism
97
116
-------------------------------------
98
117
99
118
We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices.
100
119
101
-
To demonstrate how it is possible to do this, let's apply TP to the Transformer block of our :doc:`Llama Inference <../examples/llama-inference>` example. In this example we will use the same inference script as the Llama Inference example which can be found in `mlx-examples`_.
120
+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <../examples/llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
102
121
103
122
Our first edit is to initialize the distributed communication group and get the current process rank:
104
123
@@ -124,7 +143,7 @@ The intermediate operations between the linear layers (RoPE, softmax, scaled dot
124
143
125
144
- **Operations on non-sharded dimensions** (softmax, scaled dot-product attention): These operate along dimensions that are not sharded (such as the sequence length or head dimensions), so they can be computed independently on each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work correctly with sharded Q, K, V tensors because the matrix multiplications are performed along the sharded feature dimension, and the results remain properly sharded for the subsequent row-wise TP layer.
126
145
127
-
The :func:`mlx.nn.layers.distributed.shard_linear` utility function simplifies creating tensor parallel layers based on existing Linear layers. Similarly,:func:`mlx.nn.layers.distributed.shard_inplace` does the same thing but changes the existing Linear layer instead of creating a new one. If the input is a :class:`mlx.nn.QuantizedLinear` layer, it automatically returns the corresponding quantized tensor parallel layer.
146
+
To implement sharding in our Llama inference, we use :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with distributed communication. This is easier than using:func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually in the :code:`__call__` function.
128
147
129
148
The following code shows how to shard the Attention block. The Q, K, and V projection layers are converted to column-wise sharded layers (all-to-sharded), while the output projection is converted to a row-wise sharded layer (sharded-to-all). The number of heads and repeats are also adjusted to account for the sharding:
130
149
@@ -140,9 +159,10 @@ The following code shows how to shard the Attention block. The Q, K, and V proje
Similarly, the FeedForward block is sharded by converting the gate (w1) and up (w3) projections to column-wise sharded layers, and the down projection (w2) to a row-wise sharded layer:
145
164
165
+
146
166
.. code-block:: python
147
167
148
168
# ... in FeedForward class
@@ -162,6 +182,6 @@ Finally, in our :code:`load_model` function, we need to apply our sharding funct
162
182
layer.attention.shard(group=world)
163
183
layer.feed_forward.shard(group=world)
164
184
165
-
This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also use it running across two (or more) devices via :code:`mlx.launch -n 2 llama.py`.
185
+
This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also run it across two (or more) devices via :code:`mlx.launch -n 2 llama.py`.
0 commit comments