Skip to content

Commit d4083cc

Browse files
author
Bowen Fu
committed
feat(annotation): add tta.custom_plugin — Triton/CuTile/CuTeDSL QDP AOT integration
Adds the `torch_tensorrt.annotation` (tta) module, which lets users register custom QDP plugins backed by Triton, CuTile, or CuTe DSL kernels and compile them AOT into TensorRT engines — no Python interpreter required at inference time. Public API surface ------------------ * `tta.triton(launch_fn, configs=...)` → TritonSpec * `tta.cutile(launch_fn, configs=...)` → CuTileSpec * `tta.cutedsl(launch_fn, configs=..., arch=...)` → CuTeDSLSpec * `tta.custom_plugin(spec_or_list, meta_impl=...)` → CustomPluginSpec * `tta.normalize_impl_to_spec(impl)` — coerce bare specs to CustomPluginSpec Integration path ---------------- CustomPluginSpec plugs into the existing `trt_plugins.custom_op()` / `generate_plugin_converter()` infrastructure: it registers the QDP descriptor, autotune callbacks, and AOT impl, then hands the resulting converter to the standard Dynamo lowering pipeline. AOT compilation backends ------------------------ * **Triton** — kernel compiled to PTX via `triton.compile`; PTX header patched to match TRT 10.16's expected `.version` before handoff. * **CuTile** — compiled via `cuda_tile.compile`. * **CuTe DSL** — compiled via `cutlass.cute.compile` with optional arch override. TvmFfiSpec (fourth backend) is planned but blocked on TVM FFI support landing in QDP. Key implementation details -------------------------- * `FakeTensorMode` used for `meta_impl` shape/dtype inference — avoids real tensor allocation during registration. * `_assign_recorded_grid` shared helper centralises the grid assignment step reused by all three AOT backends. * `_lowering.py` and `_impl.py` removed — both were dead code superseded by `_descriptor.py` and the dynamo plugin pipeline. * 3 `test_gated_ffn_block` tests marked `@expectedFailure`: TRT's `mergeMatmulLayers` pass delivers non-contiguous sub-region buffers to `IPluginV3::enqueue`, violating the LINEAR stride contract. The `_contiguous` variants (separate inputs, no fusion) pass as workaround.
1 parent 2e26bfa commit d4083cc

25 files changed

