Skip to content

Commit 074cb2c

Browse files
stefpiAwni Hannun
authored andcommitted
docs: cleanup up tp inference
1 parent 50c2d71 commit 074cb2c

1 file changed

Lines changed: 33 additions & 13 deletions

File tree

docs/src/examples/tensor_parallelism.rst

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
Tensor Parallelism
22
==================
33

4-
MLX enables efficient implementation of tensor parallelism *(TP)* through its implementation of distributed layers. In this example we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.
4+
MLX enables efficient implementation of tensor parallelism *(TP)* through its implementation of distributed layers. In this example, we will explore what these layers are and create a small inference script for Llama family transformer models using MLX tensor parallelism.
5+
6+
7+
Sharded Layers
8+
--------------
59

6-
MLX Sharded Layers
7-
------------------
810

911
:class:`AllToShardedLinear <mlx.nn.AllToShardedLinear>`
1012
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -17,15 +19,16 @@ For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with ``input_d
1719

1820
<div>
1921
<img src="../_static/tp_inference/all-to-sharded-linear.png" alt="column-wise tensor parallelism" style="width: 100%">
20-
<p style="font-size: 0.85em; margin-top: 0.5em;"><small>Check out <a href="https://huggingface.co/spaces/gxa33/ultrascale-playbook?section=tensor_parallelism_in_a_transformer_block">huggingface ultrascale-playbook</a> to learn about TP more in depth.</small></p>
22+
<p style="font-size: 0.85em; margin-top: 0.5em; color:gray;"><small>Check out <a href="https://huggingface.co/spaces/gxa33/ultrascale-playbook?section=tensor_parallelism_in_a_transformer_block">huggingface ultrascale-playbook</a> to learn more about tensor parallelism in LLMs.</small></p>
2123
</div>
2224

2325
This layer does not automatically gather all outputs from each device. This is an intended and :ref:`useful design choice <useful_design_choices>`.
2426

2527
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
26-
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
28+
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
2729
will not be included in any gradient computation.
2830

31+
2932
:class:`ShardedToAllLinear <mlx.nn.ShardedToAllLinear>`
3033
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3134

@@ -60,16 +63,32 @@ We can create partial inputs based on rank. For example, for an input with 1024
6063
This code splits the 1024 input features into ``world.size()`` different groups which are assigned continuously based on ``world.rank()``. More information about distributed communication can be found in the :doc:`Distributed Communication <../usage/distributed>` page.
6164

6265
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
63-
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
66+
Similar to :class:`mlx.nn.QuantizedLinear`, its parameters are frozen and
6467
will not be included in any gradient computation.
6568

69+
70+
Shard Utility Functions
71+
-----------------------
72+
73+
:func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`
74+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
75+
76+
Converts a regular linear layer into a tensor parallel layer that distributes computation across multiple devices. Takes an existing :class:`mlx.nn.Linear` or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer (either :class:`mlx.nn.AllToShardedLinear` or :class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The original layer is not modified.
77+
78+
:func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>`
79+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
80+
81+
Splits the parameters of an existing layer across multiple devices by modifying the layer in-place. Unlike :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`, this function does not create a new layer or add distributed communication. The layer itself must handle distributed communication if needed.
82+
83+
6684
.. _useful_design_choices:
85+
6786
Useful Design Choices
6887
---------------------
6988

70-
There are design choices made above related to when things are done automatically that are done on purpose to make model training / inference easier.
89+
The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.
7190

72-
Column-wise and row-wise tensor parallel layers naturally go together due to the output of the column-wise TP layer being the exact size needed for the sharded input of a row-wise TP layer. This removes the need for an intermediary gather step between the layers, reducing communication overhead.
91+
Column-wise and row-wise tensor parallel layers naturally go together because the output of the column-wise TP layer is exactly the size needed for the sharded input of a row-wise TP layer. This removes the need for an intermediate gather step between the layers, reducing communication overhead.
7392

7493
This is why AllToShardedLinear does not aggregate results automatically and why ShardedToAllLinear does not shard inputs automatically. It is so that they can be placed in successive order and work together easily.
7594

@@ -92,13 +111,13 @@ We can demonstrate this through a simple model using our two types of distribute
92111
<p style="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using column-wise then row-wise tensor parallelism across 2 devices.</small></p>
93112
</div>
94113

95-
.. _llm_inference_with_tp:
114+
96115
LLM Inference with Tensor Parallelism
97116
-------------------------------------
98117

99118
We can apply these TP techniques to LLMs in order to enable inference for much larger models by sharding parameters from huge layers across multiple devices.
100119

101-
To demonstrate how it is possible to do this, let's apply TP to the Transformer block of our :doc:`Llama Inference <../examples/llama-inference>` example. In this example we will use the same inference script as the Llama Inference example which can be found in `mlx-examples`_.
120+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama Inference <../examples/llama-inference>` example. In this example, we will use the same inference script as the Llama Inference example, which can be found in `mlx-examples`_.
102121

103122
Our first edit is to initialize the distributed communication group and get the current process rank:
104123

@@ -124,7 +143,7 @@ The intermediate operations between the linear layers (RoPE, softmax, scaled dot
124143

125144
- **Operations on non-sharded dimensions** (softmax, scaled dot-product attention): These operate along dimensions that are not sharded (such as the sequence length or head dimensions), so they can be computed independently on each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work correctly with sharded Q, K, V tensors because the matrix multiplications are performed along the sharded feature dimension, and the results remain properly sharded for the subsequent row-wise TP layer.
126145

127-
The :func:`mlx.nn.layers.distributed.shard_linear` utility function simplifies creating tensor parallel layers based on existing Linear layers. Similarly, :func:`mlx.nn.layers.distributed.shard_inplace` does the same thing but changes the existing Linear layer instead of creating a new one. If the input is a :class:`mlx.nn.QuantizedLinear` layer, it automatically returns the corresponding quantized tensor parallel layer.
146+
To implement sharding in our Llama inference, we use :func:`shard_linear <mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with distributed communication. This is easier than using :func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually in the :code:`__call__` function.
128147

129148
The following code shows how to shard the Attention block. The Q, K, and V projection layers are converted to column-wise sharded layers (all-to-sharded), while the output projection is converted to a row-wise sharded layer (sharded-to-all). The number of heads and repeats are also adjusted to account for the sharding:
130149

@@ -140,9 +159,10 @@ The following code shows how to shard the Attention block. The Q, K, and V proje
140159
self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group)
141160
self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group)
142161
self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group)
143-
162+
144163
Similarly, the FeedForward block is sharded by converting the gate (w1) and up (w3) projections to column-wise sharded layers, and the down projection (w2) to a row-wise sharded layer:
145164

165+
146166
.. code-block:: python
147167
148168
# ... in FeedForward class
@@ -162,6 +182,6 @@ Finally, in our :code:`load_model` function, we need to apply our sharding funct
162182
layer.attention.shard(group=world)
163183
layer.feed_forward.shard(group=world)
164184
165-
This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also use it running across two (or more) devices via :code:`mlx.launch -n 2 llama.py`.
185+
This allows us to use the llama inference file as normal when running :code:`python llama.py`, but now we can also run it across two (or more) devices via :code:`mlx.launch -n 2 llama.py`.
166186

167187
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama

0 commit comments

Comments
 (0)