Skip to content

Commit 24f99f4

Browse files
committed
fix: some more index fixes, enable dynamo tracer in some tests
1 parent b0837f7 commit 24f99f4

3 files changed

Lines changed: 9 additions & 12 deletions

File tree

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import numpy as np
88
import torch
9+
from tensorrt import ITensor as TRTTensor
910
from torch.fx.node import Argument, Node, Target
11+
1012
from torch_tensorrt import ENABLED_FEATURES
1113
from torch_tensorrt._features import needs_not_tensorrt_rtx
1214
from torch_tensorrt._utils import is_tensorrt_version_supported
@@ -27,8 +29,6 @@
2729
)
2830
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2931

30-
from tensorrt import ITensor as TRTTensor
31-
3232
_LOGGER: logging.Logger = logging.getLogger(__name__)
3333

3434

@@ -2904,7 +2904,6 @@ def sort_validator(node: Node, settings: Optional[CompilationSettings] = None) -
29042904

29052905

29062906
def topk_sort_validator(k: int) -> bool:
2907-
29082907
# topk layer supports dynamic k value but we cannot determine supported dynamic topk value at
29092908
# compile time.
29102909
if k == DYNAMIC_DIM or not isinstance(k, int):

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from typing import List, Optional, Sequence, Union
33

44
import numpy as np
5+
import tensorrt as trt
56
import torch
7+
from tensorrt import ITensor
68
from torch.fx.node import Target
9+
710
from torch_tensorrt._enums import dtype
811
from torch_tensorrt.dynamo._SourceIR import SourceIR
912
from torch_tensorrt.dynamo.conversion import impl
@@ -21,9 +24,6 @@
2124
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
2225
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2326

24-
import tensorrt as trt
25-
from tensorrt import ITensor
26-
2727
_LOGGER: logging.Logger = logging.getLogger(__name__)
2828

2929

@@ -875,9 +875,7 @@ def index_put_converter(
875875
K = len(I)
876876
# Determine the maximum size 'N' among the index tensors
877877
if K > 0:
878-
index_shapes = (
879-
[]
880-
) # [tensor.shape[0] for tensor in indices if tensor is not None]
878+
index_shapes = [] # [tensor.shape[0] for tensor in indices if tensor is not None]
881879
for _ni, idx_tensor in enumerate(indices):
882880
if idx_tensor is not None:
883881
if idx_tensor.shape[0] != DYNAMIC_DIM:

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def forward(self, x, index0):
160160

161161
input = torch.randn(2, 2)
162162
index0 = torch.tensor([True, False])
163-
self.run_test(TestModule(), [input, index0], enable_passes=True)
163+
self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
164164

165165
def test_index_zero_index_three_dim_ITensor(self):
166166
class TestModule(nn.Module):
@@ -172,7 +172,7 @@ def forward(self, x, index0):
172172
input = torch.randn(2, 2, 2)
173173
index0 = torch.randint(0, 1, (1, 1))
174174
index0 = index0.to(torch.int32)
175-
self.run_test(TestModule(), [input, index0])
175+
self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
176176

177177
@unittest.skipIf(
178178
ENABLED_FEATURES.tensorrt_rtx,
@@ -187,7 +187,7 @@ def forward(self, x, index0):
187187

188188
input = torch.randn(2, 2, 2)
189189
index0 = torch.tensor([True, False])
190-
self.run_test(TestModule(), [input, index0])
190+
self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
191191

192192

193193
class TestIndexDynamicConstantConverter(DispatchTestCase):

0 commit comments

Comments
 (0)