diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 6bf1b58f71..5a81f74ec7 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -114,7 +114,7 @@ def test_compile_weight_stripped_engine(self): ) def test_weight_stripped_engine_sizes(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") - example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + example_inputs = (torch.randn((2, 3, 224, 224)).to("cuda"),) exp_program = torch.export.export(pyt_model, example_inputs) weight_included_engine = convert_exported_program_to_serialized_trt_engine( exp_program, @@ -159,14 +159,14 @@ def test_weight_stripped_engine_sizes(self): ) def test_weight_stripped_engine_results(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") - example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + example_inputs = (torch.randn((2, 3, 224, 224)).to("cuda"),) # Mark the dim0 of inputs as dynamic batch = torch.export.Dim("batch", min=1, max=200) exp_program = torch.export.export( pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} ) - inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + inputs = [torch.rand((2, 3, 224, 224)).to("cuda")] trt_gm = torch_trt.dynamo.compile( exp_program, @@ -551,12 +551,12 @@ def forward(self, x): ) def test_two_TRTRuntime_in_refitting(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") - example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + example_inputs = (torch.randn((2, 3, 224, 224)).to("cuda"),) batch = torch.export.Dim("batch", min=1, max=200) exp_program = torch.export.export( pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} ) - inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + inputs = [torch.rand((2, 3, 224, 224)).to("cuda")] pyt_results = pyt_model(*inputs)