Skip to content

Commit d1fe74d

Browse files
committed
feat: Add SVE kernels for TopKV.
Resolves MLCE-1719 Change-Id: I7a0c7bd1154b9cb7f35c7fd1c3b8ad54698f8799 Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
1 parent 3795941 commit d1fe74d

File tree

11 files changed

+524
-9
lines changed

11 files changed

+524
-9
lines changed

filelist.json

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,6 +2453,7 @@
24532453
]
24542454
}
24552455
},
2456+
24562457
"TopKV": {
24572458
"files": {
24582459
"common": [
@@ -2463,15 +2464,17 @@
24632464
"neon": {
24642465
"fp16": [ "src/cpu/kernels/topkv/generic/neon/fp16.cpp" ],
24652466
"fp32": [ "src/cpu/kernels/topkv/generic/neon/fp32.cpp" ],
2466-
"integer":["src/cpu/kernels/topkv/generic/neon/integer.cpp"],
2467-
"qasymm8": [
2468-
"src/cpu/kernels/topkv/generic/neon/qasymm8.cpp"
2469-
],
2470-
"qasymm8_signed": [
2471-
"src/cpu/kernels/topkv/generic/neon/qasymm8_signed.cpp"
2472-
]
2467+
"integer": [ "src/cpu/kernels/topkv/generic/neon/integer.cpp" ],
2468+
"qasymm8": [ "src/cpu/kernels/topkv/generic/neon/qasymm8.cpp" ],
2469+
"qasymm8_signed": [ "src/cpu/kernels/topkv/generic/neon/qasymm8_signed.cpp" ]
2470+
},
2471+
"sve": {
2472+
"fp32": [ "src/cpu/kernels/topkv/generic/sve/fp32.cpp" ],
2473+
"fp16": [ "src/cpu/kernels/topkv/generic/sve/fp16.cpp" ],
2474+
"integer": [ "src/cpu/kernels/topkv/generic/sve/integer.cpp" ],
2475+
"qasymm8": [ "src/cpu/kernels/topkv/generic/sve/qasymm8.cpp" ],
2476+
"qasymm8_signed": [ "src/cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp" ]
24732477
}
2474-
24752478
}
24762479
},
24772480
"Transpose": {

src/BUILD.bazel

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,12 @@ filegroup(
395395
"cpu/kernels/scale/sve/qasymm8.cpp",
396396
"cpu/kernels/scale/sve/qasymm8_signed.cpp",
397397
"cpu/kernels/softmax/generic/sve/impl.cpp",
398-
"cpu/kernels/softmax/generic/sve/impl_bf16.cpp"] +
398+
"cpu/kernels/softmax/generic/sve/impl_bf16.cpp",
399+
"cpu/kernels/topkv/generic/sve/fp16.cpp",
400+
"cpu/kernels/topkv/generic/sve/fp32.cpp",
401+
"cpu/kernels/topkv/generic/sve/integer.cpp",
402+
"cpu/kernels/topkv/generic/sve/qasymm8.cpp",
403+
"cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp"] +
399404
glob(["**/*.h",
400405
"**/*.hpp",
401406
"**/*.inl"]),

src/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ target_sources(
365365
cpu/kernels/scale/sve/qasymm8_signed.cpp
366366
cpu/kernels/softmax/generic/sve/impl.cpp
367367
cpu/kernels/softmax/generic/sve/impl_bf16.cpp
368+
cpu/kernels/topkv/generic/sve/fp16.cpp
369+
cpu/kernels/topkv/generic/sve/fp32.cpp
370+
cpu/kernels/topkv/generic/sve/integer.cpp
371+
cpu/kernels/topkv/generic/sve/qasymm8.cpp
372+
cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp
368373
)
369374

370375
target_sources(

src/cpu/kernels/CpuTopKVKernel.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,42 @@ namespace
4343
{
4444

4545
static const std::vector<CpuTopKVKernel::TopKVKernel> available_kernels = {
46+
47+
{"sve_fp16_topkv",
48+
[](const CpuTopKVKernelDataTypeISASelectorData &data)
49+
{ return (data.dt == DataType::F16) && data.isa.fp16 && data.isa.sve; },
50+
REGISTER_FP16_SVE(arm_compute::cpu::topkv_fp16_sve)},
51+
52+
{"sve_fp32_topkv",
53+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F32) && data.isa.sve; },
54+
REGISTER_FP32_SVE(arm_compute::cpu::topkv_fp32_sve)},
55+
56+
{"sve_qasymm8_topkv",
57+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8) && data.isa.sve; },
58+
REGISTER_QASYMM8_SVE(arm_compute::cpu::topkv_qasymm8_sve)},
59+
60+
{"sve_qasymm8_signed_topkv",
61+
[](const CpuTopKVKernelDataTypeISASelectorData &data)
62+
{ return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve; },
63+
REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::topkv_qasymm8_signed_sve)},
64+
65+
{"sve_s32_topkv",
66+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::S32) && data.isa.sve; },
67+
REGISTER_INTEGER_SVE(arm_compute::cpu::topkv_s32_sve)},
68+
4669
{"neon_s32_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::S32); },
4770
REGISTER_INTEGER_NEON(arm_compute::cpu::topkv_s32_neon)},
71+
4872
{"neon_fp32_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F32); },
4973
REGISTER_FP32_NEON(arm_compute::cpu::topkv_fp32_neon)},
74+
5075
{"neon_fp16_topkv",
5176
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F16) && data.isa.fp16; },
5277
REGISTER_FP16_NEON(arm_compute::cpu::topkv_fp16_neon)},
78+
5379
{"neon_qu8_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8); },
5480
REGISTER_QASYMM8_NEON(arm_compute::cpu::topkv_qasymm8_neon)},
81+
5582
{"neon_qs8_topkv",
5683
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8_SIGNED); },
5784
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::topkv_qasymm8_signed_neon)}};
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#if defined(__ARM_FEATURE_SVE)
25+
26+
#include "src/cpu/kernels/topkv/generic/sve/impl.h"
27+
28+
#include <arm_sve.h>
29+
#include <cstdint>
30+
31+
namespace arm_compute
32+
{
33+
namespace cpu
34+
{
35+
namespace detail
36+
{
37+
38+
template <>
39+
inline uint32_t vector_length<float16_t>()
40+
{
41+
return static_cast<uint32_t>(svcnth());
42+
}
43+
44+
template <>
45+
inline uint32_t count_gt_block<float16_t>(const float16_t *ptr, svfloat16_t thr_vec, uint32_t block_elems)
46+
{
47+
const svbool_t pg = svwhilelt_b16(static_cast<uint64_t>(0), static_cast<uint64_t>(block_elems));
48+
const svfloat16_t v = svld1_f16(pg, ptr);
49+
const svbool_t gt = svcmpgt_f16(pg, v, thr_vec);
50+
51+
return static_cast<uint32_t>(svcntp_b16(svptrue_b16(), gt));
52+
}
53+
} // namespace detail
54+
55+
void topkv_fp16_sve(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win)
56+
{
57+
detail::topkv_sve_wrapper<float16_t>(predictions, targets, out, k, win);
58+
}
59+
60+
// Force instantiation into this TU
61+
template void
62+
detail::topkv_sve_wrapper<float16_t>(const ITensor *, const ITensor *, ITensor *, uint32_t, const Window &);
63+
64+
} // namespace cpu
65+
} // namespace arm_compute
66+
67+
#endif // __ARM_FEATURE_SVE
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#if defined(__ARM_FEATURE_SVE)
25+
26+
#include "src/cpu/kernels/topkv/generic/sve/impl.h"
27+
28+
#include <arm_sve.h>
29+
#include <cstdint>
30+
31+
namespace arm_compute
32+
{
33+
namespace cpu
34+
{
35+
namespace detail
36+
{
37+
38+
template <>
39+
inline uint32_t vector_length<float>()
40+
{
41+
return static_cast<uint32_t>(svcntw());
42+
}
43+
44+
template <>
45+
inline uint32_t count_gt_block<float>(const float *ptr, svfloat32_t thr_vec, uint32_t block_elems)
46+
{
47+
const svbool_t pg = svwhilelt_b32(static_cast<uint64_t>(0), static_cast<uint64_t>(block_elems));
48+
const svfloat32_t v = svld1_f32(pg, ptr);
49+
const svbool_t gt = svcmpgt_f32(pg, v, thr_vec);
50+
return static_cast<uint32_t>(svcntp_b32(svptrue_b32(), gt));
51+
}
52+
53+
} // namespace detail
54+
55+
void topkv_fp32_sve(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win)
56+
{
57+
detail::topkv_sve_wrapper<float>(predictions, targets, out, k, win);
58+
}
59+
60+
// Force instantiation into this TU
61+
template void detail::topkv_sve_wrapper<float>(const ITensor *, const ITensor *, ITensor *, uint32_t, const Window &);
62+
63+
} // namespace cpu
64+
} // namespace arm_compute
65+
66+
#endif // __ARM_FEATURE_SVE
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#ifndef ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H
25+
#define ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H
26+
27+
#include "arm_compute/core/Coordinates.h"
28+
#include "arm_compute/core/Error.h"
29+
#include "arm_compute/core/Helpers.h"
30+
#include "arm_compute/core/ITensor.h"
31+
#include "arm_compute/core/Types.h"
32+
#include "arm_compute/core/Window.h"
33+
34+
#include "src/core/NEON/wrapper/wrapper.h"
35+
36+
#include <cstdint>
37+
#include <cstring>
38+
39+
#if defined(__ARM_FEATURE_SVE)
40+
#include <arm_sve.h>
41+
#endif // defined(__ARM_FEATURE_SVE)
42+
43+
namespace arm_compute
44+
{
45+
namespace cpu
46+
{
47+
namespace detail
48+
{
49+
50+
/*
51+
* Type-specific hooks (declared here, defined in each cpp).
52+
*
53+
* - vector_length<Scalar>()
54+
* Return the SVE vector length in elements for Scalar (no clamping).
55+
*
56+
* - count_gt_block<Scalar>(ptr, thr_vec, block_elems)
57+
* Count how many elements in [ptr, ptr + block_elems) are > threshold.
58+
* Tail-safe via predicate. block_elems is always <= vector_length<Scalar>().
59+
*
60+
t contains the SVE intrinsics
61+
* (e.g., qasymm8.cpp, qasymm8_signed.cpp, fp16.cpp, fp32.cpp, integer.cpp).
62+
*/
63+
64+
template <typename Scalar>
65+
uint32_t vector_length();
66+
67+
template <typename Scalar, typename ThresholdVector>
68+
uint32_t count_gt_block(const Scalar *ptr, ThresholdVector thr_vec, uint32_t block_elems);
69+
70+
// ----------------------------------------------------------------------------
71+
// Generic wrapper (type-agnostic) - uses the above hooks.
72+
// Semantics (matching TopKV tests you showed):
73+
// - predictions is N x C
74+
// - window iterates across output elements (classes) => id.x() == class index c
75+
// - for each class c, targets[c] gives the sample index t
76+
// - scan across N samples and compute rank (#samples with value > predictions[t])
77+
// - output is U8 boolean: (rank < k)
78+
// ----------------------------------------------------------------------------
79+
template <typename Scalar>
80+
inline void
81+
topkv_sve_wrapper(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &window)
82+
{
83+
ARM_COMPUTE_ERROR_ON_NULLPTR(predictions, targets, out);
84+
ARM_COMPUTE_ERROR_ON(k == 0);
85+
86+
const ITensorInfo *pred_info = predictions->info();
87+
const uint32_t N = pred_info->dimension(0); // samples
88+
const uint32_t C = pred_info->dimension(1); // classes
89+
90+
const uint32_t vl = vector_length<Scalar>(); // cache once per kernel invocation
91+
92+
Iterator tgt_it(targets, window);
93+
Iterator out_it(out, window);
94+
95+
execute_window_loop(
96+
window,
97+
[&](const Coordinates &id)
98+
{
99+
const uint32_t c = static_cast<uint32_t>(id.x()); // class index
100+
ARM_COMPUTE_ERROR_ON(c >= C);
101+
102+
uint32_t t = {*reinterpret_cast<uint32_t *>(tgt_it.ptr())};
103+
ARM_COMPUTE_ERROR_ON(t >= N);
104+
105+
const Scalar *col_ptr = reinterpret_cast<const Scalar *>(predictions->ptr_to_element(Coordinates(0, c)));
106+
ARM_COMPUTE_ERROR_ON(col_ptr == nullptr);
107+
108+
const Scalar thr = col_ptr[t];
109+
const auto thr_vec = wrapper::svdup_n(thr);
110+
111+
uint32_t rank = 0;
112+
uint32_t idx = 0;
113+
114+
while (idx < N)
115+
{
116+
const uint32_t remaining = N - idx;
117+
const uint32_t bw = (remaining < vl) ? remaining : vl;
118+
119+
rank += count_gt_block<Scalar>(col_ptr + idx, thr_vec, bw);
120+
121+
if (rank >= k)
122+
{
123+
break;
124+
}
125+
126+
idx += bw;
127+
}
128+
129+
*reinterpret_cast<uint8_t *>(out_it.ptr()) = static_cast<uint8_t>(rank < k);
130+
},
131+
tgt_it, out_it);
132+
}
133+
134+
} // namespace detail
135+
} // namespace cpu
136+
} // namespace arm_compute
137+
138+
#endif // ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H

0 commit comments

Comments
 (0)