Skip to content

Commit ccb6750

Browse files
committed
fix: improve is_cross_lane_128 implementation
Rewrite is_cross_lane_128 in a more procedural style using C++14 constexpr features with a temporary array and for loop, replacing the recursive template implementation. The function correctly checks for cross-lane operations on 128-bit lanes.
1 parent a7497a9 commit ccb6750

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed

include/xsimd/arch/common/xsimd_common_swizzle.hpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,57 @@ namespace xsimd
167167
return cross_impl<0, sizeof...(Vs), sizeof...(Vs) / 2, Vs...>::value;
168168
}
169169

170+
// 128-bit lane aware cross_impl: checks per 128-bit lane
171+
template <std::size_t I,
172+
std::size_t N,
173+
std::size_t LaneElems,
174+
typename U,
175+
U... Vs>
176+
struct cross_impl128
177+
{
178+
static constexpr std::size_t Vi = static_cast<std::size_t>(get_at<U, I, Vs...>::value);
179+
static constexpr bool curr = ((I / LaneElems) != (static_cast<std::size_t>(Vi) / LaneElems));
180+
static constexpr bool next = cross_impl128<I + 1, N, LaneElems, U, Vs...>::value;
181+
static constexpr bool value = curr || next;
182+
};
183+
template <std::size_t N, std::size_t LaneElems, typename U, U... Vs>
184+
struct cross_impl128<N, N, LaneElems, U, Vs...>
185+
{
186+
static constexpr bool value = false;
187+
};
188+
189+
template <typename ElemT, typename U, U... Vs>
190+
XSIMD_INLINE constexpr bool is_cross_lane_128() noexcept
191+
{
192+
static_assert(std::is_integral<U>::value, "swizzle mask values must be integral");
193+
static_assert(sizeof...(Vs) >= 1, "Need at least one lane");
194+
constexpr std::size_t N = sizeof...(Vs);
195+
constexpr std::size_t lane_elems = 16 / sizeof(ElemT);
196+
constexpr U values[] = { Vs... };
197+
for (std::size_t i = 0; i < N; ++i)
198+
{
199+
std::size_t elem_lane = i / lane_elems;
200+
std::size_t target_lane = static_cast<std::size_t>(values[i]) / lane_elems;
201+
if (elem_lane != target_lane)
202+
return true;
203+
}
204+
return false;
205+
}
206+
207+
// overload accepting an element type first to compute 128-bit lane size
208+
template <typename ElemT, typename U, U... Vs>
209+
XSIMD_INLINE constexpr bool is_cross_lane() noexcept
210+
{
211+
return is_cross_lane_128<ElemT, U, Vs...>();
212+
}
213+
214+
// convenience overload taking element type then integer non-type parameter pack
215+
template <typename ElemT, std::size_t... Vs>
216+
XSIMD_INLINE constexpr bool is_cross_lane() noexcept
217+
{
218+
return is_cross_lane_128<ElemT, std::size_t, Vs...>();
219+
}
220+
170221
template <typename T, T... Vs>
171222
XSIMD_INLINE constexpr bool is_identity() noexcept { return detail::identity_impl<0, T, Vs...>(); }
172223
template <typename T, T... Vs>
@@ -184,7 +235,11 @@ namespace xsimd
184235
template <typename T, class A, T... Vs>
185236
XSIMD_INLINE constexpr bool is_only_from_hi(batch_constant<T, A, Vs...>) noexcept { return detail::is_only_from_hi<T, Vs...>(); }
186237
template <typename T, class A, T... Vs>
187-
XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept { return detail::is_cross_lane<Vs...>(); }
238+
XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept
239+
{
240+
static_assert(std::is_integral<T>::value, "swizzle mask values must be integral");
241+
return is_cross_lane_128<T, T, Vs...>();
242+
}
188243

189244
} // namespace detail
190245
} // namespace kernel

test/test_batch_manip.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,24 @@ namespace xsimd
5252
static_assert(is_dup_hi<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_hi failed");
5353
static_assert(!is_dup_lo<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_lo on dup_hi");
5454

55-
static_assert(is_cross_lane<0, 1, 0, 1>(), "dup-lo only → crossing");
56-
static_assert(is_cross_lane<2, 3, 2, 3>(), "dup-hi only → crossing");
57-
static_assert(is_cross_lane<0, 3, 3, 3>(), "one low + rest high → crossing");
58-
static_assert(!is_cross_lane<1, 0, 2, 3>(), "mixed low/high → no crossing");
59-
static_assert(!is_cross_lane<0, 1, 2, 3>(), "mixed low/high → no crossing");
55+
static_assert(is_cross_lane<double, 0, 1, 0, 1>(), "dup-lo only → crossing");
56+
static_assert(is_cross_lane<double, 2, 3, 2, 3>(), "dup-hi only → crossing");
57+
static_assert(is_cross_lane<double, 0, 3, 3, 3>(), "one low + rest high → crossing");
58+
static_assert(!is_cross_lane<double, 1, 0, 2, 3>(), "mixed low/high → no crossing");
59+
static_assert(!is_cross_lane<double, 0, 1, 2, 3>(), "mixed low/high → no crossing");
60+
// 8-lane 128-bit lane checks (use double/int64 for 2-elements-per-128-bit lanes)
61+
static_assert(is_cross_lane<double, 3, 2, 1, 0, 7, 6, 5, 4>(), "8-lane 128-bit swap → crossing");
62+
static_assert(!is_cross_lane<double, 0, 1, 2, 3, 4, 5, 6, 7>(), "identity 8-lane → no crossing");
63+
static_assert(is_cross_lane<std::uint64_t, 3, 2, 1, 0, 7, 6, 5, 4>(), "8-lane uint64_t swap → crossing");
64+
static_assert(is_cross_lane<std::int32_t, 4, 5, 6, 7, 0, 1, 2, 3>(), "8-lane int32_t swap → crossing");
65+
66+
// Additional compile-time checks for 16-element batches (e.g. float/int32)
67+
static_assert(is_cross_lane<float, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7>(),
68+
"16-lane 128-bit swap → crossing");
69+
static_assert(!is_cross_lane<float, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15>(),
70+
"identity 16-lane → no crossing");
71+
static_assert(is_cross_lane<std::uint32_t, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7>(),
72+
"16-lane uint32_t swap → crossing");
6073
}
6174
}
6275
}

0 commit comments

Comments
 (0)