Skip to content

Commit d98aa22

Browse files
authored
Arm backend: support depthwise Conv3D (pytorch#19902)
## Summary Depthwise Conv3D (`in_channels == groups`, rank-5 input) previously crashed with a `RuntimeError` inside `RewriteConvPass` because TOSA has no `DEPTHWISE_CONV3D` op. `DecomposeGroupedConvPass` already handles non-depthwise grouped Conv3D by splitting it into `groups==1` convolutions via slice→conv→cat, but it explicitly skipped the depthwise case since Conv2D depthwise maps to the native `DEPTHWISE_CONV2D` TOSA op. For Conv3D there is no such native op, so the fix is to extend `DecomposeGroupedConvPass` to stop skipping depthwise when the input is rank 5(Conv3D). The existing slice→`CONV3D`→cat decomposition can handle it correctly. ```mermaid flowchart LR DW2D["Depthwise Conv2D\n(in_channels == groups, rank 4)"] DW3D["Depthwise Conv3D\n(in_channels == groups, rank 5)"] GRP["DecomposeGroupedConvPass"] RC2D["RewriteConvPass"] RC3D["RewriteConvPass"] DELEGATE_CONV2D["DEPTHWISE_CONV2D"] DELEGATE_CONV3D["CONV3D"] DW2D --> RC2D DW3D -->|"decomposed"| GRP GRP -->|"CONV3D (groups==1)"| RC3D RC2D -->|"delegated to native op"| DELEGATE_CONV2D RC3D -->|"delegated to native op"| DELEGATE_CONV3D ``` ## Files changed: | File | Change | | --- | --- | | `backends/arm/_passes/decompose_grouped_conv_pass.py` | In `call_operator`, narrow the depthwise skip to Conv2D only (`len(input.data.shape) != 5`); for rank-5 inputs(Conv3D) fall through to the existing decomposition. | | `backends/arm/_passes/rewrite_conv_pass.py` | Update comment in `_is_conv3d` to reflect that both grouped and depthwise Conv3D are now decomposed upstream; retain the `RuntimeError` as defense-in-depth. | | `backends/arm/test/ops/test_conv3d.py` | Rewrite `test_convolution_3d_tosa_FP_depthwise` to assert delegation | ## Test result ```bash python -m pytest backends/arm/test/ops/test_conv3d.py::test_convolution_u55_INT_not_delegated_3d # 2 passed, 0 failed. ``` ```bash lintrunner -a \ backends/arm/_passes/decompose_grouped_conv_pass.py \ backends/arm/_passes/rewrite_conv_pass.py \ backends/arm/test/ops/test_conv3d.py # ok No lint issues. ``` Signed-off-by: Youngsik Yang <vacu9708@gmail.com>
1 parent e983693 commit d98aa22

3 files changed

Lines changed: 48 additions & 18 deletions

File tree

backends/arm/_passes/decompose_grouped_conv_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,10 @@ def call_operator(self, op, args, kwargs, meta):
257257

258258
input_node = args[0]
259259
if DecomposeGroupedConvPass._is_depthwise_conv(input_node, groups, transposed):
260-
# This is a depthwise convolution which is handled elsewhere
261-
return super().call_operator(op, args, kwargs, meta)
260+
# Conv2D depthwise maps to TOSA DEPTHWISE_CONV2D — handled in RewriteConvPass.
261+
# Conv3D has no DEPTHWISE_CONV3D, so fall through and decompose like grouped conv.
262+
if len(input_node.data.shape) != 5:
263+
return super().call_operator(op, args, kwargs, meta)
262264

263265
weight_node = args[1]
264266
bias_node = args[2]

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,13 @@ def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:
129129

130130
def _is_conv3d(self, rank, groups) -> bool:
131131
if rank == 5:
132-
# A Conv3D is considered depthwise if Group == InChannels and
133-
# Group * N == OutChannels, where N is a possitive integer.
134-
# Currently we do not support depthwise or grouped conv3d.
135-
# @TODO Add grouped/depthwise conv3d support or reject in partitioner.
132+
# Both grouped and depthwise Conv3D are decomposed into groups==1
133+
# convolutions by DecomposeGroupedConvPass before reaching here.
134+
# This guard is defense-in-depth for paths that bypass that pass.
136135
if groups != 1:
137136
raise RuntimeError(
138-
"CONV3D with groups != 1 is not supported in the Arm backend."
137+
"CONV3D with groups != 1 reached unexpectedly; "
138+
"DecomposeGroupedConvPass should have decomposed it first."
139139
)
140140
return True
141141
return False

backends/arm/test/ops/test_conv3d.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,32 @@ def forward(self, x):
212212
return self.conv(x)
213213

214214

215+
class GroupedConv3d(torch.nn.Module):
216+
"""Non-depthwise grouped Conv3d (in_channels != groups).
217+
218+
Split into ``groups`` plain convolutions by DecomposeGroupedConvPass, so it
219+
is delegated unlike the depthwise case.
220+
221+
"""
222+
223+
def __init__(self, dtype=torch.float):
224+
super().__init__()
225+
self.dtype = dtype
226+
self.conv = torch.nn.Conv3d(
227+
in_channels=4,
228+
out_channels=4,
229+
kernel_size=(3, 3, 3),
230+
padding=1,
231+
groups=2,
232+
).to(dtype)
233+
234+
def get_inputs(self):
235+
return (torch.randn(1, 4, 8, 8, 8).to(self.dtype),)
236+
237+
def forward(self, x):
238+
return self.conv(x)
239+
240+
215241
conv3d_2x2_3x2x14x14_nobias = Conv3d(
216242
in_channels=2,
217243
out_channels=3,
@@ -623,19 +649,21 @@ def test_convolution_3d_tosa_INT_multi_op():
623649

624650

625651
def test_convolution_3d_tosa_FP_depthwise():
626-
"""Depthwise or Grouped Conv3d should be rejected until grouped support
627-
exists.
652+
"""Depthwise Conv3d should be delegated, decomposed into groups==1
653+
convolutions by DecomposeGroupedConvPass.
628654
"""
629655
model = DepthwiseConv3d()
630-
pipeline = TosaPipelineFP[input_t](
631-
model,
632-
model.get_inputs(),
633-
aten_op,
634-
exir_op,
635-
run_on_tosa_ref_model=False,
636-
)
637-
with pytest.raises(RuntimeError, match="CONV3D with groups != 1"):
638-
pipeline.run()
656+
pipeline = TosaPipelineFP[input_t](model, model.get_inputs(), aten_op, exir_op)
657+
pipeline.run()
658+
659+
660+
def test_convolution_3d_tosa_FP_grouped():
661+
"""Non-depthwise grouped Conv3d should be delegated, decomposed into
662+
groups==1 convolutions by DecomposeGroupedConvPass.
663+
"""
664+
model = GroupedConv3d()
665+
pipeline = TosaPipelineFP[input_t](model, model.get_inputs(), aten_op, exir_op)
666+
pipeline.run()
639667

640668

641669
@common.parametrize("test_data", test_data_INT)

0 commit comments

Comments
 (0)