Skip to content

Commit ab6854e

Browse files
author
shoumikhin
committed
[executorch][nvidia][tensorrt][20/n] Add bmm converter
Add batch matrix multiplication (bmm) converter to enable transformer attention layers and batch matrix operations. Differential Revision: [D93275048](https://our.internmc.facebook.com/intern/diff/D93275048/) [ghstack-poisoned]
1 parent e612445 commit ab6854e

5 files changed

Lines changed: 65 additions & 5 deletions

File tree

backends/nvidia/tensorrt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.backends.nvidia.tensorrt.converters import add # noqa: F401
1212
from executorch.backends.nvidia.tensorrt.converters import addmm # noqa: F401
1313
from executorch.backends.nvidia.tensorrt.converters import batch_norm # noqa: F401
14+
from executorch.backends.nvidia.tensorrt.converters import bmm # noqa: F401
1415
from executorch.backends.nvidia.tensorrt.converters import clamp # noqa: F401
1516
from executorch.backends.nvidia.tensorrt.converters import concat # noqa: F401
1617
from executorch.backends.nvidia.tensorrt.converters import conv2d # noqa: F401
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Converter for batch matrix multiplication operations."""
8+
9+
from typing import Any, Dict, Optional
10+
11+
import tensorrt as trt
12+
import torch
13+
from executorch.backends.nvidia.tensorrt.converter_registry import converter
14+
from executorch.backends.nvidia.tensorrt.converter_utils import set_layer_name
15+
16+
17+
@converter("aten.bmm.default")
18+
def convert_bmm(
19+
node: torch.fx.Node,
20+
network: trt.INetworkDefinition,
21+
input_map: Dict[torch.fx.Node, Any],
22+
edge_program: Optional[Any] = None,
23+
) -> trt.ITensor:
24+
"""Convert aten.bmm.default to TensorRT MatrixMultiply.
25+
26+
Performs batch matrix multiplication of two 3D tensors (B, M, K) @ (B, K, N) -> (B, M, N).
27+
TensorRT's IMatrixMultiplyLayer supports batch matrix multiplication natively.
28+
"""
29+
lhs_arg = node.args[0]
30+
rhs_arg = node.args[1]
31+
32+
if lhs_arg not in input_map:
33+
raise ValueError(f"Input node '{lhs_arg.name}' not found in input_map for bmm")
34+
if rhs_arg not in input_map:
35+
raise ValueError(f"Input node '{rhs_arg.name}' not found in input_map for bmm")
36+
37+
lhs = input_map[lhs_arg]
38+
rhs = input_map[rhs_arg]
39+
40+
layer = network.add_matrix_multiply(
41+
lhs, trt.MatrixOperation.NONE, rhs, trt.MatrixOperation.NONE
42+
)
43+
set_layer_name(layer, node, "bmm")
44+
45+
return layer.get_output(0)

backends/nvidia/tensorrt/converters/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def define_common_targets():
1515
"add.py",
1616
"addmm.py",
1717
"batch_norm.py",
18+
"bmm.py",
1819
"clamp.py",
1920
"concat.py",
2021
"conv2d.py",

examples/nvidia/tensorrt/export.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
"conv1d",
3939
"dl3",
4040
"edsr",
41+
# "efficient_sam", # TODO: diff ~41 — likely bicubic interpolation decomposition or ConvTranspose2d issue
42+
"emformer_join",
43+
# "emformer_predict", # TODO: passes 1/3 seeds — precision sensitive with randomized inputs
44+
"emformer_transcribe",
4145
"ic3",
4246
"linear",
4347
"mul",
@@ -126,6 +130,7 @@ def _verify_correctness(
126130

127131
et_module = _load_for_executorch_from_buffer(pte_bytes)
128132

133+
129134
for seed in _TEST_SEEDS:
130135
inputs = _randomise_inputs(example_inputs, seed)
131136

examples/nvidia/tensorrt/tests/test_export.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ def _populate_weight_cache() -> None:
4040
for env_var, filename in _WEIGHT_ENV_VARS.items():
4141
src = os.environ.get(env_var)
4242
if src and os.path.isfile(src):
43-
# dog.jpg goes to CWD (mv2 model downloads it there)
44-
if filename == "dog.jpg":
45-
dst = os.path.join(os.getcwd(), filename)
46-
else:
47-
dst = os.path.join(cache_dir, filename)
43+
dst = os.path.join(cache_dir, filename)
4844
if not os.path.exists(dst):
4945
shutil.copy2(src, dst)
5046
logger.info(f"Cached {filename} from {src}")
@@ -121,3 +117,15 @@ def test_ic3(self) -> None:
121117

122118
def test_sdpa(self) -> None:
123119
_export_and_verify("sdpa")
120+
121+
def test_emformer_join(self) -> None:
122+
_export_and_verify("emformer_join")
123+
124+
def test_softmax(self) -> None:
125+
_export_and_verify("softmax")
126+
127+
def test_mv3(self) -> None:
128+
_export_and_verify("mv3")
129+
130+
def test_ic3(self) -> None:
131+
_export_and_verify("ic3")

0 commit comments

Comments
 (0)