Skip to content

Commit 503ca51

Browse files
committed
fix: some more index fixes, enable dynamo tracer in some tests
1 parent b0837f7 commit 503ca51

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2904,7 +2904,6 @@ def sort_validator(node: Node, settings: Optional[CompilationSettings] = None) -
29042904

29052905

29062906
def topk_sort_validator(k: int) -> bool:
2907-
29082907
# topk layer supports dynamic k value but we cannot determine supported dynamic topk value at
29092908
# compile time.
29102909
if k == DYNAMIC_DIM or not isinstance(k, int):

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def forward(self, x, index0):
160160

161161
input = torch.randn(2, 2)
162162
index0 = torch.tensor([True, False])
163-
self.run_test(TestModule(), [input, index0], enable_passes=True)
163+
self.run_test(
164+
TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
165+
)
164166

165167
def test_index_zero_index_three_dim_ITensor(self):
166168
class TestModule(nn.Module):
@@ -172,7 +174,9 @@ def forward(self, x, index0):
172174
input = torch.randn(2, 2, 2)
173175
index0 = torch.randint(0, 1, (1, 1))
174176
index0 = index0.to(torch.int32)
175-
self.run_test(TestModule(), [input, index0])
177+
self.run_test(
178+
TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
179+
)
176180

177181
@unittest.skipIf(
178182
ENABLED_FEATURES.tensorrt_rtx,
@@ -187,7 +191,9 @@ def forward(self, x, index0):
187191

188192
input = torch.randn(2, 2, 2)
189193
index0 = torch.tensor([True, False])
190-
self.run_test(TestModule(), [input, index0])
194+
self.run_test(
195+
TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
196+
)
191197

192198

193199
class TestIndexDynamicConstantConverter(DispatchTestCase):

0 commit comments

Comments
 (0)