Skip to content

Commit ae07d06

Browse files
authored
[ET-VK] Only partition copy.default to Vulkan when it is a no-op (#18576)
Summary: Previously, copy.default was unconditionally added to the Vulkan partition as an ephemeral op (assumed to always be a no-op removed by RemoveRedundantOpsTransform). However, copy.default can also represent a dtype cast, and blindly partitioning casts to Vulkan causes runtime failures since the Vulkan backend has no copy implementation. Replace the unconditional registration with a custom are_node_inputs_supported_fn that checks both dtype and shape match between src and dst. Only same-dtype, same-shape copies (true no-ops) are partitioned to Vulkan; dtype casts remain on CPU. Authored with Claude. Test Plan: Exported EfficientNet for Vulkan, which contains copy.default for dtype casts, and confirmed it no longer errors during Vulkan compilation. No-op copies in other models (e.g. MobileNetV3) still partition correctly. Reviewers: Subscribers: Tasks: Tags: --- [ET-VK] Prevent hardswish decomposition by adding to ops_not_to_decompose Summary: Without this change, aten.hardswish.default decomposes into several arithmetic ops (add, clamp, mul, div) which prevents the Vulkan conv2d depthwise fusion pass from recognizing and fusing the activation. Adding hardswish to ops_not_to_decompose keeps it as a single op so the conv2d output tile shader can apply the clamp fusion. Also adds the conv2d_dw_output_tile_3x3_b1x1_clamp shader variant needed for the fused path. Authored with Claude. Test Plan: Exported MobileNetV3 for Vulkan and verified hardswish no longer decomposes. Reviewers: Subscribers: Tasks: Tags: --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/executorch/pull/18576). * __->__ #18576 * #18575 cc @manuelcandales @digantdesai @cbilgin
1 parent 330a567 commit ae07d06

3 files changed

Lines changed: 31 additions & 3 deletions

File tree

backends/vulkan/op_registry.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,6 @@ def update_features_impl(op: OpKey):
169169
# Guard and assert ops
170170
torch.ops.aten._assert_scalar.default,
171171
torch.ops.aten.sym_constrain_range_for_size.default,
172-
# copy.default is a no-op when src dtype matches dst dtype; removed by
173-
# RemoveRedundantOpsTransform before execution.
174-
exir_ops.edge.aten.copy.default,
175172
]
176173
)
177174
def register_ephemeral_ops():
@@ -181,6 +178,31 @@ def register_ephemeral_ops():
181178
)
182179

183180

181+
def _check_copy_is_noop(node: torch.fx.Node) -> bool:
182+
"""Only support copy.default when it's a no-op (same dtype and shape)."""
183+
src = node.args[1]
184+
if not isinstance(src, torch.fx.Node):
185+
return False
186+
src_val = src.meta.get("val")
187+
dst_val = node.meta.get("val")
188+
if src_val is None or dst_val is None:
189+
return False
190+
return src_val.dtype == dst_val.dtype and src_val.shape == dst_val.shape
191+
192+
193+
@update_features(
194+
[
195+
exir_ops.edge.aten.copy.default,
196+
]
197+
)
198+
def register_copy_op():
199+
return OpFeatures(
200+
inputs_storage=utils.ANY_STORAGE,
201+
supports_resize=True,
202+
are_node_inputs_supported_fn=_check_copy_is_noop,
203+
)
204+
205+
184206
# =============================================================================
185207
# UnaryOp.cpp
186208
# =============================================================================
@@ -193,6 +215,7 @@ def register_ephemeral_ops():
193215
exir_ops.edge.aten.exp.default,
194216
exir_ops.edge.aten.gelu.default,
195217
exir_ops.edge.aten.hardshrink.default,
218+
exir_ops.edge.aten.hardswish.default,
196219
exir_ops.edge.aten.hardtanh.default,
197220
exir_ops.edge.aten.neg.default,
198221
exir_ops.edge.aten.relu.default,

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
# pyre-ignore
4848
ops_not_to_decompose = [
49+
torch.ops.aten.hardswish.default,
4950
torch.ops.aten.upsample_nearest2d.vec,
5051
]
5152

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ conv2d_dw_output_tile:
2828
- NAME: conv2d_dw_output_tile_3x3_b1x1
2929
BATCH_SIZE_X: 1
3030
BATCH_SIZE_Y: 1
31+
- NAME: conv2d_dw_output_tile_3x3_b1x1_clamp
32+
OPERATOR: clamp(X, A, B)
33+
BATCH_SIZE_X: 1
34+
BATCH_SIZE_Y: 1

0 commit comments

Comments
 (0)