Skip to content

Commit d3be8be

Browse files
committed
Add dynamic shape tests for xnnpack model tests
Add dynamic shape support to xnnpack model tests that previously only tested with static inputs. This covers DeepLab V3, EDSR, Inception V3/V4, Emformer RNNT, and MobileBERT using the same DynamicWrapper + interpolate pattern established by ResNet and ViT. Addresses #11585 Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
1 parent afc9989 commit d3be8be

6 files changed

Lines changed: 189 additions & 0 deletions

File tree

backends/xnnpack/test/models/deeplab_v3.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,37 @@ def forward(self, *args):
2222
return self.m(*args)["out"]
2323

2424

25+
class DynamicDL3Wrapper(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.m = deeplabv3_resnet50(
29+
weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT
30+
)
31+
32+
def forward(self, x):
33+
x = torch.nn.functional.interpolate(
34+
x,
35+
size=(224, 224),
36+
mode="bilinear",
37+
align_corners=True,
38+
antialias=False,
39+
)
40+
return self.m(x)["out"]
41+
42+
2543
class TestDeepLabV3(unittest.TestCase):
2644
def setUp(self):
2745
torch._dynamo.reset()
2846

2947
dl3 = DL3Wrapper()
3048
dl3 = dl3.eval()
3149
model_inputs = (torch.randn(1, 3, 224, 224),)
50+
dynamic_shapes = (
51+
{
52+
2: torch.export.Dim("height", min=224, max=455),
53+
3: torch.export.Dim("width", min=224, max=455),
54+
},
55+
)
3256

3357
def test_fp32_dl3(self):
3458

@@ -40,3 +64,13 @@ def test_fp32_dl3(self):
4064
.serialize()
4165
.run_method_and_compare_outputs()
4266
)
67+
68+
def test_fp32_dl3_dynamic(self):
69+
(
70+
Tester(DynamicDL3Wrapper(), self.model_inputs, self.dynamic_shapes)
71+
.export()
72+
.to_edge_transform_and_lower()
73+
.to_executorch()
74+
.serialize()
75+
.run_method_and_compare_outputs()
76+
)

backends/xnnpack/test/models/edsr.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,34 @@
1313
from torchsr.models import edsr_r16f64
1414

1515

16+
class DynamicEDSR(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.model = edsr_r16f64(2, False).eval()
20+
21+
def forward(self, x):
22+
x = torch.nn.functional.interpolate(
23+
x,
24+
size=(224, 224),
25+
mode="bilinear",
26+
align_corners=True,
27+
antialias=False,
28+
)
29+
return self.model(x)
30+
31+
1632
class TestEDSR(unittest.TestCase):
1733
def setUp(self):
1834
torch._dynamo.reset()
1935

2036
edsr = edsr_r16f64(2, False).eval() # noqa
2137
model_inputs = (torch.randn(1, 3, 224, 224),)
38+
dynamic_shapes = (
39+
{
40+
2: torch.export.Dim("height", min=224, max=455),
41+
3: torch.export.Dim("width", min=224, max=455),
42+
},
43+
)
2244

2345
def test_fp32_edsr(self):
2446
(
@@ -53,3 +75,13 @@ def test_qs8_edsr_no_calibrate(self):
5375
.serialize()
5476
.run_method_and_compare_outputs()
5577
)
78+
79+
def test_fp32_edsr_dynamic(self):
80+
(
81+
Tester(DynamicEDSR(), self.model_inputs, self.dynamic_shapes)
82+
.export()
83+
.to_edge_transform_and_lower()
84+
.to_executorch()
85+
.serialize()
86+
.run_method_and_compare_outputs()
87+
)

backends/xnnpack/test/models/emformer_rnnt.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ def test_fp32_emformer_joiner(self):
4848
.run_method_and_compare_outputs()
4949
)
5050

51+
def test_fp32_emformer_joiner_dynamic(self):
52+
joiner = self.Joiner()
53+
dynamic_shapes = (
54+
{0: torch.export.Dim("batch", min=1, max=4)},
55+
None,
56+
{0: torch.export.Dim("batch", min=1, max=4)},
57+
None,
58+
)
59+
(
60+
Tester(joiner, joiner.get_example_inputs(), dynamic_shapes=dynamic_shapes)
61+
.export()
62+
.to_edge_transform_and_lower()
63+
.check(["torch.ops.higher_order.executorch_call_delegate"])
64+
.to_executorch()
65+
.serialize()
66+
.run_method_and_compare_outputs()
67+
)
68+
5169
class Predictor(EmformerRnnt):
5270
def forward(self, a, b):
5371
return self.rnnt.predict(a, b, None)
@@ -96,3 +114,23 @@ def test_fp32_emformer_transcriber(self):
96114
.serialize()
97115
.run_method_and_compare_outputs()
98116
)
117+
118+
def test_fp32_emformer_transcriber_dynamic(self):
119+
transcriber = self.Transcriber()
120+
dynamic_shapes = (
121+
{0: torch.export.Dim("batch", min=1, max=4)},
122+
None,
123+
)
124+
(
125+
Tester(
126+
transcriber,
127+
transcriber.get_example_inputs(),
128+
dynamic_shapes=dynamic_shapes,
129+
)
130+
.export()
131+
.to_edge_transform_and_lower()
132+
.check(["torch.ops.higher_order.executorch_call_delegate"])
133+
.to_executorch()
134+
.serialize()
135+
.run_method_and_compare_outputs()
136+
)

