Skip to content

Commit fe98297

Browse files
authored
Updated the export document with custom llm documentation (pytorch#19460)
### Summary Fixes pytorch#8768 Update the custom LLM export documentation to better guide power users adapting non-standard transformer models for ExecuTorch. The page keeps the existing nanoGPT tutorial flow, then adds focused guidance on runner compatibility, KV cache ownership, attention/SDPA reuse, reusable LLM components, backend delegation considerations, and encoder-decoder model boundaries. This is a docs-only PR. No public APIs, runtime behavior, exporter behavior, or backend behavior are changed. ### Test plan ```bash lintrunner docs/source/llm/export-custom-llm.md cd docs make html ``` cc: @GregoryComer @nil-is-all cc @mergennachin @AlannaBurke @larryliu0820 @cccclai @helunwencser @jackzhxng @byjlw
1 parent 99f1f0b commit fe98297

1 file changed

Lines changed: 197 additions & 2 deletions

File tree

docs/source/llm/export-custom-llm.md

Lines changed: 197 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
# Exporting custom LLMs
22

3-
If you have your own PyTorch model that is an LLM, this guide will show you how to manually export and lower to ExecuTorch, with many of the same optimizations as covered in the previous `export_llm` guide.
3+
If you have your own PyTorch model that is an LLM, this guide shows how to
4+
manually export and lower it to ExecuTorch. Use this flow when your model is not
5+
covered by the native [`export_llm`](export-llm.md) API, is not directly handled
6+
by [Optimum ExecuTorch](export-llm-optimum.md), or needs model-specific changes
7+
before it can use the standard ExecuTorch LLM runtime.
48

59
This example uses Karpathy’s [nanoGPT](https://github.com/karpathy/nanoGPT), which is a minimal implementation of
6-
GPT-2 124M. This guide is applicable to other language models, as ExecuTorch is model-invariant.
10+
GPT-2 124M. The same manual export pattern applies broadly to PyTorch models.
11+
However, exporting a `.pte` file and running that file with the stock LLM runners
12+
are separate steps. To use the LLM runners, the exported model must also follow
13+
the runtime contract described below.
714

815

916
## Exporting to ExecuTorch (basic)
@@ -84,6 +91,165 @@ To export, run the script with `python export_nanogpt.py` (or python3, as approp
8491
For more information, see [Exporting to ExecuTorch](../tutorials/export-to-executorch-tutorial) <!-- @lint-ignore --> and
8592
[torch.export](https://pytorch.org/docs/stable/export.html).
8693

94+
## Using the LLM runners with a custom model
95+
96+
The exported `.pte` file can be loaded directly through the ExecuTorch runtime,
97+
but many text-generation applications use the higher-level LLM runners described
98+
in [Running LLMs with C++](run-with-c-plus-plus.md). These runners handle
99+
tokenization, prefill, decode, sampling, and streaming output. To use them with a
100+
custom model, shape the model boundary around autoregressive text generation:
101+
102+
A KV cache stores previous attention key and value tensors so decode can append
103+
new tokens without recomputing attention over the full context.
104+
105+
- The model should accept token IDs as the primary input.
106+
- If the model uses a KV cache, it should also accept a position input, often
107+
named `input_pos` or `start_pos`.
108+
- The model should return a single logits tensor that the runner can sample from.
109+
- The tokenizer file and BOS/EOS token IDs should match the model.
110+
- Cache tensors should normally be model-owned buffers, not extra inputs and
111+
outputs passed through the runner.
112+
113+
A typical runner-compatible forward signature looks like:
114+
115+
```python
116+
def forward(
117+
self,
118+
tokens: torch.Tensor,
119+
input_pos: torch.Tensor,
120+
) -> torch.Tensor:
121+
...
122+
return logits
123+
```
124+
125+
Models without a KV cache may expose only the token input, but generation will
126+
usually be much slower because the model recomputes attention over the full
127+
context on each decode step.
128+
129+
The runner also reads metadata from the `.pte` file. At minimum, include the
130+
values that describe sequence limits, KV cache behavior, and tokenizer
131+
termination:
132+
133+
- `get_max_seq_len`: maximum number of tokens processed by one model invocation.
134+
- `get_max_context_len`: maximum context length remembered by the model.
135+
- `use_kv_cache`: whether the model has an internal KV cache.
136+
- `enable_dynamic_shape`: whether prefill can use dynamic sequence lengths.
137+
- `get_bos_id` and `get_eos_ids`: token IDs used by the runner.
138+
139+
For example:
140+
141+
```python
142+
metadata = {
143+
"get_bos_id": bos_id,
144+
"get_eos_ids": [eos_id],
145+
"get_max_seq_len": max_seq_len,
146+
"get_max_context_len": max_context_len,
147+
"use_kv_cache": True,
148+
"enable_dynamic_shape": True,
149+
}
150+
```
151+
152+
When manually exporting, serialize this metadata as constant methods. Constant
153+
methods are named values in the `.pte` file that the runner can query at load
154+
time:
155+
156+
```python
157+
edge_manager = to_edge(
158+
traced_model,
159+
constant_methods=metadata,
160+
compile_config=edge_config,
161+
)
162+
```
163+
164+
If your model needs additional runtime inputs, such as explicit cache tensors,
165+
attention masks, encoder outputs, or cross-attention state, the default text LLM
166+
runner is probably not the right boundary. In that case, either wrap the model so
167+
that those values are stored inside the module, or build a custom runner or
168+
`IOManager` for the model-specific input and output protocol. An `IOManager` is
169+
the runner component that prepares model inputs and processes model outputs for
170+
prefill and decode.
171+
172+
Encoder-decoder models, such as translation models from Fairseq, are a common
173+
case where this distinction matters. ExecuTorch can run the exported program,
174+
but the stock text-generation runner is oriented around decoder-only generation.
175+
If the model is supported by Optimum ExecuTorch, prefer that path. Otherwise,
176+
decide whether to wrap the model into the runner-compatible shape or expose a
177+
custom runtime interface.
178+
179+
## Adapting attention and KV cache
180+
181+
Optimized LLM exports work well in ExecuTorch when attention and decode state are
182+
structured in an export-friendly way. The important design choice is to keep the
183+
runtime interface simple while moving mutable decode state into the module.
184+
185+
The optimized transformer implementations in ExecuTorch preserve a few
186+
properties that are useful to keep in a custom model:
187+
188+
- The exported graph is static enough for `torch.export`: tensor operations are
189+
traceable, and generation state does not depend on Python-side mutation.
190+
- The runner boundary stays small: tokens and optional position go in, logits
191+
come out, and metadata describes how to drive generation.
192+
- KV cache state is stored in model buffers and updated by tensor position.
193+
- Attention is factored so standard scaled dot product attention (SDPA) can be
194+
replaced by optimized or backend-specific SDPA implementations when the tensor
195+
layout matches.
196+
- Large compute patterns, such as linear layers and attention, stay recognizable
197+
to backend partitioners.
198+
199+
For KV cache support:
200+
201+
- Register key and value caches as module buffers so they are part of the
202+
exported program state.
203+
- Update cache entries using the tensor position passed to the model, rather than
204+
Python-side counters or data-dependent control flow.
205+
- Keep cache shapes predictable. Backends and custom operators often rely on
206+
fixed cache layout assumptions.
207+
- Return logits only. The default runner does not expect cache tensors as model
208+
outputs.
209+
- Reset or reinitialize cache state through the runner/runtime lifecycle, not by
210+
changing Python attributes during generation.
211+
212+
For attention:
213+
214+
- Prefer standard `torch.nn.functional.scaled_dot_product_attention` or an
215+
equivalent module boundary that can later be swapped for backend-specific
216+
attention.
217+
- Keep query, key, value, mask, and cache shapes explicit and stable.
218+
- First make the model exportable and correct, then apply SDPA, cache,
219+
quantization, and backend transforms for the targets you care about.
220+
221+
ExecuTorch includes optimized SDPA and cache-update custom operators used by the
222+
Llama export flow. You can leverage those paths when your model's attention
223+
layout matches the expected query/key/value/cache conventions. If your attention
224+
layout is different, it is usually better to adapt the module boundary first
225+
than to force the custom operator into an incompatible shape.
226+
227+
## Reusing LLM components
228+
229+
You do not need to copy the Llama implementation to build a custom model. The
230+
`extension/llm` tree contains reusable pieces that are useful when adapting a
231+
model for export:
232+
233+
- [`extension/llm/modules`](https://github.com/pytorch/executorch/tree/main/extension/llm/modules)
234+
contains export-friendly modules.
235+
- [`KVCache`](https://github.com/pytorch/executorch/blob/main/extension/llm/modules/kv_cache.py)
236+
provides an export-friendly cache implementation adapted from torchtune.
237+
- [`MultiHeadAttention`](https://github.com/pytorch/executorch/blob/main/extension/llm/modules/attention.py)
238+
factors SDPA out of the attention module so it can be replaced with optimized
239+
implementations.
240+
- [`examples/models/llama/source_transformation`](https://github.com/pytorch/executorch/tree/main/examples/models/llama/source_transformation)
241+
shows how the Llama flow swaps in custom SDPA, custom KV cache, quantized KV
242+
cache, and backend-specific attention variants.
243+
244+
These components are most useful as building blocks and reference
245+
implementations. Keep your model architecture readable and close to the original
246+
PyTorch version first, then replace individual pieces only when they improve
247+
export compatibility, runner compatibility, or backend performance.
248+
249+
If you are authoring or fine-tuning a transformer model from scratch, also look
250+
at [torchtune](https://github.com/pytorch/torchtune). Several ExecuTorch LLM
251+
modules are adapted from torchtune modules with changes for export and inference.
252+
87253
## Backend delegation
88254

89255
While ExecuTorch provides a portable, cross-platform implementation for all
@@ -110,6 +276,35 @@ To delegate the exported model to a specific backend, we need to import its
110276
partitioner as well as edge compile config from ExecuTorch codebase first, then
111277
call `to_edge_transform_and_lower`.
112278

279+
If you also added runner metadata earlier, pass the same metadata through the
280+
`constant_methods` argument in this call so the delegated `.pte` keeps the same
281+
runner-visible values.
282+
283+
For custom LLMs, backend performance depends on how much of the model graph the
284+
backend can recognize and delegate. Keep the following in mind when adapting the
285+
model:
286+
287+
- Inspect delegated and non-delegated operators after lowering with
288+
`get_delegation_info()`.
289+
- Prefer linear and attention patterns that leave model weights visible as
290+
constants to the partitioner.
291+
- Be careful with dynamic shapes inside delegated subgraphs. Dynamic prefill can
292+
be useful, but not every dynamic pattern is backend-friendly.
293+
- Export separate `.pte` files for different targets, such as XNNPACK for CPU
294+
and Core ML for Apple devices.
295+
296+
When targeting XNNPACK, use the XNNPACK partitioner and quantization flow. For
297+
details, see the [XNNPACK backend overview](../backends/xnnpack/xnnpack-overview.md),
298+
[XNNPACK quantization](../backends/xnnpack/xnnpack-quantization.md), and
299+
[XNNPACK troubleshooting](../backends/xnnpack/xnnpack-troubleshooting.md).
300+
301+
When targeting Core ML, follow the Core ML backend configuration and validate on
302+
the target Apple OS and hardware. Stateful KV cache and fused SDPA support can be
303+
backend- and OS-version dependent. For details, see the
304+
[Core ML backend overview](../backends/coreml/coreml-overview.md),
305+
[Core ML partitioner](../backends/coreml/coreml-partitioner.md), and
306+
[Core ML troubleshooting](../backends/coreml/coreml-troubleshooting.md).
307+
113308
Here's an example of how to delegate nanoGPT to XNNPACK (if you're deploying to an Android phone for instance):
114309

115310
```python

0 commit comments

Comments
 (0)