Skip to content

Commit 57b265e

Browse files
hariharans29Copilotgithub-actions[bot]
authored
[MLAS] Add depthwise with multiplier conv special kernel for NCHW data layout on Avx512 (#27874)
### Description Adds a special AVX512 kernel for depthwise conv with multiplier = 2. These improve the performance of 3 costly conv operations (7x7 kernels) in the MobileClip model by approx 2.4x (will share MLAS benchmark numbers). These are 3 ops with 1) Cin=64, Cout=128, group=64, H=64, W=64, kH=7, kW=7 2) Cin=128, Cout=256, group=128, H=32, W=32, kH=7, kW=7 3) Cin=256, Cout=512, group=256, H=16, W=16, kH=7, kW=7 These Conv operations cannot be dispateched to NCHWc as the Cout per group is sub-block size. On AVX512, the block size is 16 and the Cout per group is only 2. There is a special depthwise kernel in the NCHWc suite but it can only handle Cout per group = 1. MLAS Benchmark Before and After comparison: | Benchmark | BEFORE mean (ns) | AFTER mean (ns) | Speedup | |---|---:|---:|---:| | SCONV_NCHW G64 | 3,151,190 | 1,391,419 | 2.26x | | SCONV_NCHW G128 | 1,646,040 | 824,654 | 2.00x | | SCONV_NCHW G256 | 978,843 | 533,375 | 1.84x | | SCONV_NCHW_THREADED G64 | 873,283 | 367,722 | 2.37x | | SCONV_NCHW_THREADED G128 | 445,786 | 226,777 | 1.97x | | SCONV_NCHW_THREADED G256 | 264,473 | 147,997 | 1.79x | ### Motivation and Context Just by optimizing these 3 conv operations, MobileClip is about 700us-850us faster and the entire model is <14ms on an AVX512 machine. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent bec2792 commit 57b265e

13 files changed

Lines changed: 998 additions & 90 deletions

cmake/onnxruntime_mlas.cmake

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
2323
${MLAS_SRC_DIR}/qgemm.cpp
2424
${MLAS_SRC_DIR}/qdwconv.cpp
2525
${MLAS_SRC_DIR}/convolve.cpp
26+
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_greater_than_1.cpp
2627
${MLAS_SRC_DIR}/convsym.cpp
2728
${MLAS_SRC_DIR}/pooling.cpp
2829
${MLAS_SRC_DIR}/transpose.cpp
@@ -118,7 +119,7 @@ function(setup_mlas_source_for_windows)
118119
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
119120
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
120121
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
121-
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
122+
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp
122123
)
123124

124125
set(mlas_platform_preprocess_srcs
@@ -207,6 +208,7 @@ function(setup_mlas_source_for_windows)
207208
${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp
208209
${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp
209210
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
211+
${MLAS_SRC_DIR}/intrinsics/avx512/sconv_nchw_depthwise_multiplier_greater_than_1_avx512f.cpp
210212
)
211213

212214
set_source_files_properties(${mlas_platform_srcs_avx512} PROPERTIES COMPILE_FLAGS "/arch:AVX512")
@@ -501,7 +503,7 @@ else()
501503
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
502504
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
503505
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
504-
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
506+
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp
505507
)
506508

507509
# Conditionally add the SVE implementation if compiler supports it
@@ -778,6 +780,7 @@ endif()
778780
${MLAS_SRC_DIR}/intrinsics/avx512/gelu_avx512f.cpp
779781
${MLAS_SRC_DIR}/intrinsics/avx512/silu_avx512f.cpp
780782
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
783+
${MLAS_SRC_DIR}/intrinsics/avx512/sconv_nchw_depthwise_multiplier_greater_than_1_avx512f.cpp
781784
)
782785
set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f")
783786

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ enum MLAS_CONV_ALGORITHM {
877877
MlasConvAlgorithmGemmDirect,
878878
MlasConvAlgorithmExpandThenGemm,
879879
MlasConvAlgorithmExpandThenGemmSegmented,
880+
MlasConvAlgorithmDepthwiseMultiplierGreaterThan1,
880881
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
881882
MlasConvAlgorithmDepthwise,
882883
#endif

onnxruntime/core/mlas/lib/convolve.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,53 @@ struct MLAS_CONV_WORK_BLOCK {
4242
ptrdiff_t TargetThreadCount;
4343
};
4444

45+
static
46+
void
47+
MlasDepthwiseMultiplierGreaterThan1Threaded(
48+
void* Context,
49+
ptrdiff_t Index
50+
)
51+
{
52+
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
53+
54+
const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;
55+
const float* Zeros = nullptr;
56+
57+
const size_t GroupCount = Parameters->GroupCount;
58+
const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;
59+
60+
size_t BatchGroupStart;
61+
size_t BatchGroupRemaining;
62+
63+
MlasPartitionWork(Index, WorkBlock->TargetThreadCount, BatchGroupCount,
64+
&BatchGroupStart, &BatchGroupRemaining);
65+
66+
const size_t BatchGroupEnd = BatchGroupStart + BatchGroupRemaining;
67+
68+
const size_t FilterCount = Parameters->FilterCount;
69+
const size_t OutputSize = Parameters->OutputSize;
70+
const size_t K = Parameters->K;
71+
72+
const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize;
73+
const size_t OutputGroupSize = FilterCount * OutputSize;
74+
const size_t FilterGroupSize = FilterCount * K;
75+
76+
for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
77+
size_t group = bg % GroupCount;
78+
79+
const float* input = WorkBlock->Input + bg * InputGroupSize;
80+
const float* filter = WorkBlock->Filter + group * FilterGroupSize;
81+
float* output = WorkBlock->Output + bg * OutputGroupSize;
82+
const float* bias = WorkBlock->Bias;
83+
if (bias != nullptr) {
84+
bias += group * FilterCount;
85+
}
86+
87+
MlasConvDepthwiseWithMultiplierFloat_CHW(Parameters, input, filter, output, Zeros);
88+
MlasActivation(Parameters->Activation, output, bias, FilterCount, OutputSize, OutputSize);
89+
}
90+
}
91+
4592
void
4693
MlasConvIm2Col(
4794
const MLAS_CONV_PARAMETERS* Parameters,
@@ -1106,6 +1153,30 @@ Return Value:
11061153
return;
11071154
}
11081155

