Skip to content

Commit c90a466

Browse files
committed
docs: update torch.compile documentation
1 parent c341d98 commit c90a466

3 files changed

Lines changed: 63 additions & 6 deletions

File tree

docsrc/dynamo/torch_compile.rst

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,20 @@ Custom Setting Usage
4646
"optimization_level": 4,
4747
"use_python_runtime": False,})
4848
49-
.. note:: Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
49+
.. note:: Torch-TensorRT supports FP32, FP16, and INT8 precision layers. For INT8 quantization, use the TensorRT Model Optimizer (modelopt) for post-training quantization (PTQ). See :ref:`vgg16_ptq` for an example.
50+
51+
Advanced Precision Control
52+
^^^^^^^^^^^^^^^^^
53+
54+
For fine-grained control over mixed precision execution, TensorRT 10.12+ provides additional settings:
55+
56+
* ``use_explicit_typing``: Enable explicit type specification (required for TensorRT 10.12+)
57+
* ``enable_autocast``: Enable rule-based autocast for automatic precision selection
58+
* ``autocast_low_precision_type``: Target precision for autocast (e.g., ``torch.float16``)
59+
* ``autocast_excluded_nodes``: Specific nodes to exclude from autocast
60+
* ``autocast_excluded_ops``: Operation types to exclude from autocast
61+
62+
For detailed information and examples, see :ref:`mixed_precision`.
5063

5164
Compilation
5265
-----------------
@@ -98,14 +111,54 @@ Compilation can also be helpful in demonstrating graph breaks and the feasibilit
98111
print(f"Graph breaks: {explanation.graph_break_count}")
99112
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False, options={"truncate_long_and_double": True})
100113
114+
Engine Caching
115+
^^^^^^^^^^^^^^^^^
116+
Engine caching can significantly reduce recompilation times by saving built TensorRT engines to disk and reusing them when possible. This is particularly useful for JIT workflows where graphs may be invalidated and recompiled. When enabled, engines are saved with a hash of their corresponding PyTorch subgraph and can be reloaded in subsequent compilations—even across different Python sessions.
117+
118+
To enable engine caching, use the ``cache_built_engines`` and ``reuse_cached_engines`` options:
119+
120+
.. code-block:: python
121+
122+
import torch_tensorrt
123+
...
124+
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False,
125+
options={"cache_built_engines": True,
126+
"reuse_cached_engines": True,
127+
"immutable_weights": False,
128+
"engine_cache_dir": "/tmp/torch_trt_cache",
129+
"engine_cache_size": 1 << 30}) # 1GB
130+
131+
.. note:: To use engine caching, ``immutable_weights`` must be set to ``False`` to allow engine refitting. When a cached engine is loaded, weights are refitted rather than rebuilding the entire engine, which can reduce compilation times by orders of magnitude.
132+
133+
For more details and examples, see :ref:`engine_caching_example`.
134+
101135
Dynamic Shape Support
102136
-----------------
103137

104-
The Torch-TensorRT `torch.compile` backend will currently require recompilation for each new batch size encountered, and it is preferred to use the `dynamic=False` argument when compiling with this backend. Full dynamic shape support is planned for a future release.
138+
The Torch-TensorRT `torch.compile` backend now supports dynamic shapes, allowing models to handle varying input dimensions without recompilation. You can specify dynamic dimensions using the ``torch._dynamo.mark_dynamic`` API:
139+
140+
.. code-block:: python
141+
142+
import torch
143+
import torch_tensorrt
144+
...
145+
inputs = torch.randn((1, 3, 224, 224), dtype=torch.float32).cuda()
146+
# Mark dimension 0 (batch) as dynamic with range [1, 8]
147+
torch._dynamo.mark_dynamic(inputs, 0, min=1, max=8)
148+
optimized_model = torch.compile(model, backend="tensorrt")
149+
optimized_model(inputs) # First compilation
150+
151+
# No recompilation with different batch size in the dynamic range
152+
inputs_bs4 = torch.randn((4, 3, 224, 224), dtype=torch.float32).cuda()
153+
optimized_model(inputs_bs4)
154+
155+
Without dynamic shapes, the model will recompile for each new input shape encountered. For more control over dynamic shapes, consider using the AOT compilation path with ``torch_tensorrt.compile`` as described in :ref:`dynamic_shapes`. For a complete tutorial on dynamic shape compilation, see :ref:`compile_with_dynamic_inputs`.
105156

