Skip to content

Commit a3a08c1

Browse files
authored
Fix avg_pool2d divisor_override being ignored when count_include_pad is false. (#18616)
Differential Revision: D98942362 Pull Request resolved: #18616
1 parent 5ba654f commit a3a08c1

1 file changed

Lines changed: 7 additions & 9 deletions

File tree

backends/cadence/generic/operators/op_avg_pool2d.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void avg_pool2d_nchw(
3737
IntArrayRef stride,
3838
IntArrayRef padding,
3939
bool count_include_pad,
40-
int64_t divisor,
40+
optional<int64_t> divisor_override,
4141
int leading_dims,
4242
int ih,
4343
int iw,
@@ -50,6 +50,10 @@ void avg_pool2d_nchw(
5050
int p0 = padding[0];
5151
int p1 = padding[1];
5252

53+
bool is_fixed_divisor = divisor_override.has_value() || count_include_pad;
54+
float fixed_inv_divisor =
55+
1.f / (divisor_override.has_value() ? divisor_override.value() : kh * kw);
56+
5357
for (int _n = 0; _n < leading_dims; ++_n) {
5458
for (int _ih = 0, _oh = 0; _oh < oh; ++_oh, _ih += s0) {
5559
int input_offset = _n * ih * iw;
@@ -70,9 +74,7 @@ void avg_pool2d_nchw(
7074
acc += in_data[input_addr];
7175
}
7276
}
73-
// The divisor changes depending on whether the count includes
74-
// padded cells or not.
75-
float inv_divisor = 1. / (count_include_pad ? divisor : count);
77+
float inv_divisor = is_fixed_divisor ? fixed_inv_divisor : 1.f / count;
7678
float val = acc * inv_divisor;
7779
if (quantized) {
7880
int32_t min_val =
@@ -105,10 +107,6 @@ Tensor& avg_pool2d_out(
105107
const int32_t in_zero_point = in_zero_point_t.has_value()
106108
? in_zero_point_t.value().const_data_ptr<int32_t>()[0]
107109
: 0;
108-
const int64_t divisor = divisor_override.has_value()
109-
? divisor_override.value()
110-
: kernel_size[0] * kernel_size[1];
111-
112110
const int odim = out.dim();
113111
const int on = getLeadingDims(out, odim - 2);
114112
const int oh = out.size(odim - 2);
@@ -128,7 +126,7 @@ Tensor& avg_pool2d_out(
128126
stride, \
129127
padding, \
130128
count_include_pad, \
131-
divisor, \
129+
divisor_override, \
132130
on, \
133131
ih, \
134132
iw, \

0 commit comments

Comments
 (0)