Commit d98aa22
authored
Arm backend: support depthwise Conv3D (pytorch#19902)
## Summary
Depthwise Conv3D (`in_channels == groups`, rank-5 input) previously
crashed with a `RuntimeError` inside `RewriteConvPass` because TOSA has
no `DEPTHWISE_CONV3D` op. `DecomposeGroupedConvPass` already handles
non-depthwise grouped Conv3D by splitting it into `groups==1`
convolutions via slice→conv→cat, but it explicitly skipped the depthwise
case since Conv2D depthwise maps to the native `DEPTHWISE_CONV2D` TOSA
op.
For Conv3D there is no such native op, so the fix is to extend
`DecomposeGroupedConvPass` to stop skipping depthwise when the input is
rank 5(Conv3D).
The existing slice→`CONV3D`→cat decomposition can handle it correctly.
```mermaid
flowchart LR
DW2D["Depthwise Conv2D\n(in_channels == groups, rank 4)"]
DW3D["Depthwise Conv3D\n(in_channels == groups, rank 5)"]
GRP["DecomposeGroupedConvPass"]
RC2D["RewriteConvPass"]
RC3D["RewriteConvPass"]
DELEGATE_CONV2D["DEPTHWISE_CONV2D"]
DELEGATE_CONV3D["CONV3D"]
DW2D --> RC2D
DW3D -->|"decomposed"| GRP
GRP -->|"CONV3D (groups==1)"| RC3D
RC2D -->|"delegated to native op"| DELEGATE_CONV2D
RC3D -->|"delegated to native op"| DELEGATE_CONV3D
```
## Files changed:
| File | Change |
| --- | --- |
| `backends/arm/_passes/decompose_grouped_conv_pass.py` | In
`call_operator`, narrow the depthwise skip to Conv2D only
(`len(input.data.shape) != 5`); for rank-5 inputs(Conv3D) fall through
to the existing decomposition. |
| `backends/arm/_passes/rewrite_conv_pass.py` | Update comment in
`_is_conv3d` to reflect that both grouped and depthwise Conv3D are now
decomposed upstream; retain the `RuntimeError` as defense-in-depth. |
| `backends/arm/test/ops/test_conv3d.py` | Rewrite
`test_convolution_3d_tosa_FP_depthwise` to assert delegation |
## Test result
```bash
python -m pytest backends/arm/test/ops/test_conv3d.py::test_convolution_u55_INT_not_delegated_3d
# 2 passed, 0 failed.
```
```bash
lintrunner -a \
backends/arm/_passes/decompose_grouped_conv_pass.py \
backends/arm/_passes/rewrite_conv_pass.py \
backends/arm/test/ops/test_conv3d.py
# ok No lint issues.
```
Signed-off-by: Youngsik Yang <vacu9708@gmail.com>1 parent e983693 commit d98aa22
3 files changed
Lines changed: 48 additions & 18 deletions
File tree
- backends/arm
- _passes
- test/ops
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
257 | 257 | | |
258 | 258 | | |
259 | 259 | | |
260 | | - | |
261 | | - | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
262 | 264 | | |
263 | 265 | | |
264 | 266 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
129 | 129 | | |
130 | 130 | | |
131 | 131 | | |
132 | | - | |
133 | | - | |
134 | | - | |
135 | | - | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
136 | 135 | | |
137 | 136 | | |
138 | | - | |
| 137 | + | |
| 138 | + | |
139 | 139 | | |
140 | 140 | | |
141 | 141 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
212 | 212 | | |
213 | 213 | | |
214 | 214 | | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
215 | 241 | | |
216 | 242 | | |
217 | 243 | | |
| |||
623 | 649 | | |
624 | 650 | | |
625 | 651 | | |
626 | | - | |
627 | | - | |
| 652 | + | |
| 653 | + | |
628 | 654 | | |
629 | 655 | | |
630 | | - | |
631 | | - | |
632 | | - | |
633 | | - | |
634 | | - | |
635 | | - | |
636 | | - | |
637 | | - | |
638 | | - | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
639 | 667 | | |
640 | 668 | | |
641 | 669 | | |
| |||
0 commit comments