Skip to content

Commit c0f4fa9

Browse files
authored
fix : unwaive skipped/special TRT-RTX tests (#4156)
Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
1 parent 44ad55e commit c0f4fa9

5 files changed

Lines changed: 4 additions & 32 deletions

File tree

MODULE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ http_archive(
142142
build_file = "@//third_party/tensorrt_rtx/archive:BUILD",
143143
strip_prefix = "TensorRT-RTX-1.4.0.76",
144144
urls = [
145-
"https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.4/tensorrt-rtx-1.4.0.76-linux-x86_64-cuda-13.2-release-external.tar.gz",
145+
"https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.4/TensorRT-RTX-1.4.0.76-Linux-x86_64-cuda-13.2-Release-external.tar.gz",
146146
],
147147
)
148148

@@ -178,6 +178,6 @@ http_archive(
178178
build_file = "@//third_party/tensorrt_rtx/archive:BUILD",
179179
strip_prefix = "TensorRT-RTX-1.4.0.76",
180180
urls = [
181-
"https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.4/tensorrt-rtx-1.4.0.76-win10-amd64-cuda-13.2-release-external.zip",
181+
"https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.4/TensorRT-RTX-1.4.0.76-Windows-amd64-cuda-13.2-Release-external.zip",
182182
],
183183
)

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from tensorrt import ITensor as TRTTensor
77
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
88
from torch.fx.node import Target
9-
from torch_tensorrt import ENABLED_FEATURES
109
from torch_tensorrt.dynamo._SourceIR import SourceIR
1110
from torch_tensorrt.dynamo.conversion import impl
1211
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -35,19 +34,6 @@ def get_ir(target: Target) -> SourceIR:
3534
return SourceIR.UNKNOWN
3635

3736

38-
def validate_int8_activation_quantization(name: str, dtype: trt.DataType) -> None:
39-
if (
40-
dtype == trt.DataType.INT8
41-
and ".input_quantizer" in name
42-
and ENABLED_FEATURES.tensorrt_rtx
43-
):
44-
# RTX does not support int8 activation quantization
45-
# TODO: lan to remove this once rtx team has added the support for int8 activation quantization
46-
raise NotImplementedError(
47-
"TensorRT-RTX does not support int8 activation quantization, only support int8 weight quantization"
48-
)
49-
50-
5137
def quantize(
5238
ctx: ConversionContext,
5339
target: Target,
@@ -91,8 +77,6 @@ def quantize(
9177
dtype = trt.DataType.FP8
9278
max_bound = 448
9379

94-
validate_int8_activation_quantization(name, dtype)
95-
9680
axis = None
9781
# int8 weight quantization is per-channel quantization(it can have one or multiple amax values)
9882
if dtype == trt.DataType.INT8 and amax.numel() > 1:

tests/py/dynamo/models/test_models_export.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,6 @@ def test_base_int8(ir, dtype):
413413
import modelopt.torch.quantization as mtq
414414
from modelopt.torch.quantization.utils import export_torch_mode
415415

416-
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
417-
pytest.skip("TensorRT-RTX does not support bfloat16")
418-
419416
class SimpleNetwork(torch.nn.Module):
420417
def __init__(self):
421418
super(SimpleNetwork, self).__init__()
@@ -435,9 +432,6 @@ def calibrate_loop(model):
435432
input_tensor = torch.randn(1, 10).cuda().to(dtype)
436433
model = SimpleNetwork().eval().cuda().to(dtype)
437434
quant_cfg = mtq.INT8_DEFAULT_CFG
438-
# RTX does not support INT8 default quantization(weights+activations), only support INT8 weights only quantization
439-
if torchtrt.ENABLED_FEATURES.tensorrt_rtx:
440-
quant_cfg["quant_cfg"]["*input_quantizer"] = {"enable": False}
441435
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
442436
# model has INT8 qdq nodes at this point
443437
output_pyt = model(input_tensor)
@@ -474,9 +468,6 @@ def test_base_int8_dynamic_shape(ir, dtype):
474468
import modelopt.torch.quantization as mtq
475469
from modelopt.torch.quantization.utils import export_torch_mode
476470

477-
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
478-
pytest.skip("TensorRT-RTX does not support bfloat16")
479-
480471
class SimpleNetwork(torch.nn.Module):
481472
def __init__(self):
482473
super(SimpleNetwork, self).__init__()

tests/py/dynamo/runtime/test_004_weight_streaming.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,6 @@ def test_weight_streaming_cudagraphs(self, _, use_python_runtime):
292292
("cpp_runtime", False),
293293
]
294294
)
295-
@unittest.skipIf(
296-
torchtrt.ENABLED_FEATURES.tensorrt_rtx, "TensorRT-RTX has bug on cudagraphs"
297-
)
298295
@unittest.skipIf(
299296
is_orin(), "There is a bug on Orin platform, skip for now until bug is fixed"
300297
)

toolchains/ci_workspaces/MODULE.bazel.tmpl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ http_archive(
8686
build_file = "@//third_party/tensorrt_rtx/archive:BUILD",
8787
strip_prefix = "TensorRT-RTX-1.4.0.76",
8888
urls = [
89-
"https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.4/tensorrt-rtx-1.4.0.76-linux-x86_64-cuda-${CU_UPPERBOUND}-release-external.tar.gz",
89+
"https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.4/TensorRT-RTX-1.4.0.76-Linux-x86_64-cuda-${CU_UPPERBOUND}-Release-external.tar.gz",
9090
],
9191
)
9292
@@ -122,7 +122,7 @@ http_archive(
122122
build_file = "@//third_party/tensorrt_rtx/archive:BUILD",
123123
strip_prefix = "TensorRT-RTX-1.4.0.76",
124124
urls = [
125-
"https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.4/tensorrt-rtx-1.4.0.76-win10-amd64-cuda-${CU_UPPERBOUND}-release-external.zip",
125+
"https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.4/TensorRT-RTX-1.4.0.76-Windows-amd64-cuda-${CU_UPPERBOUND}-Release-external.zip",
126126
],
127127
)
128128

0 commit comments

Comments
 (0)