1156+
if (Algorithm == MlasConvAlgorithmDepthwiseMultiplierGreaterThan1 && ((BatchCount > 1) || (GroupCount > 1))) {
1157+
const size_t BatchGroupCount = BatchCount * GroupCount;
1158+
1159+
ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);
1160+
1161+
if (static_cast<size_t>(TargetThreadCount) >= BatchGroupCount) {
1162+
TargetThreadCount = static_cast<ptrdiff_t>(BatchGroupCount);
1163+
}
1164+
1165+
MLAS_CONV_WORK_BLOCK WorkBlock;
1166+
1167+
WorkBlock.Parameters = Parameters;
1168+
WorkBlock.Input = Input;
1169+
WorkBlock.Filter = Filter;
1170+
WorkBlock.Bias = Bias;
1171+
WorkBlock.WorkingBuffer = nullptr;
1172+
WorkBlock.Output = Output;
1173+
WorkBlock.TargetThreadCount = TargetThreadCount;
1174+
1175+
MlasExecuteThreaded(MlasDepthwiseMultiplierGreaterThan1Threaded, &WorkBlock, TargetThreadCount, ThreadPool);
1176+
1177+
return;
1178+
}
1179+
11091180

11101181
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
11111182

@@ -1198,6 +1269,14 @@ Return Value:
11981269
break;
11991270
}
12001271

1272+
case MlasConvAlgorithmDepthwiseMultiplierGreaterThan1:
1273+
{
1274+
const float* Zeros = nullptr;
1275+
MlasConvDepthwiseWithMultiplierFloat_CHW(Parameters, Input, filter, Output, Zeros);
1276+
MlasActivation(Parameters->Activation, Output, bias, FilterCount, OutputSize, OutputSize);
1277+
break;
1278+
}
1279+
12011280
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
12021281

12031282
case MlasConvAlgorithmDepthwise:
@@ -1453,6 +1532,23 @@ Return Value:
14531532

