Skip to content

Commit 9a10876

Browse files
authored
Add dynamic shape tests for xnnpack model tests (#18701)
## Summary Add dynamic shape tests for xnnpack model tests that previously only tested with static inputs, as requested in #11585. **Models covered:** - **DeepLab V3** — dynamic height/width with DynamicWrapper + interpolate - **EDSR** — dynamic height/width with DynamicWrapper + interpolate - **Inception V3/V4** — dynamic height/width with DynamicWrapper + interpolate - **Emformer RNNT** — dynamic batch dimension for joiner and transcriber - **MobileBERT** — dynamic sequence length All dynamic tests follow the established pattern from `resnet.py` (`DynamicResNet`) and `torchvision_vit.py` (`DynamicViT`): a wrapper module that resizes dynamic spatial inputs to fixed dimensions via `F.interpolate` before feeding into the model. **Models already covered by existing dynamic tests (no changes):** - MobileNet V2, MobileNet V3, ResNet, ViT **Models skipped:** - `llama2_et_example.py` — LLM, requires separate dynamic shape strategy - `very_big_model.py` — synthetic test model <details> <summary>Before (only static tests)</summary> ``` $ python -m pytest backends/xnnpack/test/models/deeplab_v3.py -v backends/xnnpack/test/models/deeplab_v3.py::TestDeepLabV3::test_fp32_dl3 PASSED (no dynamic shape test exists) ``` </details> <details> <summary>After (static + dynamic tests pass)</summary> ``` $ python -m pytest backends/xnnpack/test/models/deeplab_v3.py::TestDeepLabV3::test_fp32_dl3_dynamic -v backends/xnnpack/test/models/deeplab_v3.py::TestDeepLabV3::test_fp32_dl3_dynamic PASSED [100%] ======================== 1 passed in 28.84s ============================== $ python -m pytest backends/xnnpack/test/models/mobilebert.py::TestMobilebert::test_fp32_mobilebert_dynamic -v backends/xnnpack/test/models/mobilebert.py::TestMobilebert::test_fp32_mobilebert_dynamic PASSED [100%] ======================== 1 passed in 160.79s (0:02:40) =================== ``` </details> ## Test plan - [x] `test_fp32_dl3_dynamic` passes locally - [x] `test_fp32_mobilebert_dynamic` passes locally - [ ] CI xnnpack test suite passes cc @GregoryComer @digantdesai @cbilgin --------- Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
1 parent 84aa213 commit 9a10876

File tree

6 files changed

+187
-0
lines changed

6 files changed

+187
-0
lines changed

backends/xnnpack/test/models/deeplab_v3.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,37 @@ def forward(self, *args):
2222
return self.m(*args)["out"]
2323

2424

25+
class DynamicDL3Wrapper(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.m = deeplabv3_resnet50(
29+
weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT
30+
)
31+
32+
def forward(self, x):
33+
x = torch.nn.functional.interpolate(
34+
x,
35+
size=(224, 224),
36+
mode="bilinear",
37+
align_corners=True,
38+
antialias=False,
39+
)
40+
return self.m(x)["out"]
41+
42+
2543
class TestDeepLabV3(unittest.TestCase):
2644
def setUp(self):
2745
torch._dynamo.reset()
2846

2947
dl3 = DL3Wrapper()
3048
dl3 = dl3.eval()
3149
model_inputs = (torch.randn(1, 3, 224, 224),)
50+
dynamic_shapes = (
51+
{
52+
2: torch.export.Dim("height", min=224, max=455),
53+
3: torch.export.Dim("width", min=224, max=455),
54+
},
55+
)
3256

3357
def test_fp32_dl3(self):
3458

@@ -40,3 +64,13 @@ def test_fp32_dl3(self):
4064
.serialize()
4165
.run_method_and_compare_outputs()
4266
)
67+
68+
def test_fp32_dl3_dynamic(self):
69+
(
70+
Tester(DynamicDL3Wrapper(), self.model_inputs, self.dynamic_shapes)
71+
.export()
72+
.to_edge_transform_and_lower()
73+
.to_executorch()
74+
.serialize()
75+
.run_method_and_compare_outputs()
76+
)

backends/xnnpack/test/models/edsr.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,34 @@
1313
from torchsr.models import edsr_r16f64
1414

1515

16+
class DynamicEDSR(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.model = edsr_r16f64(2, False).eval()
20+
21+
def forward(self, x):
22+
x = torch.nn.functional.interpolate(
23+
x,
24+
size=(224, 224),
25+
mode="bilinear",
26+
align_corners=True,
27+
antialias=False,
28+
)
29+
return self.model(x)
30+
31+
1632
class TestEDSR(unittest.TestCase):
1733
def setUp(self):
1834
torch._dynamo.reset()
1935

2036
edsr = edsr_r16f64(2, False).eval() # noqa
2137
model_inputs = (torch.randn(1, 3, 224, 224),)
38+
dynamic_shapes = (
39+
{
40+
2: torch.export.Dim("height", min=224, max=455),
41+
3: torch.export.Dim("width", min=224, max=455),
42+
},
43+
)
2244

2345
def test_fp32_edsr(self):
2446
(
@@ -53,3 +75,13 @@ def test_qs8_edsr_no_calibrate(self):
5375
.serialize()
5476
.run_method_and_compare_outputs()
5577
)
78+
79+
def test_fp32_edsr_dynamic(self):
80+
(
81+
Tester(DynamicEDSR(), self.model_inputs, self.dynamic_shapes)
82+
.export()
83+
.to_edge_transform_and_lower()
84+
.to_executorch()
85+
.serialize()
86+
.run_method_and_compare_outputs()
87+
)

backends/xnnpack/test/models/emformer_rnnt.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ def test_fp32_emformer_joiner(self):
4848
.run_method_and_compare_outputs()
4949
)
5050

51+
def test_fp32_emformer_joiner_dynamic(self):
52+
joiner = self.Joiner()
53+
dynamic_shapes = (
54+
{0: torch.export.Dim("batch", min=1, max=4)},
55+
None,
56+
{0: torch.export.Dim("batch", min=1, max=4)},
57+
None,
58+
)
59+
(
60+
Tester(joiner, joiner.get_example_inputs(), dynamic_shapes=dynamic_shapes)
61+
.export()
62+
.to_edge_transform_and_lower()
63+
.check(["torch.ops.higher_order.executorch_call_delegate"])
64+
.to_executorch()
65+
.serialize()
66+
.run_method_and_compare_outputs()
67+
)
68+
5169
class Predictor(EmformerRnnt):
5270
def forward(self, a, b):
5371
return self.rnnt.predict(a, b, None)
@@ -96,3 +114,23 @@ def test_fp32_emformer_transcriber(self):
96114
.serialize()
97115
.run_method_and_compare_outputs()
98116
)
117+
118+
def test_fp32_emformer_transcriber_dynamic(self):
119+
transcriber = self.Transcriber()
120+
dynamic_shapes = (
121+
{0: torch.export.Dim("batch", min=1, max=4)},
122+
None,
123+
)
124+
(
125+
Tester(
126+
transcriber,
127+
transcriber.get_example_inputs(),
128+
dynamic_shapes=dynamic_shapes,
129+
)
130+
.export()
131+
.to_edge_transform_and_lower()
132+
.check(["torch.ops.higher_order.executorch_call_delegate"])
133+
.to_executorch()
134+
.serialize()
135+
.run_method_and_compare_outputs()
136+
)

