|
| 1 | +// NEON-optimized sum.IntList_out for ExecuTorch. |
| 2 | +// |
| 3 | +// The portable kernel hands off to a generic parallel_for + MapReduce plan, |
| 4 | +// which the compiler may auto-vectorize but doesn't guarantee. This kernel |
| 5 | +// picks the common single-dim cases and does them with explicit NEON |
| 6 | +// intrinsics, falling back to a scalar loop for anything else (multi-dim |
| 7 | +// reductions, dtype mismatch, complex types). |
| 8 | +// |
| 9 | +// Signature matches kernels/portable/cpu/op_sum.cpp::sum_dim_out. |
| 10 | + |
| 11 | +#include <executorch/runtime/kernel/kernel_includes.h> |
| 12 | +#include <executorch/kernels/portable/cpu/util/reduce_util.h> |
| 13 | + |
| 14 | +#ifdef __aarch64__ |
| 15 | +#include <arm_neon.h> |
| 16 | +#endif |
| 17 | + |
| 18 | +#include <cmath> |
| 19 | +#include <cstring> |
| 20 | +#include <optional> |
| 21 | + |
| 22 | +namespace custom { |
| 23 | +namespace native { |
| 24 | + |
| 25 | +using exec_aten::ArrayRef; |
| 26 | +using exec_aten::ScalarType; |
| 27 | +using exec_aten::Tensor; |
| 28 | +using executorch::aten::optional; |
| 29 | +using torch::executor::KernelRuntimeContext; |
| 30 | + |
| 31 | +namespace { |
| 32 | + |
| 33 | +// Normalize a possibly-negative dim into [0, ndim). |
| 34 | +inline int64_t normalize_dim(int64_t dim, int64_t ndim) { |
| 35 | + return dim < 0 ? dim + ndim : dim; |
| 36 | +} |
| 37 | + |
| 38 | +// Compute (outer_size, reduce_size, inner_size) around a single reduction dim. |
| 39 | +inline void compute_partition_sizes( |
| 40 | + const Tensor& in, |
| 41 | + int64_t dim, |
| 42 | + int64_t& outer_size, |
| 43 | + int64_t& reduce_size, |
| 44 | + int64_t& inner_size) { |
| 45 | + outer_size = 1; |
| 46 | + reduce_size = in.size(dim); |
| 47 | + inner_size = 1; |
| 48 | + for (int64_t d = 0; d < dim; ++d) { |
| 49 | + outer_size *= in.size(d); |
| 50 | + } |
| 51 | + for (int64_t d = dim + 1; d < in.dim(); ++d) { |
| 52 | + inner_size *= in.size(d); |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +#ifdef __aarch64__ |
| 57 | + |
| 58 | +// Sum along the innermost (contiguous) dim: one reduction per outer row. |
| 59 | +// in : shape [outer_size, reduce_size] |
| 60 | +// out : shape [outer_size] |
| 61 | +template <typename T> |
| 62 | +inline void sum_innermost( |
| 63 | + const T* in, T* out, int64_t outer_size, int64_t reduce_size); |
| 64 | + |
| 65 | +template <> |
| 66 | +inline void sum_innermost<float>( |
| 67 | + const float* in, float* out, int64_t outer_size, int64_t reduce_size) { |
| 68 | + for (int64_t i = 0; i < outer_size; ++i) { |
| 69 | + const float* row = in + i * reduce_size; |
| 70 | + float32x4_t acc = vdupq_n_f32(0.0f); |
| 71 | + int64_t j = 0; |
| 72 | + for (; j + 15 < reduce_size; j += 16) { |
| 73 | + acc = vaddq_f32(acc, vld1q_f32(row + j)); |
| 74 | + acc = vaddq_f32(acc, vld1q_f32(row + j + 4)); |
| 75 | + acc = vaddq_f32(acc, vld1q_f32(row + j + 8)); |
| 76 | + acc = vaddq_f32(acc, vld1q_f32(row + j + 12)); |
| 77 | + } |
| 78 | + for (; j + 3 < reduce_size; j += 4) { |
| 79 | + acc = vaddq_f32(acc, vld1q_f32(row + j)); |
| 80 | + } |
| 81 | + float sum = vaddvq_f32(acc); |
| 82 | + for (; j < reduce_size; ++j) { |
| 83 | + sum += row[j]; |
| 84 | + } |
| 85 | + out[i] = sum; |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +template <> |
| 90 | +inline void sum_innermost<__fp16>( |
| 91 | + const __fp16* in, __fp16* out, int64_t outer_size, int64_t reduce_size) { |
| 92 | + // Accumulate in fp32 for precision; long reductions in pure fp16 lose |
| 93 | + // significant trailing digits. |
| 94 | + for (int64_t i = 0; i < outer_size; ++i) { |
| 95 | + const __fp16* row = in + i * reduce_size; |
| 96 | + float32x4_t acc = vdupq_n_f32(0.0f); |
| 97 | + int64_t j = 0; |
| 98 | + for (; j + 7 < reduce_size; j += 8) { |
| 99 | + float16x8_t v = vld1q_f16(row + j); |
| 100 | + acc = vaddq_f32(acc, vcvt_f32_f16(vget_low_f16(v))); |
| 101 | + acc = vaddq_f32(acc, vcvt_f32_f16(vget_high_f16(v))); |
| 102 | + } |
| 103 | + float sum = vaddvq_f32(acc); |
| 104 | + for (; j < reduce_size; ++j) { |
| 105 | + sum += static_cast<float>(row[j]); |
| 106 | + } |
| 107 | + out[i] = static_cast<__fp16>(sum); |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +// Sum along a non-innermost dim: reduce over `reduce_size` elements spaced |
| 112 | +// `inner_size` apart, for each (outer, inner) pair. |
| 113 | +// in : shape [outer_size, reduce_size, inner_size] |
| 114 | +// out : shape [outer_size, inner_size] |
| 115 | +// Vectorized over the contiguous `inner` axis. |
| 116 | +template <typename T> |
| 117 | +inline void sum_strided( |
| 118 | + const T* in, |
| 119 | + T* out, |
| 120 | + int64_t outer_size, |
| 121 | + int64_t reduce_size, |
| 122 | + int64_t inner_size); |
| 123 | + |
| 124 | +template <> |
| 125 | +inline void sum_strided<float>( |
| 126 | + const float* in, |
| 127 | + float* out, |
| 128 | + int64_t outer_size, |
| 129 | + int64_t reduce_size, |
| 130 | + int64_t inner_size) { |
| 131 | + const int64_t reduce_stride = inner_size; |
| 132 | + const int64_t outer_stride = reduce_size * inner_size; |
| 133 | + for (int64_t o = 0; o < outer_size; ++o) { |
| 134 | + const float* in_o = in + o * outer_stride; |
| 135 | + float* out_o = out + o * inner_size; |
| 136 | + int64_t j = 0; |
| 137 | + for (; j + 3 < inner_size; j += 4) { |
| 138 | + float32x4_t acc = vdupq_n_f32(0.0f); |
| 139 | + for (int64_t k = 0; k < reduce_size; ++k) { |
| 140 | + acc = vaddq_f32(acc, vld1q_f32(in_o + k * reduce_stride + j)); |
| 141 | + } |
| 142 | + vst1q_f32(out_o + j, acc); |
| 143 | + } |
| 144 | + for (; j < inner_size; ++j) { |
| 145 | + float sum = 0.0f; |
| 146 | + for (int64_t k = 0; k < reduce_size; ++k) { |
| 147 | + sum += in_o[k * reduce_stride + j]; |
| 148 | + } |
| 149 | + out_o[j] = sum; |
| 150 | + } |
| 151 | + } |
| 152 | +} |
| 153 | + |
| 154 | +template <> |
| 155 | +inline void sum_strided<__fp16>( |
| 156 | + const __fp16* in, |
| 157 | + __fp16* out, |
| 158 | + int64_t outer_size, |
| 159 | + int64_t reduce_size, |
| 160 | + int64_t inner_size) { |
| 161 | + const int64_t reduce_stride = inner_size; |
| 162 | + const int64_t outer_stride = reduce_size * inner_size; |
| 163 | + for (int64_t o = 0; o < outer_size; ++o) { |
| 164 | + const __fp16* in_o = in + o * outer_stride; |
| 165 | + __fp16* out_o = out + o * inner_size; |
| 166 | + int64_t j = 0; |
| 167 | + // Accumulate in fp32 for precision, 8 contiguous inner positions per iter. |
| 168 | + for (; j + 7 < inner_size; j += 8) { |
| 169 | + float32x4_t acc_lo = vdupq_n_f32(0.0f); |
| 170 | + float32x4_t acc_hi = vdupq_n_f32(0.0f); |
| 171 | + for (int64_t k = 0; k < reduce_size; ++k) { |
| 172 | + float16x8_t v = vld1q_f16(in_o + k * reduce_stride + j); |
| 173 | + acc_lo = vaddq_f32(acc_lo, vcvt_f32_f16(vget_low_f16(v))); |
| 174 | + acc_hi = vaddq_f32(acc_hi, vcvt_f32_f16(vget_high_f16(v))); |
| 175 | + } |
| 176 | + vst1q_f16( |
| 177 | + out_o + j, |
| 178 | + vcombine_f16(vcvt_f16_f32(acc_lo), vcvt_f16_f32(acc_hi))); |
| 179 | + } |
| 180 | + for (; j < inner_size; ++j) { |
| 181 | + float sum = 0.0f; |
| 182 | + for (int64_t k = 0; k < reduce_size; ++k) { |
| 183 | + sum += static_cast<float>(in_o[k * reduce_stride + j]); |
| 184 | + } |
| 185 | + out_o[j] = static_cast<__fp16>(sum); |
| 186 | + } |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +#endif // __aarch64__ |
| 191 | + |
| 192 | +// Scalar fallback: equivalent semantics to the NEON paths, used on non-arm64 |
| 193 | +// hosts and for dtypes we don't special-case. |
| 194 | +template <typename T_IN, typename T_OUT> |
| 195 | +inline void sum_scalar( |
| 196 | + const T_IN* in, |
| 197 | + T_OUT* out, |
| 198 | + int64_t outer_size, |
| 199 | + int64_t reduce_size, |
| 200 | + int64_t inner_size) { |
| 201 | + const int64_t reduce_stride = inner_size; |
| 202 | + const int64_t outer_stride = reduce_size * inner_size; |
| 203 | + for (int64_t o = 0; o < outer_size; ++o) { |
| 204 | + const T_IN* in_o = in + o * outer_stride; |
| 205 | + T_OUT* out_o = out + o * inner_size; |
| 206 | + for (int64_t j = 0; j < inner_size; ++j) { |
| 207 | + double sum = 0.0; |
| 208 | + for (int64_t k = 0; k < reduce_size; ++k) { |
| 209 | + sum += static_cast<double>(in_o[k * reduce_stride + j]); |
| 210 | + } |
| 211 | + out_o[j] = static_cast<T_OUT>(sum); |
| 212 | + } |
| 213 | + } |
| 214 | +} |
| 215 | + |
| 216 | +} // namespace |
| 217 | + |
| 218 | +Tensor& sum_dim_out( |
| 219 | + KernelRuntimeContext& ctx, |
| 220 | + const Tensor& in, |
| 221 | + optional<ArrayRef<int64_t>> dim_list, |
| 222 | + bool keepdim, |
| 223 | + optional<ScalarType> dtype, |
| 224 | + Tensor& out) { |
| 225 | + ET_KERNEL_CHECK( |
| 226 | + ctx, |
| 227 | + torch::executor::check_reduction_args(in, dim_list, keepdim, dtype, out), |
| 228 | + InvalidArgument, |
| 229 | + out); |
| 230 | + |
| 231 | + ET_KERNEL_CHECK( |
| 232 | + ctx, |
| 233 | + torch::executor::resize_reduction_out(in, dim_list, keepdim, out) == |
| 234 | + torch::executor::Error::Ok, |
| 235 | + InvalidArgument, |
| 236 | + out); |
| 237 | + |
| 238 | + ET_KERNEL_CHECK( |
| 239 | + ctx, |
| 240 | + torch::executor::tensors_have_same_dim_order(in, out), |
| 241 | + InvalidArgument, |
| 242 | + out); |
| 243 | + |
| 244 | + ET_KERNEL_CHECK( |
| 245 | + ctx, |
| 246 | + torch::executor::tensor_is_default_dim_order(in), |
| 247 | + InvalidArgument, |
| 248 | + out); |
| 249 | + |
| 250 | + if (in.numel() == 0) { |
| 251 | + if (out.numel() > 0) { |
| 252 | + std::memset(out.mutable_data_ptr(), 0, out.nbytes()); |
| 253 | + } |
| 254 | + return out; |
| 255 | + } |
| 256 | + |
| 257 | + // We only fast-path the common case: single reduction dim, matching dtype, |
| 258 | + // non-complex type, contiguous input. Everything else falls through to the |
| 259 | + // scalar path which has identical semantics but no vectorization. |
| 260 | + const bool fast_eligible = dim_list.has_value() && |
| 261 | + dim_list.value().size() == 1 && |
| 262 | + in.scalar_type() == out.scalar_type() && |
| 263 | + !executorch::runtime::isComplexType(in.scalar_type()) && |
| 264 | + torch::executor::tensor_is_contiguous(in); |
| 265 | + |
| 266 | + if (fast_eligible) { |
| 267 | + const int64_t dim = normalize_dim(dim_list.value()[0], in.dim()); |
| 268 | + int64_t outer_size, reduce_size, inner_size; |
| 269 | + compute_partition_sizes(in, dim, outer_size, reduce_size, inner_size); |
| 270 | + |
| 271 | +#ifdef __aarch64__ |
| 272 | + if (in.scalar_type() == ScalarType::Float) { |
| 273 | + const float* ip = in.const_data_ptr<float>(); |
| 274 | + float* op = out.mutable_data_ptr<float>(); |
| 275 | + if (inner_size == 1) { |
| 276 | + sum_innermost<float>(ip, op, outer_size, reduce_size); |
| 277 | + } else { |
| 278 | + sum_strided<float>(ip, op, outer_size, reduce_size, inner_size); |
| 279 | + } |
| 280 | + return out; |
| 281 | + } else if (in.scalar_type() == ScalarType::Half) { |
| 282 | + static_assert(sizeof(__fp16) == 2, "expected __fp16 == 2 bytes"); |
| 283 | + const __fp16* ip = |
| 284 | + reinterpret_cast<const __fp16*>(in.const_data_ptr<uint16_t>()); |
| 285 | + __fp16* op = |
| 286 | + reinterpret_cast<__fp16*>(out.mutable_data_ptr<uint16_t>()); |
| 287 | + if (inner_size == 1) { |
| 288 | + sum_innermost<__fp16>(ip, op, outer_size, reduce_size); |
| 289 | + } else { |
| 290 | + sum_strided<__fp16>(ip, op, outer_size, reduce_size, inner_size); |
| 291 | + } |
| 292 | + return out; |
| 293 | + } |
| 294 | +#endif |
| 295 | + |
| 296 | + // Scalar fast path for dtypes we don't have NEON for, or non-arm64 hosts. |
| 297 | + // @lint-ignore CLANGTIDY facebook-hte-CArray |
| 298 | + static constexpr const char op_name[] = "sum.IntList_out"; |
| 299 | + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] { |
| 300 | + const CTYPE* ip = in.const_data_ptr<CTYPE>(); |
| 301 | + CTYPE* op = out.mutable_data_ptr<CTYPE>(); |
| 302 | + sum_scalar<CTYPE, CTYPE>(ip, op, outer_size, reduce_size, inner_size); |
| 303 | + }); |
| 304 | + return out; |
| 305 | + } |
| 306 | + |
| 307 | + // Generic fallback: multi-dim reduction, dtype conversion, or any other |
| 308 | + // case the fast path rejects. Semantically matches the portable kernel. |
| 309 | + // @lint-ignore CLANGTIDY facebook-hte-CArray |
| 310 | + static constexpr const char op_name[] = "sum.IntList_out"; |
| 311 | + std::optional<torch::executor::MapReduceOverDimListPlan> plan; |
| 312 | + plan.emplace(in, dim_list); |
| 313 | + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] { |
| 314 | + ET_SWITCH_REALHBBF16_TYPES( |
| 315 | + out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] { |
| 316 | + CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>(); |
| 317 | + const bool success = |
| 318 | + torch::executor:: |
| 319 | + parallel_for_each_reduce_over_dim_list_output_index( |
| 320 | + in, |
| 321 | + dim_list, |
| 322 | + out, |
| 323 | + [&](const auto begin, const auto end) { |
| 324 | + for (const auto out_ix : |
| 325 | + c10::irange(begin, end)) { |
| 326 | + CTYPE_OUT sum = 0; |
| 327 | + sum = plan->execute<CTYPE_IN, CTYPE_OUT>( |
| 328 | + [](CTYPE_IN v) { |
| 329 | + return static_cast<CTYPE_OUT>(v); |
| 330 | + }, |
| 331 | + [](CTYPE_OUT outv, CTYPE_OUT acc) { |
| 332 | + return acc + outv; |
| 333 | + }, |
| 334 | + out_ix); |
| 335 | + out_data[out_ix] = sum; |
| 336 | + } |
| 337 | + }); |
| 338 | + ET_KERNEL_CHECK_MSG( |
| 339 | + ctx, success, Internal, , "parallel_for failed"); |
| 340 | + }); |
| 341 | + }); |
| 342 | + return out; |
| 343 | +} |
| 344 | + |
| 345 | +} // namespace native |
| 346 | +} // namespace custom |
0 commit comments