Skip to content

Commit 3c3ad79

Browse files
author
ssjia
committed
[ETVK][experimental] Route general conv2d through im2col + GEMM
Routes the `SlidingWindow` branch of `add_conv2d_node` (in `Convolution.cpp`) through the new `conv2d_gemm_impl` orchestrator from the previous diff, replacing the legacy direct `conv2d` shader for the general non-pointwise / non-depthwise / non-transposed case. Pointwise still uses `conv2d_pw_impl`, depthwise still uses `conv2d_dw_impl`, transposed still uses the legacy path (im2col doesn't support transposed yet). On `UNTRAINED_TinyCNNDepthEstimatorRealTime_Vulkan.pte` on Pixel 9 Pro XL (Mali → buffer im2col), total convolution time drops from 84.3 ms to 59.8 ms — a **29% reduction**. The previously-dominant `conv2d_float` (78.5 ms, ~93% of conv time) is replaced by `conv2d_im2col_buffer_float` (10.8 ms) + `conv2d_gemm_buffer_float` (42.5 ms). Pointwise and depthwise dispatches are unchanged. ``` kernel before (us) after (us) ------------------------------------------------------------------------ conv2d_float 78518.3 0.0 conv2d_gemm_buffer_float 0.0 42501.3 conv2d_im2col_buffer_float 0.0 10806.6 conv2d_pw_tiled_float 5525.0 6163.5 conv2d_dw_output_tile_3x3_b1x1_float 101.0 124.4 conv2d_dw_sned_output_tile_5x5_float 171.5 206.1 ------------------------------------------------------------------------ TOTAL conv time 84315.8 59801.9 ``` This is intentionally a thin, easily-revertible diff sitting on top of the im2col + GEMM prototype, marked experimental so we can land it as a kill-switch-able change while we validate other models. Differential Revision: [D105120965](https://our.internmc.facebook.com/intern/diff/D105120965/) [ghstack-poisoned]
1 parent 0186ccb commit 3c3ad79

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1212

1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h>
1415
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1516

1617
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
@@ -473,6 +474,27 @@ void add_conv2d_node(
473474
out_max_val);
474475
}
475476

477+
// EXPERIMENTAL: route the general SlidingWindow case through the im2col
478+
// + GEMM implementation. The im2col path picks an im2col intermediate
479+
// storage (buffer / texture2d / texture3d) per device and consistently
480+
// outperforms the legacy conv2d shader on the depth-estimator hotspots.
481+
// Transposed conv keeps using the legacy path because conv2d_gemm_impl
482+
// doesn't support it yet.
483+
if (method == Conv2dMethod::SlidingWindow) {
484+
return conv2d_gemm_impl(
485+
graph,
486+
in,
487+
weight_data,
488+
bias,
489+
stride,
490+
padding,
491+
dilation,
492+
out,
493+
clamp_out,
494+
out_min_val,
495+
out_max_val);
496+
}
497+
476498
ValueRef arg_weight = prepack_weights(graph, weight_data, method);
477499
ValueRef arg_bias = prepack_biases(
478500
graph,

0 commit comments

Comments
 (0)