backends/xnnpack/test/models/inception_v3.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,34 @@
1212
from torchvision import models
1313

1414

15+
class DynamicInceptionV3(torch.nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
self.model = models.inception_v3(weights="IMAGENET1K_V1").eval()
19+
20+
def forward(self, x):
21+
x = torch.nn.functional.interpolate(
22+
x,
23+
size=(224, 224),
24+
mode="bilinear",
25+
align_corners=True,
26+
antialias=False,
27+
)
28+
return self.model(x)
29+
30+
1531
class TestInceptionV3(unittest.TestCase):
1632
def setUp(self):
1733
torch._dynamo.reset()
1834

1935
ic3 = models.inception_v3(weights="IMAGENET1K_V1").eval() # noqa
2036
model_inputs = (torch.randn(1, 3, 224, 224),)
37+
dynamic_shapes = (
38+
{
39+
2: torch.export.Dim("height", min=224, max=455),
40+
3: torch.export.Dim("width", min=224, max=455),
41+
},
42+
)
2143

2244
all_operators = {
2345
"executorch_exir_dialects_edge__ops_aten_addmm_default",
@@ -82,3 +104,14 @@ def test_qs8_ic3_no_calibration(self):
82104
.serialize()
83105
.run_method_and_compare_outputs()
84106
)
107+
108+
def test_fp32_ic3_dynamic(self):
109+
(
110+
Tester(DynamicInceptionV3(), self.model_inputs, self.dynamic_shapes)
111+
.export()
112+
.to_edge_transform_and_lower()
113+
.check(["torch.ops.higher_order.executorch_call_delegate"])
114+
.to_executorch()
115+
.serialize()
116+
.run_method_and_compare_outputs()
117+
)

backends/xnnpack/test/models/inception_v4.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,34 @@
1111
from timm.models import inception_v4
1212

1313

14+
class DynamicInceptionV4(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.model = inception_v4(pretrained=False).eval()
18+
19+
def forward(self, x):
20+
x = torch.nn.functional.interpolate(
21+
x,
22+
size=(299, 299),
23+
mode="bilinear",
24+
align_corners=True,
25+
antialias=False,
26+
)
27+
return self.model(x)
28+
29+
1430
class TestInceptionV4(unittest.TestCase):
1531
def setUp(self):
1632
torch._dynamo.reset()
1733

1834
ic4 = inception_v4(pretrained=False).eval()
1935
model_inputs = (torch.randn(3, 299, 299).unsqueeze(0),)
36+
dynamic_shapes = (
37+
{
38+
2: torch.export.Dim("height", min=299, max=455),
39+
3: torch.export.Dim("width", min=299, max=455),
40+
},
41+
)
2042

2143
all_operators = {
2244
"executorch_exir_dialects_edge__ops_aten_addmm_default",
@@ -60,3 +82,14 @@ def test_qs8_ic4(self):
6082
.serialize()
6183
.run_method_and_compare_outputs()
6284
)
85+
86+
def test_fp32_ic4_dynamic(self):
87+
(
88+
Tester(DynamicInceptionV4(), self.model_inputs, self.dynamic_shapes)
89+
.export()
90+
.to_edge_transform_and_lower()
91+
.check(["torch.ops.higher_order.executorch_call_delegate"])
92+
.to_executorch()
93+
.serialize()
94+
.run_method_and_compare_outputs()
95+
)

backends/xnnpack/test/models/mobilebert.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def setUp(self):
3131
"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default",
3232
}
3333

34+
dynamic_shapes = (
35+
{1: torch.export.Dim("seq_length", min=2, max=32)},
36+
)
37+
3438
def test_fp32_mobilebert(self):
3539
(
3640
Tester(self.mobilebert, self.example_inputs)
@@ -53,3 +57,18 @@ def test_qs8_mobilebert(self):
5357
.serialize()
5458
.run_method_and_compare_outputs(inputs=self.example_inputs)
5559
)
60+
61+
def test_fp32_mobilebert_dynamic(self):
62+
(
63+
Tester(
64+
self.mobilebert,
65+
self.example_inputs,
66+
dynamic_shapes=self.dynamic_shapes,
67+
)
68+
.export()
69+
.to_edge_transform_and_lower()
70+
.check(["torch.ops.higher_order.executorch_call_delegate"])
71+
.to_executorch()
72+
.serialize()
73+
.run_method_and_compare_outputs(inputs=self.example_inputs)
74+
)

0 commit comments

Comments
 (0)