Skip to content

Commit dbe68c9

Browse files
committed
add guard in maxpool parser and kernel for kernel length and strides
1 parent cc810c0 commit dbe68c9

3 files changed

Lines changed: 12 additions & 2 deletions

File tree

Deeploy/Targets/Generic/Parsers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,13 @@ def __init__(self):
193193
def parseNode(self, node: gs.Node) -> bool:
194194

195195
ret = all([
196-
'ceil_mode' in node.attrs, 'kernel_shape' in node.attrs, 'pads' in node.attrs, 'strides' in node.attrs,
196+
'ceil_mode' in node.attrs,
197+
'kernel_shape' in node.attrs,
198+
'pads' in node.attrs,
199+
'strides' in node.attrs,
197200
len(node.inputs) == 1,
198-
len(node.outputs) >= 1
201+
len(node.outputs) >= 1,
202+
all([stride > 0 for stride in node.attrs['strides']]),
199203
])
200204

201205
if ret:

TargetLibraries/Generic/src/MaxPool_fp32.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ void MaxPool2d_fp32_fp32_NCHW(float32_t const *__restrict__ pSrcA, uint32_t C,
4343
void MaxPool1d_fp32_fp32(float32_t const *__restrict__ pSrcA, uint32_t C,
4444
uint32_t W, uint32_t K, uint32_t S,
4545
float32_t *__restrict__ pDstC) {
46+
if (W < K || S == 0) {
47+
return;
48+
}
4649
uint32_t W_out = (W - K) / S + 1;
4750
for (uint32_t c = 0; c < C; ++c) {
4851
for (uint32_t w_out = 0; w_out < W_out; ++w_out) {

TargetLibraries/Generic/src/MaxPool_s8.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ void MaxPool2d_s8_s8_NCHW(int8_t const *__restrict__ pSrcA, uint32_t C,
5252
void MaxPool1d_s8_s8(int8_t const *__restrict__ pSrcA, uint32_t C, uint32_t L,
5353
uint32_t K, uint32_t S, int8_t *__restrict__ pDstC,
5454
int32_t input_offset, int32_t output_offset) {
55+
if (L < K || S == 0) {
56+
return;
57+
}
5558
uint32_t L_out = (L - K) / S + 1;
5659
for (uint32_t c = 0; c < C; ++c) {
5760
for (uint32_t l_out = 0; l_out < L_out; ++l_out) {

0 commit comments

Comments
 (0)