14541533
} else {
14551534

1535+
#if defined(MLAS_TARGET_AMD64)
1536+
1537+
if (Dimensions == 2
1538+
&& GroupCount > 1
1539+
&& Parameters->FilterCount == 2 && Parameters->InputChannels == 1
1540+
&& Parameters->KernelShape[0] == 7 && Parameters->KernelShape[1] == 7
1541+
&& Parameters->Padding[0] == 3 && Parameters->Padding[1] == 3
1542+
&& Parameters->Padding[2] == 3 && Parameters->Padding[3] == 3
1543+
&& Parameters->StrideShape[0] == 2 && Parameters->StrideShape[1] == 2
1544+
&& Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1
1545+
&& GetMlasPlatform().ConvNchwFloatKernel == MlasConvNchwFloatKernelAvx512F) {
1546+
1547+
Parameters->Algorithm = MlasConvAlgorithmDepthwiseMultiplierGreaterThan1;
1548+
return;
1549+
}
1550+
#endif
1551+
14561552
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
14571553

14581554
// Scalar (WASM_SCALAR) / vectorized (ARM64) direct conv for depthwise convolution.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*++
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
Module Name:
5+
sconv_nchw_depthwise_multiplier_greater_than_1_avx512f.cpp
6+
Abstract:
7+
This module implements the AVX512F kernel for the exact MobileClip grouped
8+
projection case:
9+
10+
- CHW input/output layout per group slice
11+
- input channels per group = 1
12+
- output channels per group = 2
13+
- kernel = 7x7
14+
- stride = 2x2
15+
- padding = 3,3,3,3
16+
- dilation = 1x1
17+
18+
The outer dispatch is expected to guarantee these constraints.
19+
--*/
20+
21+
#include "mlasi.h"
22+
23+
#if defined(MLAS_TARGET_AMD64)
24+
25+
namespace {
26+
27+
MLAS_FORCEINLINE
28+
void
29+
MlasConv2dSingleChannelCHWKernel7x7Pad3Stride2Dilation1DepthMultiplier2Scalar(
30+
const float* Input,
31+
size_t InputHeight,
32+
size_t InputWidth,
33+
const float* Filter0,
34+
const float* Filter1,
35+
float* Output0,
36+
float* Output1,
37+
size_t OutputWidth,
38+
size_t oh,
39+
size_t ow,
40+
float Beta
41+
)
42+
/*++
43+
44+
Routine Description:
45+
46+
Computes one border output point for the exact MobileClip
47+
7x7/pad-3/stride-2/dilation-1, multiplier-2 case.
48+
49+
This helper is only used by the AVX512 implementation for border handling;
50+
it is not a generic fallback dispatch path despite the scalar
51+
implementation.
52+
53+
--*/
54+
{
55+
const ptrdiff_t input_origin_y = static_cast<ptrdiff_t>(oh * 2) - 3;
56+
const ptrdiff_t input_origin_x = static_cast<ptrdiff_t>(ow * 2) - 3;
57+
const size_t output_index = oh * OutputWidth + ow;
58+
59+
float acc0 = (Beta == 0.0f) ? 0.0f : Output0[output_index] * Beta;
60+
float acc1 = (Beta == 0.0f) ? 0.0f : Output1[output_index] * Beta;
61+
62+
for (size_t kh = 0; kh < 7; ++kh) {
63+
const ptrdiff_t ih = input_origin_y + static_cast<ptrdiff_t>(kh);
64+
if (ih < 0 || ih >= static_cast<ptrdiff_t>(InputHeight)) {
65+
continue;
66+
}
67+
68+
const float* input_row = Input + static_cast<size_t>(ih) * InputWidth;
69+
const float* filter0_row = Filter0 + kh * 7;
70+
const float* filter1_row = Filter1 + kh * 7;
71+
72+
for (size_t kw = 0; kw < 7; ++kw) {
73+
const ptrdiff_t iw = input_origin_x + static_cast<ptrdiff_t>(kw);
74+
if (iw < 0 || iw >= static_cast<ptrdiff_t>(InputWidth)) {
75+
continue;
76+
}
77+
78+
const float input_value = input_row[static_cast<size_t>(iw)];
79+
acc0 += input_value * filter0_row[kw];
80+
acc1 += input_value * filter1_row[kw];
81+
}
82+
}
83+
84+
Output0[output_index] = acc0;
85+
Output1[output_index] = acc1;
86+
}
87+
88+
} // namespace
89+
90+
void
91+
MlasConvDepthwiseMultiplier2CHWKernel7x7S2Avx512F(
92+
const float* Input,
93+
size_t InputHeight,
94+
size_t InputWidth,
95+
const float* Filter,
96+
float* Output,
97+
size_t OutputHeight,
98+
size_t OutputWidth,
99+
float Beta
100+
)
101+
/*++
102+
103+
Routine Description:
104+
105+
Computes one group slice of the exact MobileClip grouped projection case.
106+
107+
Assumptions:
108+
109+
- Input and output are CHW tensors for a single group slice.
110+
- Filter is OIHW for a single group slice with exactly two output channels.
111+
- Kernel = 7x7, stride = 2, padding = 3, dilation = 1.
112+
- OutputHeight and OutputWidth match the supplied input geometry.
113+
114+
Return Value:
115+
116+
None.
117+
118+
--*/
119+
{
120+
constexpr size_t KernelSize = 7;
121+
constexpr __mmask16 ValidKernelMask = 0x007F;
122+
123+
const float* Filter0 = Filter;
124+
const float* Filter1 = Filter + KernelSize * KernelSize;
125+
float* Output0 = Output;
126+
float* Output1 = Output + (OutputHeight * OutputWidth);
127+
128+
for (size_t oh = 0; oh < OutputHeight; ++oh) {
129+
const ptrdiff_t input_origin_y = static_cast<ptrdiff_t>(oh * 2) - 3;
130+
const bool interior_y = input_origin_y >= 0 &&
131+
(input_origin_y + static_cast<ptrdiff_t>(KernelSize)) <= static_cast<ptrdiff_t>(InputHeight);
132+
133+
for (size_t ow = 0; ow < OutputWidth; ++ow) {
134+
const ptrdiff_t input_origin_x = static_cast<ptrdiff_t>(ow * 2) - 3;
135+
const bool interior_x = input_origin_x >= 0 &&
136+
(input_origin_x + static_cast<ptrdiff_t>(KernelSize)) <= static_cast<ptrdiff_t>(InputWidth);
137+
138+
if (!(interior_y && interior_x)) {
139+
MlasConv2dSingleChannelCHWKernel7x7Pad3Stride2Dilation1DepthMultiplier2Scalar(
140+
Input, InputHeight, InputWidth, Filter0, Filter1, Output0, Output1, OutputWidth, oh, ow, Beta);
141+
continue;
142+
}
143+
144+
__m512 acc0 = _mm512_setzero_ps();
145+
__m512 acc1 = _mm512_setzero_ps();
146+
147+
for (size_t kh = 0; kh < KernelSize; ++kh) {
148+
const float* input_row = Input + (static_cast<size_t>(input_origin_y) + kh) * InputWidth + static_cast<size_t>(input_origin_x);
149+
const __m512 input_vec = _mm512_maskz_loadu_ps(ValidKernelMask, input_row);
150+
const __m512 filter0_vec = _mm512_maskz_loadu_ps(ValidKernelMask, Filter0 + kh * KernelSize);
151+
const __m512 filter1_vec = _mm512_maskz_loadu_ps(ValidKernelMask, Filter1 + kh * KernelSize);
152+
153+
acc0 = _mm512_fmadd_ps(input_vec, filter0_vec, acc0);
154+
acc1 = _mm512_fmadd_ps(input_vec, filter1_vec, acc1);
155+
}
156+
157+
const size_t output_index = oh * OutputWidth + ow;
158+
float acc0_scalar = _mm512_reduce_add_ps(acc0);
159+
float acc1_scalar = _mm512_reduce_add_ps(acc1);
160+
161+
if (Beta != 0.0f) {
162+
acc0_scalar += Output0[output_index] * Beta;
163+
acc1_scalar += Output1[output_index] * Beta;
164+
}
165+
166+
Output0[output_index] = acc0_scalar;
167+
Output1[output_index] = acc1_scalar;
168+
}
169+
}
170+
}
171+
172+
#endif

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,8 +1688,6 @@ MlasFp32FromBits(
16881688
#endif
16891689

16901690
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
1691-
1692-
16931691
void
16941692
MLASCALL
16951693
MlasConvDepthwiseFloat_CHW(
@@ -1702,6 +1700,28 @@ MlasConvDepthwiseFloat_CHW(
17021700

17031701
#endif
17041702

1703+
void
1704+
MlasConvDepthwiseWithMultiplierFloat_CHW(
1705+
const MLAS_CONV_PARAMETERS* Parameters,
1706+
const float* Input,
1707+
const float* Filter,
1708+
float* Output,
1709+
const float* Zeros
1710+
);
1711+
1712+
#if defined(MLAS_TARGET_AMD64)
1713+
void
1714+
MlasConvDepthwiseMultiplier2CHWKernel7x7S2Avx512F(
1715+
const float* Input,
1716+
size_t InputHeight,
1717+
size_t InputWidth,
1718+
const float* Filter,
1719+
float* Output,
1720+
size_t OutputHeight,
1721+
size_t OutputWidth,
1722+
float Beta
1723+
);
1724+
#endif
17051725

17061726
//
17071727
// Define the missing ARM64 NEON intrinsic macros from arm64_neon.h that enable

0 commit comments

Comments
 (0)