106157
Recompilation Conditions
107158
-----------------
108159

109-
Once the model has been compiled, subsequent inference inputs with the same shape and data type, which traverse the graph in the same way will not require recompilation. Furthermore, each new recompilation will be cached for the duration of the Python session. For instance, if inputs of batch size 4 and 8 are provided to the model, causing two recompilations, no further recompilation would be necessary for future inputs with those batch sizes during inference within the same session. Support for engine cache serialization is planned for a future release.
160+
Once the model has been compiled, subsequent inference inputs with the same shape and data type, which traverse the graph in the same way will not require recompilation. Furthermore, each new recompilation will be cached for the duration of the Python session. For instance, if inputs of batch size 4 and 8 are provided to the model, causing two recompilations, no further recompilation would be necessary for future inputs with those batch sizes during inference within the same session.
161+
162+
To persist engine caches across Python sessions, use the ``cache_built_engines`` and ``reuse_cached_engines`` options as described in the Engine Caching section above.
110163

111164
Recompilation is generally triggered by one of two events: encountering inputs of different sizes or inputs which traverse the model code differently. The latter scenario can occur when the model code includes conditional logic, complex loops, or data-dependent-shapes. `torch.compile` handles guarding in both of these scenario and determines when recompilation is necessary.

docsrc/getting_started/installation.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ You need to have CUDA, PyTorch, and TensorRT (python package is sufficient) inst
1515

1616
* https://developer.nvidia.com/cuda
1717
* https://pytorch.org
18+
* TensorRT 10.0 or later (TensorRT 10.12+ recommended for latest features like explicit typing)
1819

1920

2021
Installing Torch-TensorRT
@@ -208,9 +209,9 @@ A tarball with the include files and library can then be found in ``bazel-bin``
208209
Choosing the Right ABI
209210
^^^^^^^^^^^^^^^^^^^^^^^^
210211

211-
For the old versions, there were two ABI options to compile Torch-TensorRT which were incompatible with each other,
212-
pre-cxx11-abi and cxx11-abi. The complexity came from the different distributions of PyTorch. Fortunately, PyTorch
213-
has switched to cxx11-abi for all distributions. Below is a table with general pairings of PyTorch distribution
212+
For the old versions, there were two ABI options to compile Torch-TensorRT which were incompatible with each other,
213+
pre-cxx11-abi and cxx11-abi. The complexity came from the different distributions of PyTorch. Fortunately, PyTorch
214+
has switched to cxx11-abi for all distributions. Below is a table with general pairings of PyTorch distribution
214215
sources and the recommended commands:
215216

216217
+-------------------------------------------------------------+----------------------------------------------------------+--------------------------------------------------------------------+

docsrc/ts/ptq.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
Post Training Quantization (PTQ)
44
=================================
55

6+
.. warning::
7+
This guide describes the legacy PTQ workflow for the TorchScript frontend. **For new projects, use the TensorRT Model Optimizer (modelopt) with the Dynamo frontend instead.** See :ref:`vgg16_ptq` for the recommended approach.
8+
69
Post Training Quantization (PTQ) is a technique to reduce the required computational resources for inference
710
while still preserving the accuracy of your model by mapping the traditional FP32 activation space to a reduced
811
INT8 space. TensorRT uses a calibration step which executes your model with sample data from the target domain

0 commit comments

Comments
 (0)