@@ -413,9 +413,6 @@ def test_base_int8(ir, dtype):
413413 import modelopt .torch .quantization as mtq
414414 from modelopt .torch .quantization .utils import export_torch_mode
415415
416- if torchtrt .ENABLED_FEATURES .tensorrt_rtx and dtype == torch .bfloat16 :
417- pytest .skip ("TensorRT-RTX does not support bfloat16" )
418-
419416 class SimpleNetwork (torch .nn .Module ):
420417 def __init__ (self ):
421418 super (SimpleNetwork , self ).__init__ ()
@@ -435,9 +432,6 @@ def calibrate_loop(model):
435432 input_tensor = torch .randn (1 , 10 ).cuda ().to (dtype )
436433 model = SimpleNetwork ().eval ().cuda ().to (dtype )
437434 quant_cfg = mtq .INT8_DEFAULT_CFG
438- # RTX does not support INT8 default quantization(weights+activations), only support INT8 weights only quantization
439- if torchtrt .ENABLED_FEATURES .tensorrt_rtx :
440- quant_cfg ["quant_cfg" ]["*input_quantizer" ] = {"enable" : False }
441435 mtq .quantize (model , quant_cfg , forward_loop = calibrate_loop )
442436 # model has INT8 qdq nodes at this point
443437 output_pyt = model (input_tensor )
@@ -474,9 +468,6 @@ def test_base_int8_dynamic_shape(ir, dtype):
474468 import modelopt .torch .quantization as mtq
475469 from modelopt .torch .quantization .utils import export_torch_mode
476470
477- if torchtrt .ENABLED_FEATURES .tensorrt_rtx and dtype == torch .bfloat16 :
478- pytest .skip ("TensorRT-RTX does not support bfloat16" )
479-
480471 class SimpleNetwork (torch .nn .Module ):
481472 def __init__ (self ):
482473 super (SimpleNetwork , self ).__init__ ()
0 commit comments