@@ -37,7 +37,7 @@ void avg_pool2d_nchw(
3737 IntArrayRef stride,
3838 IntArrayRef padding,
3939 bool count_include_pad,
40- int64_t divisor ,
40+ optional< int64_t > divisor_override ,
4141 int leading_dims,
4242 int ih,
4343 int iw,
@@ -50,6 +50,10 @@ void avg_pool2d_nchw(
5050 int p0 = padding[0 ];
5151 int p1 = padding[1 ];
5252
53+ bool is_fixed_divisor = divisor_override.has_value () || count_include_pad;
54+ float fixed_inv_divisor =
55+ 1 .f / (divisor_override.has_value () ? divisor_override.value () : kh * kw);
56+
5357 for (int _n = 0 ; _n < leading_dims; ++_n) {
5458 for (int _ih = 0 , _oh = 0 ; _oh < oh; ++_oh, _ih += s0) {
5559 int input_offset = _n * ih * iw;
@@ -70,9 +74,7 @@ void avg_pool2d_nchw(
7074 acc += in_data[input_addr];
7175 }
7276 }
73- // The divisor changes depending on whether the count includes
74- // padded cells or not.
75- float inv_divisor = 1 . / (count_include_pad ? divisor : count);
77+ float inv_divisor = is_fixed_divisor ? fixed_inv_divisor : 1 .f / count;
7678 float val = acc * inv_divisor;
7779 if (quantized) {
7880 int32_t min_val =
@@ -105,10 +107,6 @@ Tensor& avg_pool2d_out(
105107 const int32_t in_zero_point = in_zero_point_t .has_value ()
106108 ? in_zero_point_t .value ().const_data_ptr <int32_t >()[0 ]
107109 : 0 ;
108- const int64_t divisor = divisor_override.has_value ()
109- ? divisor_override.value ()
110- : kernel_size[0 ] * kernel_size[1 ];
111-
112110 const int odim = out.dim ();
113111 const int on = getLeadingDims (out, odim - 2 );
114112 const int oh = out.size (odim - 2 );
@@ -128,7 +126,7 @@ Tensor& avg_pool2d_out(
128126 stride, \
129127 padding, \
130128 count_include_pad, \
131- divisor, \
129+ divisor_override, \
132130 on, \
133131 ih, \
134132 iw, \
0 commit comments