Skip to content

Commit ab98a78

Browse files
committed
chore: resolve failures in CI - disable the plugin for Orin and for TRT-RTX
1 parent 65a6c46 commit ab98a78

5 files changed

Lines changed: 47 additions & 9 deletions

File tree

core/plugins/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ config_setting(
6262
cc_library(
6363
name = "torch_tensorrt_plugins",
6464
srcs = select({
65+
":jetpack": [
66+
"impl/interpolate_plugin.cpp",
67+
"impl/normalize_plugin.cpp",
68+
"register_plugins.cpp",
69+
],
6570
":rtx_win": [],
6671
":rtx_x86_64": [],
6772
"//conditions:default": [
@@ -72,6 +77,11 @@ cc_library(
7277
],
7378
}),
7479
hdrs = select({
80+
":jetpack": [
81+
"impl/interpolate_plugin.h",
82+
"impl/normalize_plugin.h",
83+
"plugins.h",
84+
],
7585
":rtx_win": [],
7686
":rtx_x86_64": [],
7787
"//conditions:default": [

core/plugins/register_plugins.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#include "NvInferPlugin.h"
22
#include "NvInferPluginUtils.h"
3+
#include "NvInferVersion.h"
34
#include "core/plugins/impl/interpolate_plugin.h"
45
#include "core/plugins/impl/normalize_plugin.h"
6+
// ScatterAdd plugin is not available on Jetpack (TRT 10.3.x / L4T) or TRT-RTX builds
7+
#if NV_TENSORRT_MAJOR > 10 || (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 3)
58
#include "core/plugins/impl/scatter_add_plugin.h"
9+
#endif
610
#include "core/plugins/plugins.h"
711
#include "core/util/prelude.h"
812

py/torch_tensorrt/_features.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,20 @@
5252
_WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib()
5353
_TRTLLM_AVAIL = load_tensorrt_llm_for_nccl()
5454

55-
if importlib.util.find_spec("tensorrt.plugin") and importlib.util.find_spec(
56-
"tensorrt.plugin._lib"
57-
):
58-
# there is a bug in tensorrt 10.14.* and 10.15.* that causes the plugin to not work, disable it for now
59-
if tensorrt.__version__.startswith("10.15.") or tensorrt.__version__.startswith(
60-
"10.14."
55+
try:
56+
if importlib.util.find_spec("tensorrt.plugin") and importlib.util.find_spec(
57+
"tensorrt.plugin._lib"
6158
):
62-
_QDP_PLUGIN_AVAIL = False
59+
# there is a bug in tensorrt 10.14.* and 10.15.* that causes the plugin to not work, disable it for now
60+
if tensorrt.__version__.startswith("10.15.") or tensorrt.__version__.startswith(
61+
"10.14."
62+
):
63+
_QDP_PLUGIN_AVAIL = False
64+
else:
65+
_QDP_PLUGIN_AVAIL = True
6366
else:
64-
_QDP_PLUGIN_AVAIL = True
65-
else:
67+
_QDP_PLUGIN_AVAIL = False
68+
except Exception:
6669
_QDP_PLUGIN_AVAIL = False
6770

6871
ENABLED_FEATURES = FeatureSet(

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,9 @@ def aten_ops_select(
10951095
def _scatter_add_plugin_available() -> bool:
10961096
import tensorrt as trt
10971097

1098+
# ScatterAdd plugin is not built for Jetpack (TRT 10.3.x) or TRT-RTX builds
1099+
if trt.__version__.startswith("10.3.") or trt._package_name == "tensorrt_rtx":
1100+
return False
10981101
return (
10991102
trt.get_plugin_registry().get_creator("ScatterAdd", "1", "torch_tensorrt")
11001103
is not None

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import pytest
12
import torch
23
import torch_tensorrt as torchtrt
34
from parameterized import param, parameterized
45
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt import ENABLED_FEATURES
57

68
from .harness import DispatchTestCase
79

@@ -409,6 +411,9 @@ class TestIndexPutConverter(DispatchTestCase):
409411
def test_index_put(
410412
self, test_name, source_tensor, indices_tensor, value_tensor, accumulate=False
411413
):
414+
if accumulate and ENABLED_FEATURES.tensorrt_rtx:
415+
pytest.skip("ScatterAdd plugin not available in TRT RTX")
416+
412417
@torch._dynamo.assume_constant_result
413418
def get_indices_tensor():
414419
return indices_tensor
@@ -425,6 +430,7 @@ def forward(self, source_tensor, value_tensor):
425430
inputs=[source_tensor, value_tensor],
426431
enable_passes=True,
427432
use_dynamo_tracer=True,
433+
use_explicit_typing=True,
428434
)
429435

430436
def test_index_add_dynamic_shape(self):
@@ -868,6 +874,9 @@ def forward(self, src, values, idx0, idx1):
868874
# duplicate positions when index_put is embedded in a larger graph.
869875
# ------------------------------------------------------------------
870876

877+
@pytest.mark.skipif(
878+
ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
879+
)
871880
def test_kv_cache_duplicate_slot_writes(self):
872881
"""KV-cache style: linear projection → index_put(accumulate=True) into
873882
a flat cache with duplicate slot indices → output projection.
@@ -903,8 +912,12 @@ def forward(self, tokens, cache):
903912
inputs=[tokens, cache],
904913
use_dynamo_tracer=True,
905914
enable_passes=True,
915+
use_explicit_typing=True,
906916
)
907917

918+
@pytest.mark.skipif(
919+
ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
920+
)
908921
def test_sparse_embedding_duplicate_seq_ids(self):
909922
"""Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
910923
into per-sequence accumulators where many tokens map to the same sequence → ReLU.
@@ -943,8 +956,12 @@ def forward(self, token_ids, accum):
943956
inputs=[token_ids, accum],
944957
use_dynamo_tracer=True,
945958
enable_passes=True,
959+
use_explicit_typing=True,
946960
)
947961

962+
@pytest.mark.skipif(
963+
ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
964+
)
948965
def test_histogram_conv_duplicate_bin_ids(self):
949966
"""Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
950967
bins where many frames land in the same bin → mean pool → linear.
@@ -983,6 +1000,7 @@ def forward(self, signal, hist):
9831000
inputs=[signal, hist],
9841001
use_dynamo_tracer=True,
9851002
enable_passes=True,
1003+
use_explicit_typing=True,
9861004
)
9871005

9881006

0 commit comments

Comments
 (0)