@@ -545,41 +545,17 @@ class open_addressing_ref_impl {
545545
546546 // If the key is already in the container, return false
547547 if (eq_res == detail::equal_result::EQUAL) {
548- if constexpr (has_payload and sizeof (value_type) > 8 ) {
549- #if (__CUDA_ARCH__ >= 900)
550- if constexpr (not cuco::detail::is_packable<value_type>()) {
551- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
552- }
553- #else
554- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
555- #endif
556- }
548+ this ->maybe_wait_for_payload (slot_ptr);
557549 return {iterator{slot_ptr}, false };
558550 }
559551 if (eq_res == detail::equal_result::AVAILABLE) {
560552 switch (this ->attempt_insert_stable (slot_ptr, bucket_slots[i], val)) {
561553 case insert_result::SUCCESS: {
562- if constexpr (has_payload and sizeof (value_type) > 8 ) {
563- #if (__CUDA_ARCH__ >= 900)
564- if constexpr (not cuco::detail::is_packable<value_type>()) {
565- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
566- }
567- #else
568- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
569- #endif
570- }
554+ this ->maybe_wait_for_payload (slot_ptr);
571555 return {iterator{slot_ptr}, true };
572556 }
573557 case insert_result::DUPLICATE: {
574- if constexpr (has_payload and sizeof (value_type) > 8 ) {
575- #if (__CUDA_ARCH__ >= 900)
576- if constexpr (not cuco::detail::is_packable<value_type>()) {
577- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
578- }
579- #else
580- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
581- #endif
582- }
558+ this ->maybe_wait_for_payload (slot_ptr);
583559 return {iterator{slot_ptr}, false };
584560 }
585561 default : continue ;
@@ -647,17 +623,7 @@ class open_addressing_ref_impl {
647623 if (group_finds_equal) {
648624 auto const src_lane = __ffs (group_finds_equal) - 1 ;
649625 auto const res = group.shfl (reinterpret_cast <intptr_t >(slot_ptr), src_lane);
650- if (group.thread_rank () == src_lane) {
651- if constexpr (has_payload and sizeof (value_type) > 8 ) {
652- #if (__CUDA_ARCH__ >= 900)
653- if constexpr (not cuco::detail::is_packable<value_type>()) {
654- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
655- }
656- #else
657- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
658- #endif
659- }
660- }
626+ if (group.thread_rank () == src_lane) { this ->maybe_wait_for_payload (slot_ptr); }
661627 group.sync ();
662628 return {iterator{reinterpret_cast <value_type*>(res)}, false };
663629 }
@@ -673,32 +639,12 @@ class open_addressing_ref_impl {
673639
674640 switch (group.shfl (status, src_lane)) {
675641 case insert_result::SUCCESS: {
676- if (group.thread_rank () == src_lane) {
677- if constexpr (has_payload and sizeof (value_type) > 8 ) {
678- #if (__CUDA_ARCH__ >= 900)
679- if constexpr (not cuco::detail::is_packable<value_type>()) {
680- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
681- }
682- #else
683- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
684- #endif
685- }
686- }
642+ if (group.thread_rank () == src_lane) { this ->maybe_wait_for_payload (slot_ptr); }
687643 group.sync ();
688644 return {iterator{reinterpret_cast <value_type*>(res)}, true };
689645 }
690646 case insert_result::DUPLICATE: {
691- if (group.thread_rank () == src_lane) {
692- if constexpr (has_payload and sizeof (value_type) > 8 ) {
693- #if (__CUDA_ARCH__ >= 900)
694- if constexpr (not cuco::detail::is_packable<value_type>()) {
695- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
696- }
697- #else
698- this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
699- #endif
700- }
701- }
647+ if (group.thread_rank () == src_lane) { this ->maybe_wait_for_payload (slot_ptr); }
702648 group.sync ();
703649 return {iterator{reinterpret_cast <value_type*>(res)}, false };
704650 }
@@ -1973,6 +1919,34 @@ class open_addressing_ref_impl {
19731919 } while (cuco::detail::bitwise_compare (current, sentinel));
19741920 }
19751921
1922+ /* *
1923+ * @brief Conditionally spin-waits for the payload of a non-atomically inserted slot to become
1924+ * visible.
1925+ *
1926+ * For containers where the key and value are inserted by separate instructions
1927+ * (`cas_dependent_write` / `back_to_back_cas`), an observer thread may see the key before the
1928+ * payload. This helper spins until the payload is visible. For atomic single-CAS paths (slot
1929+ * size <= 8 bytes, or a packable slot on sm_90+ via `atom.cas.b128`), the payload is already
1930+ * visible and this is a no-op.
1931+ *
1932+ * @tparam SlotPtr Pointer-like type to a slot holding a `.second` payload member
1933+ *
1934+ * @param slot_ptr Pointer to the slot whose payload may need waiting on
1935+ */
1936+ template <typename SlotPtr>
1937+ __device__ void maybe_wait_for_payload (SlotPtr slot_ptr) noexcept
1938+ {
1939+ if constexpr (has_payload and sizeof (value_type) > 8 ) {
1940+ #if (__CUDA_ARCH__ >= 900)
1941+ if constexpr (not cuco::detail::is_packable<value_type>()) {
1942+ this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
1943+ }
1944+ #else
1945+ this ->wait_for_payload (slot_ptr->second , this ->empty_value_sentinel ());
1946+ #endif
1947+ }
1948+ }
1949+
19761950 // TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper
19771951 value_type empty_slot_sentinel_; // /< Sentinel value indicating an empty slot
19781952 detail::equal_wrapper<key_type, key_equal, allows_duplicates>
0 commit comments