Skip to content

Commit e612445

Browse files
author
shoumikhin
committed
[executorch][nvidia][tensorrt][19/n] Add SDPA converter
Add Scaled Dot-Product Attention (SDPA) converter to enable transformer-based attention layers. Differential Revision: [D93275047](https://our.internmc.facebook.com/intern/diff/D93275047/) [ghstack-poisoned]
1 parent 37a7e5a commit e612445

5 files changed

Lines changed: 180 additions & 0 deletions

File tree

backends/nvidia/tensorrt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from executorch.backends.nvidia.tensorrt.converters import reduction # noqa: F401
3030
from executorch.backends.nvidia.tensorrt.converters import relu # noqa: F401
3131
from executorch.backends.nvidia.tensorrt.converters import reshape # noqa: F401
32+
from executorch.backends.nvidia.tensorrt.converters import sdpa # noqa: F401
3233
from executorch.backends.nvidia.tensorrt.converters import sub # noqa: F401
3334
from executorch.backends.nvidia.tensorrt.converters import upsample # noqa: F401
3435

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""TensorRT Converters for Scaled Dot-Product Attention (SDPA).
8+
9+
Supported operations:
10+
- aten.scaled_dot_product_attention: Core attention mechanism for transformers
11+
12+
SDPA computes: softmax(Q @ K^T / sqrt(d_k)) @ V
13+
14+
For TensorRT, we implement this using:
15+
1. Matrix multiply for Q @ K^T
16+
2. Scale by 1/sqrt(d_k)
17+
3. Optional attention mask application
18+
4. Softmax
19+
5. Matrix multiply with V
20+
"""
21+
22+
import logging
23+
import math
24+
from typing import Any, Dict, Optional
25+
26+
import numpy as np
27+
import torch
28+
from executorch.backends.nvidia.tensorrt.converter_registry import converter
29+
from executorch.backends.nvidia.tensorrt.converter_utils import (
30+
create_constant,
31+
get_trt_tensor,
32+
)
33+
34+
logger: logging.Logger = logging.getLogger(__name__)
35+
36+
37+
def validate_sdpa(node: torch.fx.Node) -> bool:
38+
"""Validate that an SDPA node can be converted to TensorRT."""
39+
if node.op != "call_function":
40+
return False
41+
42+
args = node.args
43+
if len(args) < 3:
44+
return False
45+
46+
for i in range(3):
47+
if not isinstance(args[i], torch.fx.Node):
48+
return False
49+
50+
return True
51+
52+
53+
@converter("aten.scaled_dot_product_attention.default", validator_fn=validate_sdpa)
54+
def convert_scaled_dot_product_attention(
55+
node: torch.fx.Node,
56+
network: Any, # trt.INetworkDefinition
57+
input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor]
58+
edge_program: Optional[Any] = None,
59+
) -> Any: # trt.ITensor
60+
"""Convert PyTorch scaled_dot_product_attention to TensorRT.
61+
62+
SDPA formula: softmax(Q @ K^T / sqrt(d_k) + mask) @ V
63+
"""
64+
try:
65+
import tensorrt as trt
66+
except ImportError as e:
67+
raise ImportError("TensorRT is required for convert_sdpa") from e
68+
69+
args = node.args
70+
kwargs = node.kwargs
71+
72+
query_node = args[0]
73+
key_node = args[1]
74+
value_node = args[2]
75+
attn_mask_node = args[3] if len(args) > 3 else kwargs.get("attn_mask", None)
76+
is_causal = args[5] if len(args) > 5 else kwargs.get("is_causal", False)
77+
scale = args[6] if len(args) > 6 else kwargs.get("scale", None)
78+
79+
query_trt = input_map[query_node]
80+
key_trt = input_map[key_node]
81+
value_trt = input_map[value_node]
82+
83+
query_shape = query_trt.shape
84+
d_k = query_shape[-1]
85+
86+
# Calculate scale factor
87+
if scale is not None:
88+
scale_factor = float(scale)
89+
elif d_k > 0:
90+
scale_factor = 1.0 / math.sqrt(float(d_k))
91+
else:
92+
query_meta_shape = None
93+
if isinstance(query_node, torch.fx.Node) and "val" in query_node.meta:
94+
val = query_node.meta["val"]
95+
if hasattr(val, "shape"):
96+
query_meta_shape = val.shape
97+
if query_meta_shape is not None and len(query_meta_shape) > 0:
98+
d_k_static = query_meta_shape[-1]
99+
scale_factor = 1.0 / math.sqrt(float(d_k_static)) if d_k_static > 0 else 1.0
100+
else:
101+
raise RuntimeError(
102+
f"Cannot determine head dimension for SDPA node {node.name}."
103+
)
104+
105+
# Step 1: Q @ K^T
106+
qk_layer = network.add_matrix_multiply(
107+
query_trt, trt.MatrixOperation.NONE,
108+
key_trt, trt.MatrixOperation.TRANSPOSE,
109+
)
110+
qk_layer.name = f"sdpa_qk_{node.name}"
111+
qk = qk_layer.get_output(0)
112+
113+
# Step 2: Scale by 1/sqrt(d_k)
114+
scale_const = get_trt_tensor(
115+
network, scale_factor, f"sdpa_scale_{node.name}", dtype=torch.float32
116+
)
117+
scaled_qk_layer = network.add_elementwise(
118+
qk, scale_const, trt.ElementWiseOperation.PROD
119+
)
120+
scaled_qk_layer.name = f"sdpa_scale_{node.name}"
121+
scaled_qk = scaled_qk_layer.get_output(0)
122+
123+
# Step 3: Apply attention mask if provided
124+
if attn_mask_node is not None and isinstance(attn_mask_node, torch.fx.Node):
125+
if attn_mask_node in input_map:
126+
attn_mask_trt = input_map[attn_mask_node]
127+
mask_layer = network.add_elementwise(
128+
scaled_qk, attn_mask_trt, trt.ElementWiseOperation.SUM
129+
)
130+
mask_layer.name = f"sdpa_mask_{node.name}"
131+
scaled_qk = mask_layer.get_output(0)
132+
133+
# Step 4: Handle causal masking
134+
if is_causal:
135+
seq_len = query_shape[-2] if len(query_shape) >= 2 else -1
136+
if seq_len > 0:
137+
causal_mask = np.triu(
138+
np.full((seq_len, seq_len), float("-inf"), dtype=np.float32), k=1
139+
)
140+
causal_mask_trt = create_constant(
141+
network, causal_mask, f"sdpa_causal_mask_{node.name}"
142+
)
143+
causal_layer = network.add_elementwise(
144+
scaled_qk, causal_mask_trt, trt.ElementWiseOperation.SUM
145+
)
146+
causal_layer.name = f"sdpa_causal_{node.name}"
147+
scaled_qk = causal_layer.get_output(0)
148+
149+
# Step 5: Softmax along the last dimension
150+
softmax_layer = network.add_softmax(scaled_qk)
151+
softmax_layer.axes = 1 << (len(query_shape) - 1)
152+
softmax_layer.name = f"sdpa_softmax_{node.name}"
153+
attn_weights = softmax_layer.get_output(0)
154+
155+
# Step 6: attn_weights @ V
156+
output_layer = network.add_matrix_multiply(
157+
attn_weights, trt.MatrixOperation.NONE,
158+
value_trt, trt.MatrixOperation.NONE,
159+
)
160+
output_layer.name = f"sdpa_output_{node.name}"
161+
162+
return output_layer.get_output(0)
163+
164+
165+
@converter("aten._scaled_dot_product_flash_attention.default", validator_fn=validate_sdpa)
166+
def convert_flash_attention(node, network, input_map, edge_program=None):
167+
"""Convert flash attention — reuse SDPA implementation."""
168+
return convert_scaled_dot_product_attention(node, network, input_map, edge_program)
169+
170+
171+
@converter("aten._scaled_dot_product_efficient_attention.default", validator_fn=validate_sdpa)
172+
def convert_efficient_attention(node, network, input_map, edge_program=None):
173+
"""Convert efficient attention — reuse SDPA implementation."""
174+
return convert_scaled_dot_product_attention(node, network, input_map, edge_program)

backends/nvidia/tensorrt/converters/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def define_common_targets():
3333
"reduction.py",
3434
"relu.py",
3535
"reshape.py",
36+
"sdpa.py",
3637
"sub.py",
3738
"upsample.py",
3839
],

examples/nvidia/tensorrt/export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"linear",
4343
"mul",
4444
"mv3",
45+
"sdpa",
4546
"softmax",
4647
"w2l",
4748
}

examples/nvidia/tensorrt/tests/test_export.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,6 @@ def test_w2l(self) -> None:
118118

119119
def test_ic3(self) -> None:
120120
_export_and_verify("ic3")
121+
122+
def test_sdpa(self) -> None:
123+
_export_and_verify("sdpa")

0 commit comments

Comments
 (0)