-
Notifications
You must be signed in to change notification settings - Fork 962
Expand file tree
/
Copy pathtest_while.py
More file actions
287 lines (233 loc) · 8.41 KB
/
test_while.py
File metadata and controls
287 lines (233 loc) · 8.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Tuple
import torch
import torch.fx
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU85PipelineINT,
OpNotSupportedPipeline,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)
from pytest import mark
input_single = Tuple[torch.Tensor]
input_double = Tuple[torch.Tensor, torch.Tensor]
class WhileTwoInputsTwoOutputs(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, lhs: torch.Tensor, rhs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def cond_fn(lhs_val: torch.Tensor, rhs_val: torch.Tensor) -> torch.Tensor:
total = torch.sum(rhs_val)
zero = torch.zeros_like(total)
return torch.gt(total, zero).squeeze()
def body_fn(
lhs_val: torch.Tensor, rhs_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
next_lhs = torch.add(lhs_val, rhs_val)
next_rhs = torch.sub(rhs_val, torch.full((1,), 1.0))
return (next_lhs, next_rhs)
result = torch.ops.higher_order.while_loop(
cond_fn,
body_fn,
(lhs, rhs),
(),
)
return result # type: ignore
class WhileOneInputOneBufferTwoOutputs(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("threshold", torch.tensor((30.0,)))
def forward(self, value: torch.Tensor) -> torch.Tensor:
def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor:
total = value.sum()
return torch.lt(total, limit).squeeze()
def body_fn(
value: torch.Tensor, limit: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return (torch.add(value, value), limit.clone())
result = torch.ops.higher_order.while_loop(
cond_fn,
body_fn,
(value, self.threshold),
(),
)
return result # type: ignore
class DecreasingOutput(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, value: torch.Tensor) -> torch.Tensor:
def cond_fn(value: torch.Tensor) -> torch.Tensor:
total = value.sum()
return torch.gt(total, torch.full((1,), 60.0)).squeeze()
def body_fn(value: torch.Tensor) -> Tuple[torch.Tensor]:
return (torch.div(value, torch.full((1,), 2.0)),)
result = torch.ops.higher_order.while_loop(
cond_fn,
body_fn,
(value,),
(),
)
return result[0] # type: ignore
class WhileAdditionalArg(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("threshold", torch.tensor((128.0,)))
def forward(self, value: torch.Tensor) -> torch.Tensor:
def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor:
total = value.sum()
return torch.lt(total, limit).squeeze()
def body_fn(value: torch.Tensor, limit: torch.Tensor) -> tuple[torch.Tensor]:
return (torch.add(value, value),)
result = torch.ops.higher_order.while_loop(
cond_fn,
body_fn,
(value,),
(self.threshold,),
)
return result # type: ignore
class WhileSingleCapturedOutput(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("threshold", torch.tensor((128.0,)))
def forward(self, value: torch.Tensor) -> torch.Tensor:
def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor:
total = value.sum()
return torch.lt(total, limit).squeeze()
def body_fn(
value: torch.Tensor, limit: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return (torch.add(value, value), limit.clone())
result = torch.ops.higher_order.while_loop(
cond_fn,
body_fn,
(value, self.threshold),
(),
)
return result[0] # type: ignore
class WhileLargeThreshold(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("threshold", torch.tensor((400.0,)))
def forward(self, value: torch.Tensor) -> torch.Tensor:
def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor:
total = value.sum()
return torch.lt(total, limit).squeeze()
def body_fn(value: torch.Tensor, limit: torch.Tensor) -> tuple[torch.Tensor]:
return (torch.add(value, value),)
result = torch.ops.higher_order.while_loop(
cond_fn,
body_fn,
(value,),
(self.threshold,),
)
return result # type: ignore
def _single_input_case(
module_factory: Callable[[], torch.nn.Module],
) -> Callable[[], Tuple[torch.nn.Module, input_single]]:
def _create() -> Tuple[torch.nn.Module, input_single]:
return module_factory(), (torch.ones(2, 3, 2, 2),)
return _create
def _dual_input_case(
module_factory: Callable[[], torch.nn.Module],
) -> Callable[[], Tuple[torch.nn.Module, input_double]]:
def _create() -> Tuple[torch.nn.Module, input_double]:
return module_factory(), (torch.zeros(2, 3), torch.full((2, 3), -2.0))
return _create
test_cases: dict[str, Callable[[], Tuple[torch.nn.Module, Tuple]]] = {
"two_in_two_out": _dual_input_case(WhileTwoInputsTwoOutputs),
"one_in_one_buffer_two_out": _single_input_case(WhileOneInputOneBufferTwoOutputs),
"decreasing_output": _single_input_case(DecreasingOutput),
"additional_arg": _single_input_case(WhileAdditionalArg),
"two_in_one_captured_out": _single_input_case(WhileSingleCapturedOutput),
"large_threshold": _single_input_case(WhileLargeThreshold),
}
@common.parametrize(
"case",
test_cases,
)
def test_while_loop_tosa_FP(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
module, example_inputs = case()
pipeline = TosaPipelineFP[tuple](
module,
example_inputs,
"torch.ops.higher_order.while_loop",
tosa_extensions=["cf"],
)
pipeline.run()
@common.parametrize(
"case",
test_cases,
)
def test_while_loop_tosa_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
module, example_inputs = case()
pipeline = TosaPipelineINT[tuple](
module,
example_inputs,
"torch.ops.higher_order.while_loop",
tosa_extensions=["cf"],
)
pipeline.add_stage_after(
"to_edge_transform_and_lower",
ArmTester.check_not,
pipeline.tester,
["torch.ops.higher_order.while_loop"],
)
pipeline.run()
@common.parametrize(
"case",
test_cases,
)
def test_while_loop_u55_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
module, example_inputs = case()
OpNotSupportedPipeline[tuple](
module,
example_inputs,
non_delegated_ops={"torch.ops.higher_order.while_loop": 1},
u55_subset=True,
).run()
@common.parametrize(
"case",
test_cases,
)
@common.XfailIfNoCorstone320
def test_while_loop_u85_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
module, example_inputs = case()
EthosU85PipelineINT[tuple](
module,
example_inputs,
"torch.ops.higher_order.while_loop",
).run()
@mark.skip("While not supported in model_converter.")
@common.parametrize(
"case",
test_cases,
)
@common.SkipIfNoModelConverter
def test_while_loop_vgf_FP(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
module, example_inputs = case()
VgfPipeline[tuple](
module,
example_inputs,
"torch.ops.higher_order.while_loop",
tosa_version="TOSA-1.0+FP",
).run()
@mark.skip("While not supported in model_converter.")
@common.parametrize(
"case",
test_cases,
)
@common.SkipIfNoModelConverter
def test_while_loop_vgf_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
module, example_inputs = case()
VgfPipeline[tuple](
module,
example_inputs,
"torch.ops.higher_order.while_loop",
).run()