@@ -406,16 +406,15 @@ struct MediumKernelPlanSize {
406406 }
407407};
408408
409- template <int kUnpackedPerKernelLimit_ >
410409struct MediumKernelOptions {
411410 // / An indicative limit on the number of values unpacked by the kernel.
412411 // / This is a heuristic setting: other constraints such as alignment may not always make
413412 // / small values feasibles. Must be a power of two.
414- static constexpr int kUnpackedPerKernelLimit = kUnpackedPerKernelLimit_ ;
413+ int unpacked_per_kernel_limit_ ;
415414};
416415
417- template < typename KernelOptions>
418- constexpr MediumKernelPlanSize BuildMediumPlanSize ( const KernelShape& shape ) {
416+ constexpr MediumKernelPlanSize BuildMediumPlanSize ( const KernelShape& shape,
417+ const MediumKernelOptions& options ) {
419418 const int shifts_per_swizzle =
420419 shape.unpacked_byte_size () / shape.packed_max_spread_bytes ();
421420
@@ -424,7 +423,7 @@ constexpr MediumKernelPlanSize BuildMediumPlanSize(const KernelShape& shape) {
424423 // Using `unpacked_per_kernel_limit` to influence the number of swizzles per reads.
425424 const auto packed_per_read_for_offset = [&](int bit_offset) -> int {
426425 const int best = (shape.simd_bit_size () - bit_offset) / shape.packed_bit_size ();
427- const int limit = KernelOptions:: kUnpackedPerKernelLimit ;
426+ const int limit = options. unpacked_per_kernel_limit_ ;
428427 return (best > limit) && (limit > 0 ) ? limit : best;
429428 };
430429
@@ -472,12 +471,12 @@ constexpr int reduced_bytes_per_read(int bits_per_read, int simd_byte_size) {
472471 return simd_byte_size;
473472}
474473
475- template <typename KernelTraits, typename KernelOptions >
474+ template <typename KernelTraits, MediumKernelOptions kOptions >
476475struct MediumKernelPlan {
477476 using Traits = KernelTraits;
478477 using uint_type = typename Traits::uint_type;
479478 static constexpr auto kShape = Traits::kShape ;
480- static constexpr auto kPlanSize = BuildMediumPlanSize<KernelOptions> (kShape );
479+ static constexpr auto kPlanSize = BuildMediumPlanSize(kShape , kOptions );
481480
482481 using ReadsPerKernel = std::array<int , kPlanSize .reads_per_kernel()>;
483482
@@ -523,9 +522,9 @@ constexpr Arr BuildConstantArray(typename Arr::value_type val) {
523522 return out;
524523}
525524
526- template <typename KernelTraits, typename KernelOptions >
525+ template <typename KernelTraits, MediumKernelOptions kOptions >
527526constexpr auto BuildMediumPlan () {
528- using Plan = MediumKernelPlan<KernelTraits, KernelOptions >;
527+ using Plan = MediumKernelPlan<KernelTraits, kOptions >;
529528 constexpr auto kShape = Plan::kShape ;
530529 constexpr auto kPlanSize = Plan::kPlanSize ;
531530 static_assert (kShape .is_medium ());
@@ -584,9 +583,9 @@ xsimd::batch<uint8_t, Arch> load_bytes(const uint8_t* in) {
584583 return simd_bytes::load_unaligned (in);
585584}
586585
587- template <typename KernelTraits, typename KernelOptions = MediumKernelOptions< 32 > >
586+ template <typename KernelTraits, MediumKernelOptions kOptions >
588587struct MediumKernel {
589- static constexpr auto kPlan = BuildMediumPlan<KernelTraits, KernelOptions >();
588+ static constexpr auto kPlan = BuildMediumPlan<KernelTraits, kOptions >();
590589 static constexpr auto kPlanSize = kPlan .kPlanSize ;
591590 static constexpr auto kShape = kPlan .kShape ;
592591 using Traits = typename decltype (kPlan )::Traits;
@@ -855,16 +854,19 @@ constexpr bool LargeShouldUseUint16 =
855854// A ``std::enable_if`` that works on MSVC
856855template <typename Traits>
857856constexpr auto KernelDispatchImpl () {
857+ constexpr MediumKernelOptions kMedKernelOpts = {.unpacked_per_kernel_limit_ = 32 };
858858 if constexpr (Traits::kShape .is_medium ()) {
859859 if constexpr (MediumShouldUseUint32<Traits>) {
860- using Kernel32 = MediumKernel<KernelTraitsWithUnpack<Traits, uint32_t >>;
860+ using Kernel32 =
861+ MediumKernel<KernelTraitsWithUnpack<Traits, uint32_t >, kMedKernelOpts >;
861862 return ForwardToKernel<Traits, Kernel32>{};
862863 } else {
863- return MediumKernel<Traits>{};
864+ return MediumKernel<Traits, kMedKernelOpts >{};
864865 }
865866 } else if constexpr (Traits::kShape .is_large ()) {
866867 if constexpr (LargeShouldUseUint16<Traits>) {
867- using Kernel16 = MediumKernel<KernelTraitsWithUnpack<Traits, uint16_t >>;
868+ using Kernel16 =
869+ MediumKernel<KernelTraitsWithUnpack<Traits, uint16_t >, kMedKernelOpts >;
868870 return ForwardToKernel<Traits, Kernel16>{};
869871 } else {
870872 return LargeKernel<Traits>{};
0 commit comments