Skip to content

Commit d52e297

Browse files
committed
fix: Lack of epsilon aware comparison in NETopKV for Fp32 data type.
Align NETopKV FP32 comparison semantics with the scalar reference Resolves COMPMID-8829 Change-Id: I995a07e0f38587b69b0ac08ee2c85949704f0e60 Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
1 parent 6305151 commit d52e297

File tree

1 file changed

+9
-5
lines changed
  • src/cpu/kernels/topkv/generic/neon

1 file changed

+9
-5
lines changed

src/cpu/kernels/topkv/generic/neon/fp32.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "src/cpu/kernels/topkv/generic/neon/impl.h"
2727

2828
#include <arm_neon.h>
29+
#include <limits>
2930

3031
namespace arm_compute
3132
{
@@ -44,21 +45,24 @@ static inline uint32_t reduce_u32x4(uint32x4_t v)
4445
#endif
4546
}
4647

47-
// Explicit specialization for float: may use float32x4_t etc (only in this TU)
48+
// Explicit specialization for float: may use float32x4_t
4849
template <>
4950
uint32_t count_gt_block<float>(const float *ptr, float threshold)
5051
{
5152
using Tag = wrapper::traits::neon_bitvector_tag_t<float, wrapper::traits::BitWidth::W128>;
5253

53-
const auto thr_vec = wrapper::vdup_n(threshold, Tag{});
54-
const auto v = wrapper::vloadq(ptr);
55-
const auto mask = wrapper::vcgt(v, thr_vec); // underlying uint32x4_t
54+
const auto v = wrapper::vloadq(ptr);
55+
56+
// epsilon-aware compare: treat a > b only when (a - b) > epsilon
57+
const float eps_val = std::numeric_limits<float>::epsilon();
58+
const float thr_with_eps = threshold + eps_val;
59+
const auto thr_eps_vec = wrapper::vdup_n(thr_with_eps, Tag{});
60+
const auto mask = wrapper::vcgt(v, thr_eps_vec); // new: v > (threshold + eps)
5661

5762
const uint32x4_t m = mask;
5863
const uint32x4_t b = vshrq_n_u32(m, 31);
5964
return reduce_u32x4(b);
6065
}
61-
6266
} // namespace detail
6367

6468
void topkv_fp32_neon(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win)

0 commit comments

Comments
 (0)