+7769
-20
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# torch_tensorrt.annotation — custom_plugin descriptors
2+
3+
`torch_tensorrt.annotation` (aliased as `tta`) provides descriptor types and
4+
factory functions for defining custom TensorRT AOT QDP plugins backed by
5+
Triton, CuTile, or CuTeDSL kernels.
6+
7+
```python
8+
import torch_tensorrt.annotation as tta
9+
```
10+
11+
This module is **descriptor-only**: it builds spec objects that describe how a
12+
plugin should be compiled and registered. It does not patch `torch.export`,
13+
add compilation hooks, or modify any torch-trt core path.
14+
15+
---
16+
17+
## Table of contents
18+
19+
1. [Quick start](#1-quick-start)
20+
2. [Factory functions](#2-factory-functions)
21+
3. [Spec types](#3-spec-types)
22+
4. [QDP plugin flow](#4-qdp-plugin-flow)
23+
5. [Running tests](#5-running-tests)
24+
25+
---
26+
27+
## 1. Quick start
28+
29+
```python
30+
import torch_tensorrt.annotation as tta
31+
32+
# Triton AOT plugin
33+
spec = tta.custom_plugin(tta.triton(my_launch_fn, configs=[{"BLOCK_SIZE": 128}]))
34+
35+
# CuTile plugin (Blackwell sm_100+)
36+
spec = tta.custom_plugin(tta.cutile(my_cutile_kernel, arch=120))
37+
38+
# CuTeDSL plugin
39+
spec = tta.custom_plugin(tta.cutedsl(my_cutedsl_kernel))
40+
```
41+
42+
---
43+
44+
## 2. Factory functions
45+
46+
### `tta.triton(launch_fn, configs=None)`
47+
48+
Wraps a Triton kernel launch function.
49+
50+
- **`launch_fn`** — callable that launches the Triton kernel;
51+
signature `(input0, ..., output, stream, **config)`.
52+
- **`configs`** — list of `dict` tactic configs; each becomes a separate
53+
QDP tactic. Pass `None` for a single default tactic.
54+
55+
```python
56+
@triton.jit
57+
def _add_relu_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
58+
i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
59+
mask = i < n
60+
tl.store(out_ptr + i, tl.maximum(0, tl.load(x_ptr+i, mask=mask) + tl.load(y_ptr+i, mask=mask)), mask=mask)
61+
62+
def launch_add_relu(x, y, out, stream, BLOCK=256):
63+
_add_relu_kernel[(triton.cdiv(x.numel(), BLOCK),)](x, y, out, x.numel(), BLOCK=BLOCK)
64+
65+
spec = tta.custom_plugin(tta.triton(launch_add_relu, configs=[{"BLOCK": 128}, {"BLOCK": 256}]))
66+
```
67+
68+
### `tta.cutile(launch_fn, arch=None, configs=None)`
69+
70+
Wraps a CuTile (cuda-tile) kernel. Requires Blackwell (sm_100+) and the
71+
`cuda-tile` package.
72+
73+
- **`arch`** — SM architecture integer (e.g. `120` for sm_120).
74+
- **`configs`** — list of tactic dicts.
75+
76+
```python
77+
spec = tta.custom_plugin(tta.cutile(my_cutile_fn, arch=120, configs=[{"TILE_M": 64}]))
78+
```
79+
80+
### `tta.cutedsl(launch_fn, configs=None)`
81+
82+
Wraps a CuTeDSL kernel (`nvidia-cutlass-dsl`).
83+
84+
```python
85+
spec = tta.custom_plugin(tta.cutedsl(my_cutedsl_fn))
86+
```
87+
88+
### `tta.custom_plugin(impl)`
89+
90+
Builds a `CustomPluginSpec` from a kernel spec (`TritonSpec`, `CuTileSpec`, or
91+
`CuTeDSLSpec`). Computes a deterministic QDP `op_name` from the kernel
92+
function identity and config hash.
93+
94+
```python
95+
spec = tta.custom_plugin(tta.triton(launch_fn, configs=[{"BLOCK": 256}]))
96+
# spec.op_name — deterministic string like "tta::launch_fn_a3f2c1"
97+
```
98+
99+
---
100+
101+
## 3. Spec types
102+
103+
All spec types are plain frozen dataclasses — they carry no mutable state and
104+
are safe to hash, compare, and cache.
105+
106+
| Type | Returned by | Description |
107+
|------|-------------|-------------|
108+
| `CustomPluginSpec` | `custom_plugin()` | AOT QDP plugin descriptor; holds `impl` (`TritonSpec` \| `CuTileSpec` \| `CuTeDSLSpec`) and computed `op_name` |
109+
| `TritonSpec` | `triton()` | Triton kernel launch function + tactic configs |
110+
| `CuTileSpec` | `cutile()` | CuTile kernel + target arch + tactic configs |
111+
| `CuTeDSLSpec` | `cutedsl()` | CuTeDSL kernel + tactic configs |
112+
113+
---
114+
115+
## 4. QDP plugin flow
116+
117+
`tta.custom_plugin` produces a descriptor. When you call
118+
`register_custom_plugin(spec, inputs)` (from `_custom_plugin._descriptor`) the
119+
module:
120+
121+
1. Derives a deterministic `op_name` from the kernel function + config hash.
122+
2. Registers `@trtp.register("tta::op_name")` — the shape/dtype descriptor
123+
function derived symbolically from the kernel signature.
124+
3. Registers `@trtp.aot_impl("tta::op_name")` — the AOT implementation
125+
function that returns `(kernel_name, ptx_or_cubin, KernelLaunchParams,
126+
SymIntExprs)`.
127+
4. Uses a process-level lock + double-checked locking to prevent duplicate
128+
registration in multi-threaded pytest-xdist workers.
129+
130+
The QDP AOT path means **no Python is needed at TRT engine runtime** — the
131+
compiled kernel is embedded directly.
132+
133+
```
134+
tta.triton(launch_fn, configs)
135+
└─► TritonSpec
136+
└─► tta.custom_plugin(spec)
137+
└─► CustomPluginSpec(op_name, impl)
138+
└─► register_custom_plugin(spec, inputs)
139+
├─► @trtp.register (shape descriptor)
140+
└─► @trtp.aot_impl (PTX/cubin → TRT)
141+
```
142+
143+
---
144+
145+
## 5. Running tests
146+
147+
Unit tests are CPU-only (no GPU required) and live in
148+
`tests/py/annotation/unit/`.
149+
150+
```bash
151+
# From inside the dev Docker container:
152+
python -m pytest tests/py/annotation/unit/ -n 4 --tb=short -v
153+
```
154+
155+
Test files:
156+
157+
| File | What it covers |
158+
|------|---------------|
159+
| `test_specs.py` | `TritonSpec`, `CuTileSpec`, `CuTeDSLSpec` construction and hashing |
160+
| `test_specs_custom_plugin.py` | `CustomPluginSpec` and `custom_plugin()` factory |
161+
| `test_signature_binder.py` | TRT signature derivation and binding |
162+
| `test_layer_metadata.py` | `AnnotationMetadata` encode/decode round-trip |
163+
| `test_plugin_lowering.py` | QDP plugin lowering path |
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Torch-TensorRT Annotation Layer (TTA) — custom_plugin API.
3+
4+
Provides descriptor types and factory functions for defining custom TensorRT
5+
AOT QDP plugins backed by Triton, CuTile, or CuTeDSL kernels.
6+
7+
Usage::
8+
9+
import torch_tensorrt.annotation as tta
10+
11+
# Triton kernel descriptor
12+
spec = tta.custom_plugin(
13+
tta.triton(my_triton_kernel, configs=[{"BLOCK_SIZE": 128}]),
14+
meta_impl=lambda x: x.new_empty(x.shape),
15+
)
16+
17+
# CuTile kernel descriptor
18+
spec = tta.custom_plugin(
19+
tta.cutile(my_cutile_kernel, arch=120),
20+
meta_impl=lambda x: x.new_empty(x.shape),
21+
)
22+
23+
# CuTeDSL kernel descriptor
24+
spec = tta.custom_plugin(
25+
tta.cutedsl(my_cutedsl_kernel),
26+
meta_impl=lambda x: x.new_empty(x.shape),
27+
)
28+
"""
29+
30+
from ._errors import TTADiagnosticError
31+
32+
from ._specs import (
33+
KernelImplSpec,
34+
CuTeDSLSpec,
35+
CuTileSpec,
36+
TritonSpec,
37+
cutedsl,
38+
cutile,
39+
triton,
40+
normalize_impl_to_spec,
41+
)
42+
43+
from ._custom_plugin._descriptor import CustomPluginSpec, custom_plugin
44+
45+
__all__ = [
46+
# Error types
47+
"TTADiagnosticError",
48+
# Descriptor types
49+
"KernelImplSpec",
50+
"TritonSpec",
51+
"CuTileSpec",
52+
"CuTeDSLSpec",
53+
"CustomPluginSpec",
54+
# Factory functions
55+
"custom_plugin",
56+
"triton",
57+
"cutile",
58+
"cutedsl",
59+
# Utilities
60+
"normalize_impl_to_spec",
61+
]
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Custom plugin sub-package: QDP-backed plugin descriptor and lowering.
2+
3+
This sub-package bridges user-supplied GPU kernels (Triton, cuTILE, CuTe DSL)
4+
to TensorRT's Quickstart Dynamic Plugin (QDP) framework, enabling annotated
5+
boundary ops to be lowered to first-class ``IPluginV3`` layers at TRT engine
6+
compile time.
7+
8+
Role in the compilation pipeline
9+
---------------------------------
10+
1. **Annotation** — the user calls ``tta.custom_plugin(kernel, meta_impl=...)``
11+
(re-exported here as ``custom_plugin``), which returns a
12+
``CustomPluginSpec`` that is stored in the boundary op's
13+
``AnnotationMetadata``.
14+
2. **Registration** — at compile time, ``register_custom_plugin`` calls
15+
``@trtp.register`` / ``@trtp.aot_impl`` on the descriptor, making the
16+
plugin visible to TRT's global plugin registry.
17+
3. **Lowering** — ``lower_custom_plugin_descriptor`` calls ``trtp.op.<ns>.<name>``
18+
to insert an ``IPluginV3`` layer into the ``INetworkDefinition``.
19+
20+
Public surface
21+
--------------
22+
``CustomPluginSpec``
23+
Dataclass returned by ``custom_plugin()``. Carries the op name, kernel
24+
specs, meta-shape implementation, and optional tactic table.
25+
26+
``custom_plugin``
27+
Factory that builds a ``CustomPluginSpec`` from a kernel spec, auto-
28+
computing a deterministic QDP op name from the kernel fingerprint.
29+
30+
``lower_custom_plugin_descriptor``
31+
Converts a ``CustomPluginSpec`` into a TRT ``IPluginV3`` layer and
32+
returns the output ``trt.ITensor`` (or tuple thereof).
33+
34+
``register_custom_plugin``
35+
Registers ``@trtp.register`` / ``@trtp.aot_impl`` handlers for a
36+
descriptor's op name. Idempotent at the process level.
37+
38+
``QDPRuntimeError``
39+
Raised when TRT's QDP framework encounters a runtime error (e.g. shape
40+
mismatch, unsupported dtype) during plugin execution.
41+
42+
``TTAPluginError``
43+
Raised for TTA-level plugin configuration errors (e.g. missing meta_impl,
44+
invalid kernel spec).
45+
46+
``SymbolicTensor`` / ``TensorRole``
47+
Proxy used during AOT kernel compilation to carry symbolic shape
48+
expressions. ``TensorRole`` distinguishes input from output tensors so
49+
that ``analyze_launch_args`` can reconstruct the correct QDP binding
50+
indices.
51+
"""
52+
53+
# ---------------------------------------------------------------------------
54+
# Re-exports from sub-modules
55+
# ---------------------------------------------------------------------------
56+
57+
from ._descriptor import (
58+
CustomPluginSpec,
59+
custom_plugin,
60+
lower_custom_plugin_descriptor,
61+
register_custom_plugin,
62+
)
63+
from ._qdp_utils import QDPRuntimeError, TTAPluginError
64+
from ._symbolic import SymbolicTensor, TensorRole
65+
66+
# ---------------------------------------------------------------------------
67+
# Public API
68+
# ---------------------------------------------------------------------------
69+
70+
__all__ = [
71+
"CustomPluginSpec",
72+
"custom_plugin",
73+
"lower_custom_plugin_descriptor",
74+
"register_custom_plugin",
75+
"QDPRuntimeError",
76+
"TTAPluginError",
77+
"SymbolicTensor",
78+
"TensorRole",
79+
]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""AOT kernel backend implementations for the TTA custom plugin system.
2+
3+
Each sub-module implements the ``aot_impl_<backend>`` function that drives a
4+
backend-specific compile pipeline and returns the QDP AOT 4-tuple::
5+
6+
(kernel_name, code_bytes, KernelLaunchParams, SymIntExprs)
7+
8+
Backends
9+
--------
10+
_triton — Triton JIT kernels compiled to PTX via triton.compile()
11+
_cutile — cuTILE programs compiled to CUBIN via cuda.tile compile_tile()
12+
_cutedsl — CuTe DSL @cute.jit kernels compiled via cutlass.cute.compile()
13+
14+
All three backends share ``_qdp_utils`` helpers for sandboxing, launch-arg
15+
analysis, artifact dumping, and the unified ``AOTMetadata`` result type.
16+
"""

0 commit comments

Comments
 (0)