backends/xnnpack/test/models/inception_v3.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,34 @@
1212
from torchvision import models
1313

1414

15+
class DynamicInceptionV3(torch.nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
self.model = models.inception_v3(weights="IMAGENET1K_V1").eval()
19+
20+
def forward(self, x):
21+
x = torch.nn.functional.interpolate(
22+
x,
23+
size=(224, 224),
24+
mode="bilinear",
25+
align_corners=True,
26+
antialias=False,
27+
)
28+
return self.model(x)
29+
30+
1531
class TestInceptionV3(unittest.TestCase):
1632
def setUp(self):
1733
torch._dynamo.reset()
1834

1935
ic3 = models.inception_v3(weights="IMAGENET1K_V1").eval() # noqa
2036
model_inputs = (torch.randn(1, 3, 224, 224),)
37+
dynamic_shapes = (
38+
{
39+
2: torch.export.Dim("height", min=224, max=455),
40+
3: torch.export.Dim("width", min=224, max=455),
41+
},
42+
)
2143

2244
all_operators = {
2345
"executorch_exir_dialects_edge__ops_aten_addmm_default",
@@ -82,3 +104,14 @@ def test_qs8_ic3_no_calibration(self):
82104
.serialize()
83105
.run_method_and_compare_outputs()
84106
)
107+
108+
def test_fp32_ic3_dynamic(self):
109+
(
110+
Tester(DynamicInceptionV3(), self.model_inputs, self.dynamic_shapes)
111+
.export()
112+
.to_edge_transform_and_lower()
113+
.check(["torch.ops.higher_order.executorch_call_delegate"])
114+
.to_executorch()
115+
.serialize()
116+
.run_method_and_compare_outputs()
117+
)

backends/xnnpack/test/models/inception_v4.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,34 @@
1111
from timm.models import inception_v4
1212

1313

14+
class DynamicInceptionV4(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.model = inception_v4(pretrained=False).eval()
18+
19+
def forward(self, x):
20+
x = torch.nn.functional.interpolate(
21+
x,
22+
size=(299, 299),
23+
mode="bilinear",
24+
align_corners=True,
25+
antialias=False,
26+
)
27+
return self.model(x)
28+
29+
1430
class TestInceptionV4(unittest.TestCase):
1531
def setUp(self):
1632
torch._dynamo.reset()
1733

1834
ic4 = inception_v4(pretrained=False).eval()
1935
model_inputs = (torch.randn(3, 299, 299).unsqueeze(0),)
36+
dynamic_shapes = (
37+
{
38+
2: torch.export.Dim("height", min=299, max=455),
39+
3: torch.export.Dim("width", min=299, max=455),
40+
},
41+
)
2042

2143
all_operators = {
2244
"executorch_exir_dialects_edge__ops_aten_addmm_default",
@@ -60,3 +82,14 @@ def test_qs8_ic4(self):
6082
.serialize()
6183
.run_method_and_compare_outputs()
6284
)
85+
86+
def test_fp32_ic4_dynamic(self):
87+
(
88+
Tester(DynamicInceptionV4(), self.model_inputs, self.dynamic_shapes)
89+
.export()
90+
.to_edge_transform_and_lower()
91+
.check(["torch.ops.higher_order.executorch_call_delegate"])
92+
.to_executorch()
93+
.serialize()
94+
.run_method_and_compare_outputs()
95+
)

backends/xnnpack/test/models/mobilebert.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def setUp(self):
3131
"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default",
3232
}
3333

34+
dynamic_shapes = ({1: torch.export.Dim("seq_length", min=2, max=32)},)
35+
3436
def test_fp32_mobilebert(self):
3537
(
3638
Tester(self.mobilebert, self.example_inputs)
@@ -53,3 +55,18 @@ def test_qs8_mobilebert(self):
5355
.serialize()
5456
.run_method_and_compare_outputs(inputs=self.example_inputs)
5557
)
58+
59+
def test_fp32_mobilebert_dynamic(self):
60+
(
61+
Tester(
62+
self.mobilebert,
63+
self.example_inputs,
64+
dynamic_shapes=self.dynamic_shapes,
65+
)
66+
.export()
67+
.to_edge_transform_and_lower()
68+
.check(["torch.ops.higher_order.executorch_call_delegate"])
69+
.to_executorch()
70+
.serialize()
71+
.run_method_and_compare_outputs(inputs=self.example_inputs)
72+
)

0 commit comments

Comments
 (0)