Skip to content

Commit 3795941

Browse files
gunes-armmorgolock
authored andcommitted
fix: Remove epsilon from the comparisons in NE/CPPTopKV
The comparison in CPP and NETopKV functions don't use type-dependent epsilon anymore. A simple comparison. This is categorized as fix because it's a very minor behavioural change in some edge-case numerical situations. Partially Resolves: MLCE-1719 Change-Id: Ic0d02f6df05ce45886bcee458c2bcd578a388ee2 Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
1 parent e7ed1af commit 3795941

File tree

9 files changed

+49
-69
lines changed

9 files changed

+49
-69
lines changed

src/core/CPP/kernels/CPPTopKVKernel.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2020 Arm Limited.
2+
* Copyright (c) 2019-2020, 2026 Arm Limited.
33
*
44
* SPDX-License-Identifier: MIT
55
*
@@ -34,14 +34,7 @@ namespace arm_compute
3434
{
3535
namespace
3636
{
37-
template <typename T, typename std::enable_if<utils::traits::is_floating_point<T>::value, int>::type = 0>
38-
inline bool greater_than(T a, T b)
39-
{
40-
const T epsilon = std::numeric_limits<T>::epsilon();
41-
return (a - b > epsilon);
42-
}
43-
44-
template <typename T, typename std::enable_if<!utils::traits::is_floating_point<T>::value, int>::type = 0>
37+
template <typename T>
4538
inline bool greater_than(T a, T b)
4639
{
4740
return (a > b);

src/core/NEON/wrapper/intrinsics/shr.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022 Arm Limited.
2+
* Copyright (c) 2022, 2026 Arm Limited.
33
*
44
* SPDX-License-Identifier: MIT
55
*
@@ -22,8 +22,8 @@
2222
* SOFTWARE.
2323
*/
2424

25-
#ifndef ARM_COMPUTE_WRAPPER_SHR_H
26-
#define ARM_COMPUTE_WRAPPER_SHR_H
25+
#ifndef ACL_SRC_CORE_NEON_WRAPPER_INTRINSICS_SHR_H
26+
#define ACL_SRC_CORE_NEON_WRAPPER_INTRINSICS_SHR_H
2727

2828
#include <arm_neon.h>
2929
#include <type_traits>
@@ -103,6 +103,8 @@ VSHR_IMPL(int8x8_t, vshr_n, s8)
103103
{ \
104104
return prefix##_##postfix(a, b); \
105105
}
106+
VSHRQ_IMPL(uint16x8_t, vshrq_n, u16)
107+
VSHRQ_IMPL(int16x8_t, vshrq_n, s16)
106108
VSHRQ_IMPL(uint32x4_t, vshrq_n, u32)
107109
VSHRQ_IMPL(int32x4_t, vshrq_n, s32)
108110
#undef VSHRQ_IMPL
@@ -145,4 +147,4 @@ VQRSHRN_EX_SCALAR_IMPL(int32_t, int64_t, vqrshrnd_n, vqrshrund_n, s64)
145147

146148
} // namespace wrapper
147149
} // namespace arm_compute
148-
#endif /* ARM_COMPUTE_WRAPPER_SHR_H */
150+
#endif // ACL_SRC_CORE_NEON_WRAPPER_INTRINSICS_SHR_H

src/cpu/kernels/topkv/generic/neon/fp16.cpp

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
*/
2424
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
2525

26+
#include "src/core/NEON/wrapper/wrapper.h"
2627
#include "src/cpu/kernels/topkv/generic/neon/impl.h"
2728

2829
#include <arm_neon.h>
@@ -33,33 +34,24 @@ namespace cpu
3334
{
3435
namespace detail
3536
{
36-
static inline uint32_t reduce_u32x4(uint32x4_t v)
37+
static inline uint32_t reduce_u16x8(uint16x8_t v)
3738
{
3839
#if defined(__aarch64__)
39-
return vaddvq_u32(v);
40+
return vaddvq_u16(v);
4041
#else
41-
uint32x2_t s = vadd_u32(vget_low_u32(v), vget_high_u32(v));
42-
s = vpadd_u32(s, s);
43-
return vget_lane_u32(s, 0);
42+
uint16x4_t s = vadd_u16(vget_low_u16(v), vget_high_u16(v));
43+
s = vpadd_u16(s, s);
44+
return vget_lane_u16(s, 0);
4445
#endif
4546
}
4647

47-
// Explicit specialization for float16_t
4848
template <>
49-
uint32_t count_gt_block<float16_t>(const float16_t *ptr, float16_t threshold)
49+
uint32_t count_gt_block<float16_t>(const float16_t *ptr, float16x8_t thr_vec)
5050
{
51-
// Load 8 fp16
52-
const float16x8_t v16 = vld1q_f16(reinterpret_cast<const __fp16 *>(ptr));
53-
54-
const float32x4_t thr = vdupq_n_f32(threshold);
55-
56-
const float32x4_t v0 = vcvt_f32_f16(vget_low_f16(v16));
57-
const float32x4_t v1 = vcvt_f32_f16(vget_high_f16(v16));
58-
59-
const uint32x4_t b0 = vshrq_n_u32(vcgtq_f32(v0, thr), 31);
60-
const uint32x4_t b1 = vshrq_n_u32(vcgtq_f32(v1, thr), 31);
61-
62-
return reduce_u32x4(b0) + reduce_u32x4(b1);
51+
const float16x8_t v = wrapper::vloadq(ptr);
52+
const uint16x8_t mask = wrapper::vcgt(v, thr_vec); // new: v > (threshold)
53+
const uint16x8_t b = wrapper::vshrq_n<15>(mask);
54+
return reduce_u16x8(b);
6355
}
6456

6557
} // namespace detail

src/cpu/kernels/topkv/generic/neon/fp32.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,11 @@ static inline uint32_t reduce_u32x4(uint32x4_t v)
4747

4848
// Explicit specialization for float: may use float32x4_t
4949
template <>
50-
uint32_t count_gt_block<float>(const float *ptr, float threshold)
50+
uint32_t count_gt_block<float>(const float *ptr, float32x4_t thr_vec)
5151
{
52-
using Tag = wrapper::traits::neon_bitvector_tag_t<float, wrapper::traits::BitWidth::W128>;
53-
54-
const auto v = wrapper::vloadq(ptr);
55-
56-
// epsilon-aware compare: treat a > b only when (a - b) > epsilon
57-
const float eps_val = std::numeric_limits<float>::epsilon();
58-
const float thr_with_eps = threshold + eps_val;
59-
const auto thr_eps_vec = wrapper::vdup_n(thr_with_eps, Tag{});
60-
const auto mask = wrapper::vcgt(v, thr_eps_vec); // new: v > (threshold + eps)
61-
62-
const uint32x4_t m = mask;
63-
const uint32x4_t b = vshrq_n_u32(m, 31);
52+
const float32x4_t v = wrapper::vloadq(ptr);
53+
const uint32x4_t mask = wrapper::vcgt(v, thr_vec); // new: v > (threshold)
54+
const uint32x4_t b = wrapper::vshrq_n<31>(mask);
6455
return reduce_u32x4(b);
6556
}
6657
} // namespace detail

src/cpu/kernels/topkv/generic/neon/impl.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include "arm_compute/core/Helpers.h"
2828
#include "arm_compute/core/Window.h"
2929

30+
#include "src/core/NEON/wrapper/wrapper.h"
31+
3032
#include <limits>
3133
#include <type_traits>
3234

@@ -36,13 +38,18 @@ namespace cpu
3638
{
3739
namespace detail
3840
{
39-
template <typename ScalarType>
40-
uint32_t count_gt_block(const ScalarType *ptr, ScalarType threshold);
41+
template <typename ScalarType, typename VectorType>
42+
uint32_t count_gt_block(const ScalarType *ptr, VectorType threshold);
4143

4244
template <typename ScalarType>
4345
void topkv_neon_wrapper(
4446
const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &window)
4547
{
48+
constexpr auto bit_width = wrapper::traits::BitWidth::W128;
49+
50+
using TagType = typename wrapper::traits::neon_bitvector_tag_t<ScalarType, bit_width>;
51+
using VectorType = typename wrapper::traits::neon_bitvector_t<ScalarType, bit_width>;
52+
4653
const auto &pred_info = *predictions->info();
4754
const unsigned int C = pred_info.tensor_shape()[0];
4855

@@ -66,7 +73,8 @@ void topkv_neon_wrapper(
6673
const ScalarType *base =
6774
reinterpret_cast<const ScalarType *>(predictions->ptr_to_element(Coordinates{0, n}));
6875

69-
const ScalarType thr = base[t];
76+
const ScalarType thr = base[t];
77+
const VectorType thr_vec = wrapper::vdup_n(thr, TagType{});
7078

7179
uint32_t rank = 0;
7280
unsigned int c = 0;
@@ -75,7 +83,7 @@ void topkv_neon_wrapper(
7583
// Vector loop with early-exit
7684
for (; c + vec_elems <= C; c += vec_elems)
7785
{
78-
rank += count_gt_block<ScalarType>(base + c, thr);
86+
rank += count_gt_block<ScalarType>(base + c, thr_vec);
7987
if (rank >= k)
8088
{
8189
// For large C and small K (e.g. QASYMM8, C=32000, K=3), the probability that the

src/cpu/kernels/topkv/generic/neon/integer.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ static inline uint32_t reduce_u32x4(uint32x4_t v)
4343
}
4444

4545
template <>
46-
uint32_t count_gt_block<int32_t>(const int32_t *ptr, int32_t threshold)
46+
uint32_t count_gt_block<int32_t>(const int32_t *ptr, int32x4_t thr_vec)
4747
{
48-
const int32x4_t v = vld1q_s32(ptr);
49-
const int32x4_t thr = vdupq_n_s32(threshold);
50-
const uint32x4_t m = vcgtq_s32(v, thr); // 0xFFFFFFFF / 0 per lane
51-
const uint32x4_t b = vshrq_n_u32(m, 31); // 0/1 per lane
48+
const int32x4_t v = vld1q_s32(ptr);
49+
const uint32x4_t m = vcgtq_s32(v, thr_vec); // 0xFFFFFFFF / 0 per lane
50+
const uint32x4_t b = vshrq_n_u32(m, 31); // 0/1 per lane
5251
return reduce_u32x4(b);
5352
}
5453

src/cpu/kernels/topkv/generic/neon/qasymm8.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,10 @@ static inline uint32_t reduce_u8_to_count(uint8x16_t m)
4848
}
4949

5050
template <>
51-
uint32_t count_gt_block<uint8_t>(const uint8_t *ptr, uint8_t threshold)
51+
uint32_t count_gt_block<uint8_t>(const uint8_t *ptr, uint8x16_t thr_vec)
5252
{
53-
const uint8x16_t v = vld1q_u8(ptr);
54-
const uint8x16_t thr = vdupq_n_u8(threshold);
55-
const uint8x16_t m = vcgtq_u8(v, thr); // 0xFF / 0x00 bytes
53+
const uint8x16_t v = vld1q_u8(ptr);
54+
const uint8x16_t m = vcgtq_u8(v, thr_vec); // 0xFF / 0x00 bytes
5655
return reduce_u8_to_count(m);
5756
}
5857

src/cpu/kernels/topkv/generic/neon/qasymm8_signed.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,10 @@ static inline uint32_t reduce_u8_to_count(uint8x16_t m)
4646
}
4747

4848
template <>
49-
uint32_t count_gt_block<int8_t>(const int8_t *ptr, int8_t threshold)
49+
uint32_t count_gt_block<int8_t>(const int8_t *ptr, int8x16_t thr_vec)
5050
{
51-
const int8x16_t v = vld1q_s8(ptr);
52-
const int8x16_t thr = vdupq_n_s8(threshold);
53-
const uint8x16_t m = vcgtq_s8(v, thr); // returns uint8x16_t mask
51+
const int8x16_t v = vld1q_s8(ptr);
52+
const uint8x16_t m = vcgtq_s8(v, thr_vec); // returns uint8x16_t mask
5453
return reduce_u8_to_count(m);
5554
}
5655

tests/validation/reference/TopKV.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,21 @@ SimpleTensor<uint8_t> topkv(SimpleTensor<T> &predictions, SimpleTensor<uint32_t>
5050

5151
SimpleTensor<uint8_t> expected(TensorShape(N), DataType::U8);
5252

53-
const float eps = std::numeric_limits<float>::epsilon();
54-
5553
for (int i = 0; i < N; ++i)
5654
{
5755
// targets[i] (U32)
5856
const uint32_t target_class = targets[i];
5957

6058
// Read predictions[target_class, i] as T, then promote to float
61-
const T target_t = *reinterpret_cast<const T *>(predictions(Coordinates{target_class, i}));
62-
const float target_val = static_cast<float>(target_t);
59+
const T target = *reinterpret_cast<const T *>(predictions(Coordinates{target_class, i}));
60+
const T threshold = target;
6361

6462
unsigned int rank = 0;
6563
for (int c = 0; c < C; ++c)
6664
{
67-
const T vt = *reinterpret_cast<const T *>(predictions(Coordinates{c, i}));
68-
const float v = static_cast<float>(vt);
65+
const T v = *reinterpret_cast<const T *>(predictions(Coordinates{c, i}));
6966

70-
if ((v - target_val) > eps)
67+
if (v > threshold)
7168
{
7269
++rank;
7370
}

0 commit comments

Comments
 (0)