diff --git a/include/cuco/detail/open_addressing/robin_hood/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/robin_hood/open_addressing_ref_impl.cuh new file mode 100644 index 000000000..5bbbffc56 --- /dev/null +++ b/include/cuco/detail/open_addressing/robin_hood/open_addressing_ref_impl.cuh @@ -0,0 +1,2378 @@ +/* + * Copyright (c) 2023-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(CUCO_HAS_CUDA_BARRIER) +#include +#endif + +#include + +namespace cuco { +namespace detail { +namespace robin_hood { + +/// Three-way insert result enum +enum class insert_result : cuda::std::int8_t { CONTINUE = 0, SUCCESS = 1, DUPLICATE = 2 }; + +/** + * @brief Helper struct to store intermediate bucket probing results. + */ +struct bucket_probing_results { + detail::equal_result state_; ///< Equal result + cuda::std::int32_t intra_bucket_index_; ///< Intra-bucket index + + /** + * @brief Constructs bucket_probing_results. + * + * @param state The three way equality result + * @param index Intra-bucket index + */ + __device__ explicit constexpr bucket_probing_results(detail::equal_result state, + cuda::std::int32_t index) noexcept + : state_{state}, intra_bucket_index_{index} + { + } +}; + +/** + * @brief Robin Hood inverse primitive for the linear probing sequence. + * + * @note Recovers a resident's probe distance ("age") from the slot it occupies: how many probing + * steps the resident sits from its own home bucket. For the linear sequence this is a single + * subtract — `(slot_base - resident_home) / stride mod num_buckets`. This is the linear-only + * overload; a `double_hashing` variant would add its own overload here (a modular inverse of the + * resident's per-key step, or a stored age), which is the single place that change lands. + * + * @tparam BucketSize Size of the bucket + * @tparam CGSize Size of CUDA Cooperative Groups + * @tparam Hash Unary callable type + * @tparam ProbeKey Type of probing key + * @tparam Extent Type of extent + * + * @param scheme The underlying linear probing scheme (supplies the hash function) + * @param resident_key The key currently residing in the slot + * @param slot_index The slot index at which `resident_key` resides + * @param upper_bound Upper bound of the iteration + * @return The resident's probe distance, in probing steps + */ +template +[[nodiscard]] __host__ __device__ constexpr typename Extent::value_type probe_distance( + linear_probing const& scheme, + ProbeKey resident_key, + typename Extent::value_type slot_index, + Extent upper_bound) noexcept +{ + using size_type = typename Extent::value_type; + size_type constexpr stride = CGSize * BucketSize; + auto const bound = static_cast(upper_bound); + auto const hash = scheme.hash_function(); + + // Home bucket base of the resident, using the same alignment as `make_iterator`. + size_type const resident_home = + cuco::detail::sanitize_hash(hash(resident_key)) % (bound / stride) * stride; + + // Bucket-strided base of the slot the resident currently occupies. The per-lane `thread_rank` + // offset (which is < stride) is stripped by the floor division so that the distance is measured + // in whole probing steps, consistent with the forward sequence. + size_type const slot_base = (slot_index / stride) * stride; + + // (slot_base - resident_home) mod capacity, expressed in probing steps. + return static_cast((slot_base + bound - resident_home) % bound) / stride; +} + +/** + * @brief Common device non-owning "ref" implementation class. + * + * @note This class should NOT be used directly. + * + * @throw If the size of the given key type is larger than `cuco::open_addressing_max_key_size` + * @throw If the size of the given slot type is larger than `cuco::open_addressing_max_slot_size` + * @throw If the given key type doesn't have unique object representations, i.e., + * `cuco::is_bitwise_comparable_v == false` + * @throw If the given payload type doesn't have unique object representations, i.e., + * `cuco::is_bitwise_comparable_v == false` + * @throw If the probing scheme type is not inherited from `cuco::detail::probing_scheme_base` + * + * @tparam Key Type used for keys. Requires `sizeof(Key) <= cuco::open_addressing_max_key_size` and + * `cuco::is_bitwise_comparable_v` + * @tparam Scope The scope in which operations will be performed by individual threads. + * @tparam KeyEqual Binary callable type used to compare two keys for equality + * @tparam ProbingScheme Probing scheme (see `include/cuco/probing_scheme.cuh` for options) + * @tparam StorageRef Storage ref type. Its `value_type` must fit in + * `cuco::open_addressing_max_slot_size`; + * payloads, if present, must be 4 or 8 bytes (or 16 with sm_90+) and satisfy + * `cuco::is_bitwise_comparable_v` + * @tparam AllowsDuplicates Flag indicating whether duplicate keys are allowed or not + */ +template +class open_addressing_ref_impl + : private open_addressing_compatible { + using storage_value_type = typename StorageRef::value_type; + + /// Determines if the container is a key/value or key-only store + static constexpr auto has_payload = not cuda::std::is_same_v; + + /// Flag indicating whether duplicate keys are allowed or not + static constexpr auto allows_duplicates = AllowsDuplicates; + + // TODO: how to re-enable this check? + // static_assert(is_bucket_extent_v, + // "Extent is not a valid cuco::bucket_extent"); + + public: + using key_type = Key; ///< Key type + using probing_scheme_type = ProbingScheme; ///< Type of probing scheme + using hasher = typename probing_scheme_type::hasher; ///< Hash function type + using storage_ref_type = StorageRef; ///< Type of storage ref + using bucket_type = typename storage_ref_type::bucket_type; ///< Bucket type + using value_type = typename storage_ref_type::value_type; ///< Storage element type + using extent_type = typename storage_ref_type::extent_type; ///< Extent type + using size_type = typename storage_ref_type::size_type; ///< Probing scheme size type + using key_equal = KeyEqual; ///< Type of key equality binary callable + using iterator = typename storage_ref_type::iterator; ///< Slot iterator type + using const_iterator = typename storage_ref_type::const_iterator; ///< Const slot iterator type + + static constexpr auto cg_size = probing_scheme_type::cg_size; ///< Cooperative group size + static constexpr auto bucket_size = + storage_ref_type::bucket_size; ///< Number of elements handled per bucket + static constexpr auto thread_scope = Scope; ///< CUDA thread scope + + // Robin Hood displacement swaps the in-flight pair into an *occupied* slot, which needs a single + // atomic CAS of the whole slot. That requires a packable slot: <= 8 bytes (atom.cas.b64), or + // padding-free and <= 16 bytes on an sm_90+ build (atom.cas.b128). A non-packable slot (e.g. a + // padded `pair`) would fall back to a split key/value CAS, which cannot move an + // occupied slot -- displacement would livelock. Reject it at compile time rather than hang. + static constexpr bool robin_hood_slot_is_single_cas = sizeof(value_type) <= 8 +#if defined(CUCO_HAS_128BIT_ATOMICS) + or cuco::detail::is_packable() +#endif + ; + static_assert(robin_hood_slot_is_single_cas, + "Robin Hood probing requires a single-CAS slot: the key+value must fit in 8 bytes, " + "or be packable (padding-free) and <= 16 bytes on an sm_90+ build. A padded slot " + "(e.g. pair) is unsupported -- displacement would livelock."); + + /** + * @brief Constructs open_addressing_ref_impl. + * + * @param empty_slot_sentinel Sentinel indicating an empty slot + * @param predicate Key equality binary callable + * @param probing_scheme Probing scheme + * @param storage_ref Non-owning ref of slot storage + */ + __host__ __device__ explicit constexpr open_addressing_ref_impl( + value_type empty_slot_sentinel, + key_equal const& predicate, + probing_scheme_type const& probing_scheme, + storage_ref_type storage_ref) noexcept + : empty_slot_sentinel_{empty_slot_sentinel}, + predicate_{ + this->extract_key(empty_slot_sentinel), this->extract_key(empty_slot_sentinel), predicate}, + probing_scheme_{probing_scheme}, + storage_ref_{storage_ref} + { + } + + /** + * @brief Constructs open_addressing_ref_impl. + * + * @param empty_slot_sentinel Sentinel indicating an empty slot + * @param erased_key_sentinel Sentinel indicating an erased key + * @param predicate Key equality binary callable + * @param probing_scheme Probing scheme + * @param storage_ref Non-owning ref of slot storage + */ + __host__ __device__ explicit constexpr open_addressing_ref_impl( + value_type empty_slot_sentinel, + key_type erased_key_sentinel, + key_equal const& predicate, + probing_scheme_type const& probing_scheme, + storage_ref_type storage_ref) noexcept + : empty_slot_sentinel_{empty_slot_sentinel}, + predicate_{this->extract_key(empty_slot_sentinel), erased_key_sentinel, predicate}, + probing_scheme_{probing_scheme}, + storage_ref_{storage_ref} + { + } + + /** + * @brief Gets the sentinel value used to represent an empty key slot. + * + * @return The sentinel value used to represent an empty key slot + */ + [[nodiscard]] __host__ __device__ constexpr key_type empty_key_sentinel() const noexcept + { + return this->predicate_.empty_sentinel_; + } + + /** + * @brief Gets the sentinel value used to represent an empty payload slot. + * + * @return The sentinel value used to represent an empty payload slot + */ + template > + [[nodiscard]] __host__ __device__ constexpr auto empty_value_sentinel() const noexcept + { + return this->extract_payload(this->empty_slot_sentinel()); + } + + /** + * @brief Gets the sentinel value used to represent an erased key slot. + * + * @return The sentinel value used to represent an erased key slot + */ + [[nodiscard]] __host__ __device__ constexpr key_type erased_key_sentinel() const noexcept + { + return this->predicate_.erased_sentinel_; + } + + /** + * @brief Gets the sentinel used to represent an empty slot. + * + * @return The sentinel value used to represent an empty slot + */ + [[nodiscard]] __host__ __device__ constexpr value_type empty_slot_sentinel() const noexcept + { + return empty_slot_sentinel_; + } + + /** + * @brief Returns the function that compares keys for equality. + * + * @return The key equality predicate + */ + [[nodiscard]] __host__ + __device__ constexpr detail::equal_wrapper + predicate() const noexcept + { + return this->predicate_; + } + + /** + * @brief Gets the key comparator. + * + * @return The comparator used to compare keys + */ + [[nodiscard]] __host__ __device__ constexpr key_equal key_eq() const noexcept + { + return this->predicate().equal_; + } + + /** + * @brief Gets the probing scheme. + * + * @return The probing scheme used for the container + */ + [[nodiscard]] __host__ __device__ constexpr probing_scheme_type probing_scheme() const noexcept + { + return probing_scheme_; + } + + /** + * @brief Gets the function(s) used to hash keys + * + * @return The function(s) used to hash keys + */ + [[nodiscard]] __host__ __device__ constexpr hasher hash_function() const noexcept + { + return this->probing_scheme().hash_function(); + } + + /** + * @brief Gets the non-owning storage ref. + * + * @return The non-owning storage ref of the container + */ + [[nodiscard]] __host__ __device__ constexpr storage_ref_type storage_ref() const noexcept + { + return storage_ref_; + } + + /** + * @brief Gets the maximum number of elements the container can hold. + * + * @return The maximum number of elements the container can hold + */ + [[nodiscard]] __host__ __device__ constexpr auto capacity() const noexcept + { + return storage_ref_.capacity(); + } + + /** + * @brief Gets the bucket extent of the current storage. + * + * @return The bucket extent. + */ + [[nodiscard]] __host__ __device__ constexpr extent_type extent() const noexcept + { + return storage_ref_.extent(); + } + + /** + * @brief Returns an iterator to one past the last slot. + * + * @return An iterator to one past the last slot + */ + [[nodiscard]] __host__ __device__ constexpr iterator end() const noexcept + { + return storage_ref_.end(); + } + + /** + * @brief Returns an iterator to one past the last slot. + * + * @return An iterator to one past the last slot + */ + [[nodiscard]] __host__ __device__ constexpr iterator end() noexcept { return storage_ref_.end(); } + + /** + * @brief Makes a copy of the current device reference using non-owned memory. + * + * This function is intended to be used to create shared memory copies of small static data + * structures, although global memory can be used as well. + * + * @tparam CG The type of the cooperative thread group + * + * @param g The cooperative thread group used to copy the data structure + * @param memory_to_use Array large enough to support `capacity` elements. Object does not take + * the ownership of the memory + */ + template + __device__ void make_copy(CG g, value_type* const memory_to_use) const noexcept + { + auto const num_slots = this->capacity(); +#if defined(CUCO_HAS_CUDA_BARRIER) +#pragma nv_diagnostic push +// Disables `barrier` initialization warning. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ cuda::barrier barrier; +#pragma nv_diagnostic pop + if (g.thread_rank() == 0) { init(&barrier, g.size()); } + g.sync(); + + cuda::memcpy_async( + g, memory_to_use, this->storage_ref().data(), sizeof(value_type) * num_slots, barrier); + + barrier.arrive_and_wait(); +#else + value_type const* const slots_ptr = this->storage_ref().data(); + for (size_type i = g.thread_rank(); i < num_slots; i += g.size()) { + memory_to_use[i] = slots_ptr[i]; + } + g.sync(); +#endif + } + + /** + * @brief Initializes the container storage. + * + * @note This function synchronizes the group `tile`. + * + * @tparam CG The type of the cooperative thread group + * + * @param tile The cooperative thread group used to initialize the container + */ + template + __device__ constexpr void initialize(CG tile) noexcept + { + auto tid = tile.thread_rank(); + auto const extent = static_cast(this->extent()); + + auto* const slots_ptr = this->storage_ref().data(); + while (tid < extent) { + slots_ptr[tid] = this->empty_slot_sentinel(); + tid += tile.size(); + } + + tile.sync(); + } + + /** + * @brief Inserts an element. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param value The element to insert + * + * @return True if the given element is successfully inserted + */ + template + __device__ bool insert(Value value) noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto val = this->heterogeneous_value(value); + auto key = this->extract_key(val); + + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + bool retry = false; + + for (auto& slot_content : bucket_slots) { + auto const eq_res = this->predicate_.template operator()( + key, this->extract_key(slot_content)); + + if constexpr (not allows_duplicates) { + // If the key is already in the container, return false + if (eq_res == detail::equal_result::EQUAL) { return false; } + } + // Robin Hood claims only a true empty here; a tombstone carries an age and is handled as a + // resident by the displacement test below. Skipping it must gate the CAS (once claimed it + // is already consumed), so it is folded into this condition. + if (eq_res == detail::equal_result::AVAILABLE and not this->is_erased(slot_content)) { + auto const intra_bucket_index = cuda::std::distance(bucket_slots.begin(), &slot_content); + switch (attempt_insert( + this->get_slot_ptr(*probing_iter, intra_bucket_index), slot_content, val)) { + case insert_result::DUPLICATE: { + if constexpr (allows_duplicates) { + [[fallthrough]]; + } else { + return false; + } + } + case insert_result::CONTINUE: { + // Retry on a lost CAS. Plain probing keeps scanning this (now stale) bucket; Robin + // Hood must re-read it instead, so the in-flight pair is re-evaluated against the new + // occupants -- otherwise it could be placed past a slot it should have displaced, + // breaking the invariant (and therefore lookups). + retry = true; + break; + } + case insert_result::SUCCESS: return true; + } + if (retry) { break; } // leave the scan to re-read the bucket + } + + // Robin Hood swap test. A resident "richer" than the in-flight pair (a smaller probe + // distance than our current probe step) is displaced: we swap our pair into its slot, adopt + // the evicted resident, and re-probe forward. A tombstone is treated as a resident too -- + // its age comes from its payload (`robin_hood_age`) -- but picking one up *consumes* it: we + // take the slot and are done, since there is nothing to carry forward. + if (eq_res == detail::equal_result::UNEQUAL or this->is_erased(slot_content)) { + auto const intra_bucket_index = cuda::std::distance(bucket_slots.begin(), &slot_content); + auto const evicted_age = this->robin_hood_age( + slot_content, static_cast(*probing_iter + intra_bucket_index)); + if (evicted_age < probe_step) { + if (this->attempt_insert(this->get_slot_ptr(*probing_iter, intra_bucket_index), + slot_content, + val) == insert_result::SUCCESS) { + // Consuming a tombstone reuses its freed slot -- nothing to carry, so we are done. + if (this->is_erased(slot_content)) { return true; } + // Adopt the evicted pair and re-probe THIS bucket -- its bucket distance here is + // `evicted_age`, and it may belong in another slot of the same bucket: an empty + // one, or one holding an even-richer resident it can displace in turn. Re-reading + // the bucket (rather than advancing past it) is the within-bucket linear probe, + // i.e. the combined bucket+slot distance that makes displacement correct for + // bucket_size > 1. The `slot_distance` term cancels in every comparison, so it + // never appears here; it shows up only as this slot-by-slot continuation. + // `bit_cast` keeps the adoption valid for heterogeneous insert types + // (layout-compatible by contract; identity in the common case). + val = cuda::std::bit_cast(slot_content); + key = this->extract_key(val); + probe_step = evicted_age; + } + retry = + true; // re-read this bucket: re-probe with the victim, or re-evaluate a lost CAS + break; + } + } + } + + if (retry) { continue; } // re-probe (re-read this bucket, or move on after displacement) + ++probe_step; + + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Inserts an element. + * + * @tparam Value Input type which is convertible to 'value_type' + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group insert + * @param value The element to insert + * + * @return True if the given element is successfully inserted + */ + template + __device__ bool insert(cooperative_groups::thread_block_tile group, + Value value) noexcept + { + auto val = this->heterogeneous_value(value); + auto key = this->extract_key(val); + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_bucket_index] = [&]() { + bucket_probing_results result{detail::equal_result::UNEQUAL, -1}; + cuda::static_for([&] __device__(auto i) { + if (result.state_ == detail::equal_result::UNEQUAL) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()]))) { + case detail::equal_result::AVAILABLE: { + // Robin Hood: only a true empty is AVAILABLE; a tombstone is a resident handled by + // the displacement scan below, so leave it UNEQUAL here. + bool empty_slot = not this->is_erased(bucket_slots[i()]); + if (empty_slot) { + result = bucket_probing_results{detail::equal_result::AVAILABLE, i()}; + } + break; + } + case detail::equal_result::EQUAL: { + if constexpr (!allows_duplicates) { + result = bucket_probing_results{detail::equal_result::EQUAL, i()}; + } + break; + } + default: break; + } + } + }); + return result; + }(); + + if constexpr (not allows_duplicates) { + // If the key is already in the container, return false + if (group.any(state == detail::equal_result::EQUAL)) { return false; } + } + + auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + if (group_contains_available) { + auto const src_lane = __ffs(group_contains_available) - 1; + auto status = insert_result::CONTINUE; + if (group.thread_rank() == src_lane) { + if constexpr (SupportsErase) { + status = attempt_insert(this->get_slot_ptr(*probing_iter, intra_bucket_index), + bucket_slots[intra_bucket_index], + val); + } else { + status = attempt_insert(this->get_slot_ptr(*probing_iter, intra_bucket_index), + this->empty_slot_sentinel(), + val); + } + } + + switch (group.shfl(status, src_lane)) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: { + if constexpr (allows_duplicates) { + [[fallthrough]]; + } else { + return false; + } + } + default: continue; + } + } else { + // Robin Hood displacement: no match, no empty slot in this bucket. Displace the first + // resident in probe (lane) order that is richer than the in-flight pair, adopt it, and + // re-probe THIS bucket -- the victim may belong in another slot of it. The within-bucket + // linear probe (combined bucket+slot distance) is identical to the scalar path; the + // `slot_distance` term cancels, so the test is again `resident distance < probe_step`. + cuda::std::int32_t displace_idx = -1; + size_type evicted_age = 0; + cuda::static_for([&] __device__(auto i) { + if (displace_idx < 0) { + // `robin_hood_age` so a tombstone uses its payload-stored age: it is displaced (i.e. + // consumed) exactly when richer than the in-flight pair, like any other resident. + auto const age = + this->robin_hood_age(bucket_slots[i()], static_cast(*probing_iter + i())); + if (age < probe_step) { + displace_idx = i(); + evicted_age = age; + } + } + }); + + auto const group_displaceable = group.ballot(displace_idx >= 0); + if (group_displaceable) { + auto const src_lane = __ffs(group_displaceable) - 1; + auto status = insert_result::CONTINUE; + // Only `src_lane` reads `evicted` meaningfully; other lanes just need a valid value to + // feed the broadcast `shfl` below, so seed it with the empty-slot sentinel. + value_type evicted = this->empty_slot_sentinel(); + if (group.thread_rank() == src_lane) { + evicted = bucket_slots[displace_idx]; + status = attempt_insert(this->get_slot_ptr(*probing_iter, displace_idx), evicted, val); + } + if (group.shfl(status, src_lane) == insert_result::SUCCESS) { + // Consuming a tombstone reuses its freed slot -- nothing to carry, so we are done. + if (group.shfl(this->is_erased(evicted), src_lane)) { return true; } + // Broadcast the evicted pair and its probe distance from the winning lane, and adopt + // it on every lane (all lanes need the new in-flight pair for the next scan). + auto const new_key = group.shfl(this->extract_key(evicted), src_lane); + auto const new_age = group.shfl(evicted_age, src_lane); + value_type evicted_slot; + if constexpr (has_payload) { + auto const new_payload = group.shfl(this->extract_payload(evicted), src_lane); + evicted_slot = value_type{new_key, new_payload}; + } else { + evicted_slot = new_key; + } + val = cuda::std::bit_cast(evicted_slot); + key = this->extract_key(val); + probe_step = new_age; + } + continue; // success: re-probe this bucket with the victim; lost CAS: re-read it + } + // No displaceable resident: fall through to the shared advance below. + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + } + + /** + * @brief Inserts the given element into the container. + * + * @note This API returns a pair consisting of an iterator to the inserted element (or to the + * element that prevented the insertion) and a `bool` denoting whether the insertion took place or + * not. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param value The element to insert + * + * @return a pair consisting of an iterator to the element and a bool indicating whether the + * insertion is successful or not. + */ + template + __device__ cuda::std::pair insert_and_find(Value value) noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); +#if __CUDA_ARCH__ < 700 + // Spinning to ensure that the write to the value part took place requires + // independent thread scheduling introduced with the Volta architecture. + static_assert(sizeof(value_type) <= 8, + "insert_and_find is not supported for slot types larger than 8 bytes on " + "pre-Volta GPUs."); +#endif + + auto val = this->heterogeneous_value(value); + auto key = this->extract_key(val); + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + // Robin Hood may displace the original key before the chain ends; remember the slot it landed + // in so we return an iterator to it (not to a later victim's slot). + value_type* placed_ptr = nullptr; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + bool retry = false; + + for (auto i = 0; i < bucket_size; ++i) { + auto const eq_res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i])); + auto* slot_ptr = this->get_slot_ptr(*probing_iter, i); + + // If the key is already in the container, return false + if (eq_res == detail::equal_result::EQUAL) { + this->maybe_wait_for_payload(slot_ptr); + return {iterator{slot_ptr}, false}; + } + // Robin Hood claims only a true empty here; a tombstone is handled as a resident by the + // displacement test below (see `insert`). + if (eq_res == detail::equal_result::AVAILABLE and not this->is_erased(bucket_slots[i])) { + switch (this->attempt_insert_stable(slot_ptr, bucket_slots[i], val)) { + case insert_result::SUCCESS: { + // The in-flight pair is placed in an empty slot, ending any displacement chain. The + // iterator to return is the original key's slot (captured on its first placement). + auto* result_ptr = slot_ptr; + if (placed_ptr != nullptr) { result_ptr = placed_ptr; } + this->maybe_wait_for_payload(result_ptr); + return {iterator{result_ptr}, true}; + } + case insert_result::DUPLICATE: { + this->maybe_wait_for_payload(slot_ptr); + return {iterator{slot_ptr}, false}; + } + case insert_result::CONTINUE: { + retry = true; + break; + } + } + if (retry) { break; } + } + + // Robin Hood swap test (see `insert` for the full rationale). A tombstone is a resident too + // (age from its payload); picking one up consumes it -- the in-flight pair lands there and + // we are done. + if (eq_res == detail::equal_result::UNEQUAL or this->is_erased(bucket_slots[i])) { + auto const evicted_age = + this->robin_hood_age(bucket_slots[i], static_cast(*probing_iter + i)); + if (evicted_age < probe_step) { + if (this->attempt_insert(slot_ptr, bucket_slots[i], val) == insert_result::SUCCESS) { + if (this->is_erased(bucket_slots[i])) { + // Consumed a tombstone: the in-flight pair is placed here; return the original + // key's slot (this one if it was never displaced). + auto* result_ptr = (placed_ptr != nullptr) ? placed_ptr : slot_ptr; + this->maybe_wait_for_payload(result_ptr); + return {iterator{result_ptr}, true}; + } + if (placed_ptr == nullptr) { placed_ptr = slot_ptr; } // original key's slot + val = cuda::std::bit_cast(bucket_slots[i]); + key = this->extract_key(val); + probe_step = evicted_age; + } + retry = true; + break; + } + } + } + + if (retry) { continue; } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return {this->end(), false}; } + }; + } + + /** + * @brief Inserts the given element into the container. + * + * @note This API returns a pair consisting of an iterator to the inserted element (or to the + * element that prevented the insertion) and a `bool` denoting whether the insertion took place or + * not. + * + * @tparam Value Input type which is convertible to 'value_type' + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group insert_and_find + * @param value The element to insert + * + * @return a pair consisting of an iterator to the element and a bool indicating whether the + * insertion is successful or not. + */ + template + __device__ cuda::std::pair insert_and_find( + cooperative_groups::thread_block_tile group, Value value) noexcept + { +#if __CUDA_ARCH__ < 700 + // Spinning to ensure that the write to the value part took place requires + // independent thread scheduling introduced with the Volta architecture. + static_assert(sizeof(value_type) <= 8, + "insert_and_find is not supported for slot types larger than 8 bytes on " + "pre-Volta GPUs."); +#endif + + auto val = this->heterogeneous_value(value); + auto key = this->extract_key(val); + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + // Robin Hood may displace the original key before the chain ends; remember (broadcast) the slot + // it first landed in so we return an iterator to it. 0 means "not yet placed". + intptr_t placed_ptr = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_bucket_index] = [&]() { + bucket_probing_results result{detail::equal_result::UNEQUAL, -1}; + cuda::static_for([&] __device__(auto i) { + if (result.state_ == detail::equal_result::UNEQUAL) { + auto res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()])); + // Robin Hood: a tombstone is a resident handled by the displacement scan below, not + // AVAILABLE, so leave it UNEQUAL here. + if (res == detail::equal_result::AVAILABLE and this->is_erased(bucket_slots[i()])) { + res = detail::equal_result::UNEQUAL; + } + if (res != detail::equal_result::UNEQUAL) { result = bucket_probing_results{res, i()}; } + } + }); + return result; + }(); + + auto* slot_ptr = this->get_slot_ptr(*probing_iter, intra_bucket_index); + + // If the key is already in the container, return false + auto const group_finds_equal = group.ballot(state == detail::equal_result::EQUAL); + if (group_finds_equal) { + auto const src_lane = __ffs(group_finds_equal) - 1; + auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); + if (group.thread_rank() == src_lane) { this->maybe_wait_for_payload(slot_ptr); } + group.sync(); + return {iterator{reinterpret_cast(res)}, false}; + } + + auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + if (group_contains_available) { + auto const src_lane = __ffs(group_contains_available) - 1; + auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); + auto const status = [&, target_idx = intra_bucket_index]() { + if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } + return this->attempt_insert_stable(slot_ptr, bucket_slots[target_idx], val); + }(); + + switch (group.shfl(status, src_lane)) { + case insert_result::SUCCESS: { + // The in-flight pair is placed in an empty slot, ending any displacement chain. Return + // the original key's slot (the first placement) if it was displaced earlier. + auto result = res; + if (placed_ptr != 0) { result = placed_ptr; } + if (group.thread_rank() == src_lane) { this->maybe_wait_for_payload(slot_ptr); } + group.sync(); + return {iterator{reinterpret_cast(result)}, true}; + } + case insert_result::DUPLICATE: { + if (group.thread_rank() == src_lane) { this->maybe_wait_for_payload(slot_ptr); } + group.sync(); + return {iterator{reinterpret_cast(res)}, false}; + } + default: continue; + } + } else { + // Robin Hood displacement (see CG `insert` for the full rationale). + cuda::std::int32_t displace_idx = -1; + size_type evicted_age = 0; + cuda::static_for([&] __device__(auto i) { + if (displace_idx < 0) { + // `robin_hood_age` so a tombstone uses its payload-stored age: it is displaced (i.e. + // consumed) exactly when richer than the in-flight pair, like any other resident. + auto const age = + this->robin_hood_age(bucket_slots[i()], static_cast(*probing_iter + i())); + if (age < probe_step) { + displace_idx = i(); + evicted_age = age; + } + } + }); + + auto const group_displaceable = group.ballot(displace_idx >= 0); + if (group_displaceable) { + auto const src_lane = __ffs(group_displaceable) - 1; + auto status = insert_result::CONTINUE; + value_type evicted = this->empty_slot_sentinel(); + intptr_t displaced = 0; + if (group.thread_rank() == src_lane) { + auto* dptr = this->get_slot_ptr(*probing_iter, displace_idx); + evicted = bucket_slots[displace_idx]; + status = attempt_insert(dptr, evicted, val); + displaced = reinterpret_cast(dptr); + } + if (group.shfl(status, src_lane) == insert_result::SUCCESS) { + if (placed_ptr == 0) { placed_ptr = group.shfl(displaced, src_lane); } + // Consumed a tombstone: the in-flight pair is placed in its slot; we are done. Return + // the original key's slot (`placed_ptr`, which is this slot if it was never + // displaced). + if (group.shfl(this->is_erased(evicted), src_lane)) { + if (group.thread_rank() == src_lane) { + this->maybe_wait_for_payload(reinterpret_cast(displaced)); + } + group.sync(); + return {iterator{reinterpret_cast(placed_ptr)}, true}; + } + auto const new_key = group.shfl(this->extract_key(evicted), src_lane); + auto const new_age = group.shfl(evicted_age, src_lane); + value_type evicted_slot; + if constexpr (has_payload) { + auto const new_payload = group.shfl(this->extract_payload(evicted), src_lane); + evicted_slot = value_type{new_key, new_payload}; + } else { + evicted_slot = new_key; + } + val = cuda::std::bit_cast(evicted_slot); + key = this->extract_key(val); + probe_step = new_age; + } + continue; + } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return {this->end(), false}; } + } + } + } + + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * + * @param key The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(ProbeKey key) noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + for (auto& slot_content : bucket_slots) { + auto const eq_res = + this->predicate_.template operator()(key, this->extract_key(slot_content)); + + // Key doesn't exist, return false + if (eq_res == detail::equal_result::EMPTY) { return false; } + // Key exists, return true if successfully deleted + if (eq_res == detail::equal_result::EQUAL) { + auto const intra_bucket_index = cuda::std::distance(bucket_slots.begin(), &slot_content); + // Robin Hood records the erased key's age in the tombstone payload (1a); other schemes + // use the plain erased sentinel. + value_type erased = this->robin_hood_erased_sentinel( + slot_content, static_cast(*probing_iter + intra_bucket_index)); + switch (attempt_insert_stable( + this->get_slot_ptr(*probing_iter, intra_bucket_index), slot_content, erased)) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: return false; + default: continue; + } + } + } + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group erase + * @param key The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(cooperative_groups::thread_block_tile group, + ProbeKey key) noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_bucket_index] = [&]() { + bucket_probing_results result{detail::equal_result::UNEQUAL, -1}; + cuda::static_for([&] __device__(auto i) { + if (result.state_ == detail::equal_result::UNEQUAL) { + auto res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()])); + if (res != detail::equal_result::UNEQUAL) { result = bucket_probing_results{res, i()}; } + } + }); + return result; + }(); + + auto const group_contains_equal = group.ballot(state == detail::equal_result::EQUAL); + if (group_contains_equal) { + auto const src_lane = __ffs(group_contains_equal) - 1; + auto status = insert_result::CONTINUE; + if (group.thread_rank() == src_lane) { + // Robin Hood records the erased key's age in the tombstone payload (1a); other schemes + // use the plain erased sentinel. + value_type erased = this->robin_hood_erased_sentinel( + bucket_slots[intra_bucket_index], + static_cast(*probing_iter + intra_bucket_index)); + status = attempt_insert_stable(this->get_slot_ptr(*probing_iter, intra_bucket_index), + bucket_slots[intra_bucket_index], + erased); + } + + switch (group.shfl(status, src_lane)) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: return false; + default: continue; + } + } + + // Key doesn't exist, return false + if (group.any(state == detail::equal_result::EMPTY)) { return false; } + + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Indicates whether the probe key `key` was inserted into the container. + * + * @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns + * false. + * + * @tparam ProbeKey Probe key type + * + * @param key The key to search for + * + * @return A boolean indicating whether the probe key is present + */ + template + [[nodiscard]] __device__ bool contains(ProbeKey key) const noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = storage_ref_[*probing_iter]; + + for (auto i = 0; i < bucket_size; ++i) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::UNEQUAL: continue; + case detail::equal_result::EMPTY: return false; + case detail::equal_result::EQUAL: return true; + } + } + // Robin Hood: a resident richer than us proves the key is absent. + if (this->robin_hood_proves_absent(bucket_slots, *probing_iter, probe_step)) { return false; } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Indicates whether the probe key `key` was inserted into the container. + * + * @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns + * false. + * + * @tparam ProbeKey Probe key type + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group contains + * @param key The key to search for + * + * @return A boolean indicating whether the probe key is present + */ + template + [[nodiscard]] __device__ bool contains( + cooperative_groups::thread_block_tile group, ProbeKey key) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const state = [&]() { + auto res = detail::equal_result::UNEQUAL; + for (auto i = 0; i < bucket_size; ++i) { + res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i])); + if (res != detail::equal_result::UNEQUAL) { return res; } + } + return res; + }(); + + if (group.any(state == detail::equal_result::EQUAL)) { return true; } + if (group.any(state == detail::equal_result::EMPTY)) { return false; } + + // Robin Hood: a resident richer than us (in any lane's bucket) proves the key is absent. + if (group.any(this->robin_hood_proves_absent(bucket_slots, *probing_iter, probe_step))) { + return false; + } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Finds an element in the container with key equivalent to the probe key. + * + * @note Returns a un-incrementable input iterator to the element whose key is equivalent to + * `key`. If no such element exists, returns `end()`. + * + * @tparam ProbeKey Probe key type + * + * @param key The key to search for + * + * @return An iterator to the position at which the equivalent key is stored + */ + template + [[nodiscard]] __device__ iterator find(ProbeKey key) const noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = storage_ref_[*probing_iter]; + + for (auto i = 0; i < bucket_size; ++i) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::EMPTY: { + return this->end(); + } + case detail::equal_result::EQUAL: { + return iterator{this->get_slot_ptr(*probing_iter, i)}; + } + default: continue; + } + } + // Robin Hood: a resident richer than us proves the key is absent. + if (this->robin_hood_proves_absent(bucket_slots, *probing_iter, probe_step)) { + return this->end(); + } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return this->end(); } + } + } + + /** + * @brief Finds an element in the container with key equivalent to the probe key. + * + * @note Returns a un-incrementable input iterator to the element whose key is equivalent to + * `key`. If no such element exists, returns `end()`. + * + * @tparam ProbeKey Probe key type + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * + * @return An iterator to the position at which the equivalent key is stored + */ + template + [[nodiscard]] __device__ iterator + find(cooperative_groups::thread_block_tile group, ProbeKey key) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_bucket_index] = [&]() { + bucket_probing_results result{detail::equal_result::UNEQUAL, -1}; + cuda::static_for([&] __device__(auto i) { + if (result.state_ == detail::equal_result::UNEQUAL) { + auto res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()])); + if (res != detail::equal_result::UNEQUAL) { result = bucket_probing_results{res, i()}; } + } + }); + return result; + }(); + + // Find a match for the probe key, thus return an iterator to the entry + auto const group_finds_match = group.ballot(state == detail::equal_result::EQUAL); + if (group_finds_match) { + auto const src_lane = __ffs(group_finds_match) - 1; + auto const res = group.shfl( + reinterpret_cast(this->get_slot_ptr(*probing_iter, intra_bucket_index)), + src_lane); + return iterator{reinterpret_cast(res)}; + } + + // Find an empty slot, meaning that the probe key isn't present in the container + if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); } + + // Robin Hood: a resident richer than us (in any lane's bucket) proves the key is absent. + if (group.any(this->robin_hood_proves_absent(bucket_slots, *probing_iter, probe_step))) { + return this->end(); + } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return this->end(); } + } + } + + /** + * @brief Counts the occurrence of a given key contained in the container + * + * @tparam ProbeKey Probe key type + * + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + [[nodiscard]] __device__ size_type count(ProbeKey key) const noexcept + { + if constexpr (not allows_duplicates) { + return static_cast(this->contains(key)); + } else { + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type count = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + cuda::std::int32_t equals[bucket_size] = {0}; + bool empty_found = false; + + cuda::static_for([&] __device__(auto i) { + auto const result = predicate_.template operator()( + key, this->extract_key(bucket_slots[i()])); + equals[i()] = (result == detail::equal_result::EQUAL); + if (result == detail::equal_result::EMPTY) { empty_found = true; } + }); + + count += thrust::reduce(thrust::seq, equals, equals + bucket_size); + + if (empty_found) { return count; } + + ++probing_iter; + if (*probing_iter == init_idx) { return count; } + } + } + } + + /** + * @brief Counts the occurrence of a given key contained in the container + * + * @tparam ProbeKey Probe key type + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group count + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + [[nodiscard]] __device__ size_type + count(cooperative_groups::thread_block_tile group, ProbeKey key) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type count = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + cuda::std::int32_t equals[bucket_size] = {0}; + bool empty_found = false; + + cuda::static_for([&] __device__(auto i) { + auto const result = + predicate_.template operator()(key, this->extract_key(bucket_slots[i()])); + equals[i()] = (result == detail::equal_result::EQUAL); + if (result == detail::equal_result::EMPTY) { empty_found = true; } + }); + + count += thrust::reduce(thrust::seq, equals, equals + bucket_size); + + if (group.any(empty_found)) { return count; } + + ++probing_iter; + if (*probing_iter == init_idx) { return count; } + } + } + + /** + * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, + * input_probe_end)`. + * + * If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated + * slot contents to `output_match`, respectively. The output order is unspecified. + * + * Behavior is undefined if the size of the output range exceeds the number of retrieved slots. + * Use `count()` to determine the size of the output range. + * + * @tparam BlockSize Size of the thread block this operation is executed in + * @tparam InputProbeIt Device accessible input iterator + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic counter type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * + * @param block Thread block this operation is executed in + * @param input_probe_begin Beginning of the input sequence of keys + * @param input_probe_end End of the input sequence of keys + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Atomic object of integral type that is used to count the + * number of output elements + */ + template + __device__ void retrieve(cooperative_groups::thread_block const& block, + InputProbeIt input_probe_begin, + InputProbeIt input_probe_end, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter& atomic_counter) const + { + auto constexpr is_outer = false; + auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); + auto const always_true_stencil = cuda::constant_iterator(true); + auto const identity_predicate = cuda::std::identity{}; + this->retrieve_impl(block, + input_probe_begin, + n, + always_true_stencil, + identity_predicate, + output_probe, + output_match, + atomic_counter); + } + + /** + * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, + * input_probe_end)`. + * + * If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated + * slot contents to `output_match`, respectively. The output order is unspecified. + * + * Behavior is undefined if the size of the output range exceeds the number of retrieved slots. + * Use `count()` to determine the size of the output range. + * + * If a key `k` has no matches in the container, then `{key, empty_slot_sentinel}` will be added + * to the output sequence. + * + * @tparam BlockSize Size of the thread block this operation is executed in + * @tparam InputProbeIt Device accessible input iterator + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic counter type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * + * @param block Thread block this operation is executed in + * @param input_probe_begin Beginning of the input sequence of keys + * @param input_probe_end End of the input sequence of keys + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Atomic object of integral type that is used to count the + * number of output elements + */ + template + __device__ void retrieve_outer(cooperative_groups::thread_block const& block, + InputProbeIt input_probe_begin, + InputProbeIt input_probe_end, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter& atomic_counter) const + { + auto constexpr is_outer = true; + auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); + auto const always_true_stencil = cuda::constant_iterator(true); + auto const identity_predicate = cuda::std::identity{}; + this->retrieve_impl(block, + input_probe_begin, + n, + always_true_stencil, + identity_predicate, + output_probe, + output_match, + atomic_counter); + } + + /** + * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, + * input_probe_end)` if `pred` of the corresponding stencil returns true. + * + * If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true, + * copies `k` to `output_probe` and associated slot contents to `output_match`, + * respectively. The output order is unspecified. + * + * Behavior is undefined if the size of the output range exceeds the number of retrieved slots. + * Use `count()` to determine the size of the output range. + * + * @tparam BlockSize Size of the thread block this operation is executed in + * @tparam InputProbeIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` + * and argument type is convertible from `std::iterator_traits::value_type` + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic counter type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * + * @param block Thread block this operation is executed in + * @param input_probe_begin Beginning of the input sequence of keys + * @param input_probe_end End of the input sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + n)` + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Atomic object of integral type that is used to count the + * number of output elements + */ + template + __device__ void retrieve_if(cooperative_groups::thread_block const& block, + InputProbeIt input_probe_begin, + InputProbeIt input_probe_end, + StencilIt stencil, + Predicate pred, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter& atomic_counter) const + { + auto constexpr is_outer = false; + auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); + this->retrieve_impl( + block, input_probe_begin, n, stencil, pred, output_probe, output_match, atomic_counter); + } + + /** + * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, + * input_probe_end)`. + * + * If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated + * slot contents to `output_match`, respectively. The output order is unspecified. + * + * Behavior is undefined if the size of the output range exceeds the number of retrieved slots. + * Use `count()` to determine the size of the output range. + * + * If `IsOuter == true` and a key `k` has no matches in the container, then `{key, + * empty_slot_sentinel}` will be added to the output sequence. + * + * @tparam IsOuter Flag indicating if an inner or outer retrieve operation should be performed + * @tparam BlockSize Size of the thread block this operation is executed in + * @tparam InputProbeIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` + * and argument type is convertible from `std::iterator_traits::value_type` + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * + * @param block Thread block this operation is executed in + * @param input_probe Beginning of the input sequence of keys + * @param n Number of input keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + n)` + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Atomic object of integral type that is used to count the + * number of output elements + */ + template + __device__ void retrieve_impl(cooperative_groups::thread_block const& block, + InputProbeIt input_probe, + cuco::detail::index_type n, + StencilIt stencil, + Predicate pred, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter& atomic_counter) const + { + namespace cg = cooperative_groups; + + if (n == 0) { return; } + + using probe_type = typename cuda::std::iterator_traits::value_type; + + // tuning parameter + auto constexpr buffer_multiplier = 1; + static_assert(buffer_multiplier > 0); + + auto constexpr probing_tile_size = cg_size; + auto constexpr flushing_tile_size = cuco::detail::warp_size(); + static_assert(flushing_tile_size >= probing_tile_size); + + auto constexpr num_flushing_tiles = BlockSize / flushing_tile_size; + auto constexpr max_matches_per_step = flushing_tile_size * bucket_size; + auto constexpr buffer_size = buffer_multiplier * max_matches_per_step + flushing_tile_size; + + auto const flushing_tile = cg::tiled_partition(block); + auto const probing_tile = cg::tiled_partition(block); + + auto const flushing_tile_id = flushing_tile.meta_group_rank(); + auto const stride = probing_tile.meta_group_size(); + auto idx = probing_tile.meta_group_rank(); + + __shared__ cuco::pair buffers[num_flushing_tiles][buffer_size]; + __shared__ cuda::std::int32_t counters[num_flushing_tiles]; + + if (flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; } + flushing_tile.sync(); + + auto flush_buffers = [&](auto tile) { + size_type offset = 0; + auto const count = counters[flushing_tile_id]; + auto const rank = tile.thread_rank(); + if (rank == 0) { offset = atomic_counter.fetch_add(count, cuda::memory_order_relaxed); } + offset = tile.shfl(offset, 0); + + // flush_buffers + for (auto i = rank; i < count; i += tile.size()) { + *(output_probe + offset + i) = buffers[flushing_tile_id][i].first; + *(output_match + offset + i) = buffers[flushing_tile_id][i].second; + } + }; + + while (flushing_tile.any(idx < n)) { + bool active_flag = idx < n and pred(*(stencil + idx)); + auto const active_flushing_tile = + cg::binary_partition(flushing_tile, active_flag); + + if (active_flag) { + // perform probing + // make sure the flushing_tile is converged at this point to get a coalesced load + auto const probe_key = *(input_probe + idx); + + auto probing_iter = probing_scheme_.template make_iterator( + probing_tile, probe_key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + + bool running = true; + [[maybe_unused]] bool found_match = false; + + bool equals[bucket_size]; + cuda::std::uint32_t exists[bucket_size]; + + while (active_flushing_tile.any(running)) { + if (running) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + + cuda::static_for([&] __device__(auto i) { + equals[i()] = false; + if (running) { + // inspect slot content + switch (this->predicate_.template operator()( + probe_key, this->extract_key(bucket_slots[i()]))) { + case detail::equal_result::EMPTY: { + running = false; + break; + } + case detail::equal_result::EQUAL: { + if constexpr (!AllowsDuplicates) { running = false; } + equals[i()] = true; + break; + } + default: { + break; + } + } + } + }); + + probing_tile.sync(); + running = probing_tile.all(running); + cuda::static_for( + [&](auto i) { exists[i()] = probing_tile.ballot(equals[i()]); }); + + // Fill the buffer if any matching keys are found + auto const lane_id = probing_tile.thread_rank(); + if (thrust::any_of(thrust::seq, exists, exists + bucket_size, cuda::std::identity{})) { + if constexpr (IsOuter) { found_match = true; } + + cuda::std::int32_t num_matches[bucket_size]; + + cuda::static_for( + [&](auto i) { num_matches[i()] = __popc(exists[i()]); }); + + cuda::std::int32_t output_idx; + if (lane_id == 0) { + auto const total_matches = + thrust::reduce(thrust::seq, num_matches, num_matches + bucket_size); + auto ref = cuda::atomic_ref{ + counters[flushing_tile_id]}; + output_idx = ref.fetch_add(total_matches, cuda::memory_order_relaxed); + } + output_idx = probing_tile.shfl(output_idx, 0); + + cuda::std::int32_t matches_offset = 0; + cuda::static_for([&] __device__(auto i) { + if (equals[i()]) { + auto const lane_offset = + detail::count_least_significant_bits(exists[i()], lane_id); + buffers[flushing_tile_id][output_idx + matches_offset + lane_offset] = { + probe_key, bucket_slots[i()]}; + } + matches_offset += num_matches[i()]; + }); + } + // Special handling for outer cases where no match is found + if constexpr (IsOuter) { + if (!running) { + if (!found_match and lane_id == 0) { + auto ref = cuda::atomic_ref{ + counters[flushing_tile_id]}; + auto const output_idx = ref.fetch_add(1, cuda::memory_order_relaxed); + buffers[flushing_tile_id][output_idx] = {probe_key, this->empty_slot_sentinel()}; + } + } + } + } // if running + + active_flushing_tile.sync(); + // if the buffer has not enough empty slots for the next iteration + if (counters[flushing_tile_id] > (buffer_size - max_matches_per_step)) { + flush_buffers(active_flushing_tile); + active_flushing_tile.sync(); + + // reset buffer counter + if (active_flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; } + active_flushing_tile.sync(); + } + + // onto the next probing bucket + ++probing_iter; + if (*probing_iter == init_idx) { running = false; } + } // while running + } // if active_flag + + // onto the next key + idx += stride; + } + + flushing_tile.sync(); + // entire flusing_tile has finished; flush remaining elements + if (counters[flushing_tile_id] > 0) { flush_buffers(flushing_tile); } + } + + /** + * @brief For a given key, applies the function object `callback_op` to the copy of all + * corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * + * @param key The key to search for + * @param callback_op Function to apply to every matched slot + */ + template + __device__ void for_each(ProbeKey key, CallbackOp&& callback_op) const noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + + bool should_return = false; + cuda::static_for([&] __device__(auto i) { + if (!should_return) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()]))) { + case detail::equal_result::EMPTY: { + should_return = true; + break; + } + case detail::equal_result::EQUAL: { + callback_op(bucket_slots[i()]); + break; + } + default: break; + } + } + }); + if (should_return) { return; } + ++probing_iter; + if (*probing_iter == init_idx) { return; } + } + } + + /** + * @brief For a given key, applies the function object `callback_op` to the copy of all + * corresponding matches found in the container. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to apply to every matched slot + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile group, + ProbeKey key, + CallbackOp&& callback_op) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + bool empty = false; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + + for (cuda::std::int32_t i = 0; i < bucket_size and !empty; ++i) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::EMPTY: { + empty = true; + continue; + } + case detail::equal_result::EQUAL: { + callback_op(bucket_slots[i]); + continue; + } + default: { + continue; + } + } + } + if (group.any(empty)) { return; } + + ++probing_iter; + if (*probing_iter == init_idx) { return; } + } + } + + /** + * @brief Applies the function object `callback_op` to the copy of every slot in the container + * with key equivalent to the probe key and can additionally perform work that requires + * synchronizing the Cooperative Group performing this operation. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @note The `sync_op` function can be used to perform work that requires synchronizing threads in + * `group` inbetween probing steps, where the number of probing steps performed between + * synchronization points is capped by `bucket_size * cg_size`. The functor will be called right + * after the current probing bucket has been traversed. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * @tparam SyncOp Type of function object which accepts the current `group` object + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to apply to every matched slot + * @param sync_op Function that is allowed to synchronize `group` inbetween probing buckets + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile group, + ProbeKey key, + CallbackOp&& callback_op, + SyncOp&& sync_op) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + bool empty = false; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + + for (cuda::std::int32_t i = 0; i < bucket_size and !empty; ++i) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::EMPTY: { + empty = true; + continue; + } + case detail::equal_result::EQUAL: { + callback_op(bucket_slots[i]); + continue; + } + default: { + continue; + } + } + } + sync_op(group); + if (group.any(empty)) { return; } + + ++probing_iter; + if (*probing_iter == init_idx) { return; } + } + } + + /** + * @brief Gets a pointer to the slot at the given probing index and intra-bucket index. + * + * @param probing_idx The current probing index + * @param intra_bucket_idx The index within the bucket (0 for flat storage) + * @return Pointer to the slot + */ + __device__ value_type* get_slot_ptr(size_type probing_idx, + cuda::std::int32_t intra_bucket_idx) const noexcept + { + return storage_ref_.data() + probing_idx + intra_bucket_idx; + } + + /** + * @brief Determines whether the Robin Hood invariant proves the probe key absent at the current + * probe step. + * + * @note Only meaningful for Robin Hood probing. The key is proven absent when the bucket holds a + * resident that is "richer" than the probe key — i.e. whose own probe distance is smaller than + * the probe key's probe distance at the current step (`probe_step`). Such a resident would have + * been displaced on insertion if the probe key lived here, so the probe key cannot be present. + * + * @note Behavior is only well-defined when every slot in the bucket is occupied (the callers + * reach this check only after ruling out empty and matching slots), since probe distance is + * meaningless for an empty slot. + * + * @tparam BucketSlots Bucket slot array type + * + * @param bucket_slots The slots of the bucket currently being probed + * @param bucket_base The slot index of the first slot in the bucket + * @param probe_step The probe key's own probe distance at the current step + * + * @return True if some resident in the bucket is richer than the probe key + */ + template + [[nodiscard]] __device__ bool robin_hood_proves_absent(BucketSlots const& bucket_slots, + size_type bucket_base, + size_type probe_step) const noexcept + { + bool richer = false; + cuda::static_for([&](auto i) { + auto const resident_age = + this->robin_hood_age(bucket_slots[i()], static_cast(bucket_base + i())); + if (resident_age < probe_step) { richer = true; } + }); + return richer; + } + + /** + * @brief Whether `slot` holds a tombstone (erased marker). + * + * @note Returns false when erase is disabled (the erased and empty sentinels coincide, so no slot + * is a tombstone) -- this keeps the test correct even for empty slots. + * + * @param slot The slot to test + * + * @return True if `slot` is an erased tombstone + */ + [[nodiscard]] __device__ bool is_erased(value_type const& slot) const noexcept + { + return not cuco::detail::bitwise_compare(this->erased_key_sentinel(), + this->empty_key_sentinel()) and + cuco::detail::bitwise_compare(this->extract_key(slot), this->erased_key_sentinel()); + } + + /** + * @brief Robin Hood probe distance ("age") of an occupied slot. + * + * A live key's age is its `probe_distance`. A Robin Hood tombstone keeps the age of the key it + * replaced in its payload (the original key is gone and cannot be rehashed; see `erase`), so it + * is read back here -- a tombstone then participates in every Robin Hood comparison exactly like + * the resident it stood in for. + * + * @param slot The (occupied) slot + * @param slot_index The slot's index + * + * @return The slot's probe distance + */ + [[nodiscard]] __device__ size_type robin_hood_age(value_type const& slot, + size_type slot_index) const noexcept + { + if constexpr (has_payload) { + if (this->is_erased(slot)) { return static_cast(this->extract_payload(slot)); } + } + return robin_hood::probe_distance( + probing_scheme_, this->extract_key(slot), slot_index, storage_ref_.extent()); + } + + /** + * @brief The Robin Hood tombstone for erasing the live key currently in `slot` at `slot_index`. + * + * The erased key's age is stashed in the payload (1a) so the tombstone keeps its place in the + * Robin Hood ordering (the original key is gone and cannot be rehashed). Other probing schemes + * use the plain `erased_slot_sentinel()` and never call this. + * + * @param slot The slot's current (live) contents + * @param slot_index The slot's index + * + * @return The value to CAS into the slot to erase it + */ + [[nodiscard]] __device__ value_type + robin_hood_erased_sentinel(value_type const& slot, size_type slot_index) const noexcept + { + static_assert(has_payload, + "Robin Hood erase requires a mapped payload to store the tombstone age"); + auto const age = robin_hood::probe_distance( + probing_scheme_, this->extract_key(slot), slot_index, storage_ref_.extent()); + return cuco::pair{this->erased_key_sentinel(), + static_castempty_value_sentinel())>(age)}; + } + + /** + * @brief Extracts the key from a given value type. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The key + */ + template + [[nodiscard]] __host__ __device__ constexpr auto extract_key(Value value) const noexcept + { + if constexpr (has_payload) { + return thrust::raw_reference_cast(value).first; + } else { + return thrust::raw_reference_cast(value); + } + } + + /** + * @brief Extracts the payload from a given value type. + * + * @note This function is only available if `this->has_payload == true` + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The payload + */ + template > + [[nodiscard]] __host__ __device__ constexpr auto extract_payload(Value value) const noexcept + { + return thrust::raw_reference_cast(value).second; + } + + /** + * @brief Converts the given type to the container's native `value_type`. + * + * @tparam T Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The converted object + */ + template + [[nodiscard]] __device__ constexpr value_type native_value(T value) const noexcept + { + if constexpr (has_payload) { + return {static_cast(this->extract_key(value)), this->extract_payload(value)}; + } else { + return static_cast(value); + } + } + + /** + * @brief Converts the given type to the container's native `value_type` while maintaining the + * heterogeneous key type. + * + * @tparam T Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The converted object + */ + template + [[nodiscard]] __device__ constexpr auto heterogeneous_value(T value) const noexcept + { + if constexpr (has_payload and not cuda::std::is_same_v) { + using mapped_type = decltype(this->empty_value_sentinel()); + if constexpr (cuco::detail::is_cuda_std_pair_like::value) { + return cuco::pair{cuda::std::get<0>(value), + static_cast(cuda::std::get<1>(value))}; + } else { + // hail mary (convert using .first/.second members) + return cuco::pair{thrust::raw_reference_cast(value.first), + static_cast(value.second)}; + } + } else { + return thrust::raw_reference_cast(value); + } + } + + /** + * @brief Gets the sentinel used to represent an erased slot. + * + * @return The sentinel value used to represent an erased slot + */ + [[nodiscard]] __device__ constexpr value_type erased_slot_sentinel() const noexcept + { + if constexpr (has_payload) { + return cuco::pair{this->erased_key_sentinel(), this->empty_value_sentinel()}; + } else { + return this->erased_key_sentinel(); + } + } + + /** + * @brief Inserts the specified element with one single CAS operation. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* address, + value_type expected, + Value desired) noexcept + { + using packed_type = cuco::detail::packed_t; + + auto* slot_ptr = reinterpret_cast(address); + auto* expected_ptr = reinterpret_cast(&expected); + auto* desired_ptr = reinterpret_cast(&desired); + + auto slot_ref = cuda::atomic_ref{*slot_ptr}; + + auto const success = + slot_ref.compare_exchange_strong(*expected_ptr, *desired_ptr, cuda::memory_order_relaxed); + + if (success) { + return insert_result::SUCCESS; + } else { + return this->predicate_.equal_to(this->extract_key(desired), this->extract_key(expected)) == + detail::equal_result::EQUAL + ? insert_result::DUPLICATE + : insert_result::CONTINUE; + } + } + + /** + * @brief Inserts the specified element with two back-to-back CAS operations. + * + * @note This CAS can be used exclusively for `cuco::op::insert` operations. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* address, + value_type expected, + Value desired) noexcept + { + using mapped_type = cuda::std::decay_tempty_value_sentinel())>; + + auto expected_key = expected.first; + auto expected_payload = this->empty_value_sentinel(); + + cuda::atomic_ref key_ref(address->first); + cuda::atomic_ref payload_ref(address->second); + + auto const key_cas_success = key_ref.compare_exchange_strong( + expected_key, static_cast(desired.first), cuda::memory_order_relaxed); + auto payload_cas_success = payload_ref.compare_exchange_strong( + expected_payload, desired.second, cuda::memory_order_relaxed); + + // if key success + if (key_cas_success) { + while (not payload_cas_success) { + payload_cas_success = + payload_ref.compare_exchange_strong(expected_payload = this->empty_value_sentinel(), + desired.second, + cuda::memory_order_relaxed); + } + return insert_result::SUCCESS; + } else if (payload_cas_success) { + // This is insert-specific, cannot for `erase` operations + payload_ref.store(this->empty_value_sentinel(), cuda::memory_order_relaxed); + } + + // Our key was already present in the slot, so our key is a duplicate + // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare + if (this->predicate_.equal_to(desired.first, expected_key) == detail::equal_result::EQUAL) { + return insert_result::DUPLICATE; + } + + return insert_result::CONTINUE; + } + + /** + * @brief Inserts the specified element with CAS-dependent write operations. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ constexpr insert_result cas_dependent_write(value_type* address, + value_type expected, + Value desired) noexcept + { + using mapped_type = cuda::std::decay_tempty_value_sentinel())>; + + cuda::atomic_ref key_ref(address->first); + auto expected_key = expected.first; + auto const success = key_ref.compare_exchange_strong( + expected_key, static_cast(desired.first), cuda::memory_order_relaxed); + + // if key success + if (success) { + cuda::atomic_ref payload_ref(address->second); + payload_ref.store(desired.second, cuda::memory_order_relaxed); + return insert_result::SUCCESS; + } + + // Our key was already present in the slot, so our key is a duplicate + // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare + if (this->predicate_.equal_to(desired.first, expected_key) == detail::equal_result::EQUAL) { + return insert_result::DUPLICATE; + } + + return insert_result::CONTINUE; + } + + /** + * @brief Attempts to insert an element into a slot. + * + * @note Dispatches the correct implementation depending on the container + * type and presence of other operator mixins. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ insert_result attempt_insert(value_type* address, + value_type expected, + Value desired) noexcept + { + if constexpr (sizeof(value_type) <= 8) { + return packed_cas(address, expected, desired); + } +#if (__CUDA_ARCH__ >= 900) + else if constexpr (cuco::detail::is_packable()) { + return packed_cas(address, expected, desired); + } +#endif + else if constexpr (has_payload) { +#if (__CUDA_ARCH__ < 700) + return cas_dependent_write(address, expected, desired); +#else + return back_to_back_cas(address, expected, desired); +#endif + } else { + static_assert(cuco::dependent_false, + "No valid atomic CAS path: 16-byte key in a key-only container must be " + "packable (have unique object representations) and target sm_90+."); + } + } + + /** + * @brief Attempts to insert an element into a slot. + * + * @note Dispatches the correct implementation depending on the container + * type and presence of other operator mixins. + * + * @note `stable` indicates that the payload will only be updated once from the sentinel value to + * the desired value, meaning there can be no ABA situations. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ insert_result attempt_insert_stable(value_type* address, + value_type expected, + Value desired) noexcept + { + if constexpr (sizeof(value_type) <= 8) { + return packed_cas(address, expected, desired); + } +#if (__CUDA_ARCH__ >= 900) + else if constexpr (cuco::detail::is_packable()) { + return packed_cas(address, expected, desired); + } +#endif + else if constexpr (has_payload) { + return cas_dependent_write(address, expected, desired); + } else { + static_assert(cuco::dependent_false, + "No valid atomic CAS path: 16-byte key in a key-only container must be " + "packable (have unique object representations) and target sm_90+."); + } + } + + /** + * @brief Waits until the slot payload has been updated + * + * @note The function will return once the slot payload is no longer equal to the sentinel + * value. + * + * @tparam T Map slot type + * + * @param slot The target slot to check payload with + * @param sentinel The slot sentinel value + */ + template + __device__ void wait_for_payload(T& slot, T sentinel) const noexcept + { + auto ref = cuda::atomic_ref{slot}; + T current; + // TODO exponential backoff strategy + do { + current = ref.load(cuda::std::memory_order_relaxed); + } while (cuco::detail::bitwise_compare(current, sentinel)); + } + + /** + * @brief Conditionally spin-waits for the payload of a non-atomically inserted slot to become + * visible. + * + * For containers where the key and value are inserted by separate instructions + * (`cas_dependent_write` / `back_to_back_cas`), an observer thread may see the key before the + * payload. This helper spins until the payload is visible. For atomic single-CAS paths (slot + * size <= 8 bytes, or a packable slot on sm_90+ via `atom.cas.b128`), the payload is already + * visible and this is a no-op. + * + * @tparam SlotPtr Pointer-like type to a slot holding a `.second` payload member + * + * @param slot_ptr Pointer to the slot whose payload may need waiting on + */ + template + __device__ void maybe_wait_for_payload(SlotPtr slot_ptr) noexcept + { + if constexpr (has_payload and sizeof(value_type) > 8) { +#if (__CUDA_ARCH__ >= 900) + if constexpr (not cuco::detail::is_packable()) { + this->wait_for_payload(slot_ptr->second, this->empty_value_sentinel()); + } +#else + this->wait_for_payload(slot_ptr->second, this->empty_value_sentinel()); +#endif + } + } + + // TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper + value_type empty_slot_sentinel_; ///< Sentinel value indicating an empty slot + detail::equal_wrapper + predicate_; ///< Key equality binary callable + probing_scheme_type probing_scheme_; ///< Probing scheme + storage_ref_type storage_ref_; ///< Slot storage ref +}; + +} // namespace robin_hood +} // namespace detail +} // namespace cuco diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index 660a7fcbd..1c9691d2b 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -74,7 +75,12 @@ class static_map_ref static constexpr auto allows_duplicates = false; /// Implementation type - using impl_type = detail:: + // + // HARD-WIRE (experimental Robin Hood PR): static_map is routed through the Robin Hood engine + // instead of the generic open-addressing engine. This single line is what makes static_map use + // Robin Hood probing; downstream front-end work replaces it with a proper backend-selection + // mechanism. See robin_hood_refactor_plan.md. + using impl_type = detail::robin_hood:: open_addressing_ref_impl; public: diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09c31ee35..d63209f9a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -93,7 +93,8 @@ ConfigureTest(STATIC_MAP_TEST static_map/stream_test.cu static_map/rehash_test.cu static_map/retrieve_test.cu - static_map/retrieve_if_test.cu) + static_map/retrieve_if_test.cu + static_map/robin_hood_test.cu) ################################################################################################### # - dynamic_map tests ----------------------------------------------------------------------------- diff --git a/tests/static_map/contains_test.cu b/tests/static_map/contains_test.cu index 0b3604528..a68e8d132 100644 --- a/tests/static_map/contains_test.cu +++ b/tests/static_map/contains_test.cu @@ -124,19 +124,23 @@ TEMPLATE_TEST_CASE_SIG( Value, Probe, CGSize), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // Robin Hood is linear-probing + single-CAS only. Unsupported variants are kept commented for + // reference / easy re-enable: double_hashing (no RH probe_distance), padded >8B slots + // (int32/int64, int64/int32), and >16B slots. 16B packable slots (int64/int64) need 128-bit + // atomics, so they sit under the #if. + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2) #if defined(CUCO_HAS_128BIT_ATOMICS) , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) +// (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), +// (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) #endif ) { diff --git a/tests/static_map/duplicate_keys_test.cu b/tests/static_map/duplicate_keys_test.cu index 9a3422f0b..d4a7ea22a 100644 --- a/tests/static_map/duplicate_keys_test.cu +++ b/tests/static_map/duplicate_keys_test.cu @@ -38,27 +38,31 @@ TEMPLATE_TEST_CASE_SIG( Value, Probe, CGSize), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // Robin Hood is linear-probing + single-CAS only. Unsupported variants are kept commented for + // reference / easy re-enable: double_hashing (no RH probe_distance), padded >8B slots + // (int32/int64, int64/int32), and >16B slots. 16B packable slots (int64/int64) need 128-bit + // atomics, so they sit under the #if. + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) + // (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2) +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), #if defined(CUCO_HAS_128BIT_ATOMICS) , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) +// (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), +// (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) #endif ) { diff --git a/tests/static_map/erase_test.cu b/tests/static_map/erase_test.cu index 8bf09abe2..283ee6710 100644 --- a/tests/static_map/erase_test.cu +++ b/tests/static_map/erase_test.cu @@ -87,27 +87,31 @@ TEMPLATE_TEST_CASE_SIG( Value, Probe, CGSize), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // Robin Hood is linear-probing + single-CAS only. Unsupported variants are kept commented for + // reference / easy re-enable: double_hashing (no RH probe_distance), padded >8B slots + // (int32/int64, int64/int32), and >16B slots. 16B packable slots (int64/int64) need 128-bit + // atomics, so they sit under the #if. + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) + // (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2) +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), #if defined(CUCO_HAS_128BIT_ATOMICS) , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) +// (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), +// (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) #endif ) { diff --git a/tests/static_map/find_test.cu b/tests/static_map/find_test.cu index 5d9376309..400b7a688 100644 --- a/tests/static_map/find_test.cu +++ b/tests/static_map/find_test.cu @@ -138,37 +138,41 @@ TEMPLATE_TEST_CASE_SIG( Value, Probe, CGSize), - (int8_t, int8_t, cuco::test::probe_sequence::double_hashing, 1), - (int8_t, int8_t, cuco::test::probe_sequence::double_hashing, 2), - (int8_t, int16_t, cuco::test::probe_sequence::double_hashing, 2), - (int16_t, int16_t, cuco::test::probe_sequence::double_hashing, 1), - (int16_t, int16_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // Robin Hood is linear-probing + single-CAS only. Unsupported variants are kept commented for + // reference / easy re-enable: double_hashing (no RH probe_distance), padded >8B slots + // (int32/int64, int64/int32), and >16B slots. 16B packable slots (int64/int64) need 128-bit + // atomics, so they sit under the #if. + // (int8_t, int8_t, cuco::test::probe_sequence::double_hashing, 1), + // (int8_t, int8_t, cuco::test::probe_sequence::double_hashing, 2), + // (int8_t, int16_t, cuco::test::probe_sequence::double_hashing, 2), + // (int16_t, int16_t, cuco::test::probe_sequence::double_hashing, 1), + // (int16_t, int16_t, cuco::test::probe_sequence::double_hashing, 2), + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), (int8_t, int8_t, cuco::test::probe_sequence::linear_probing, 1), (int8_t, int8_t, cuco::test::probe_sequence::linear_probing, 2), (int8_t, int16_t, cuco::test::probe_sequence::linear_probing, 2), (int16_t, int16_t, cuco::test::probe_sequence::linear_probing, 1), (int16_t, int16_t, cuco::test::probe_sequence::linear_probing, 2), (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) + // (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2) +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), #if defined(CUCO_HAS_128BIT_ATOMICS) , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) +// (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), +// (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) #endif ) { diff --git a/tests/static_map/for_each_test.cu b/tests/static_map/for_each_test.cu index 591ad83b9..a7c6c98c6 100644 --- a/tests/static_map/for_each_test.cu +++ b/tests/static_map/for_each_test.cu @@ -85,28 +85,30 @@ TEMPLATE_TEST_CASE_SIG( Value, Probe, CGSize), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // Robin Hood is linear-probing + single-CAS only; unsupported variants (double_hashing, + // padded/oversized slots) are commented; 16B int64/int64 needs 128-bit atomics. + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) + // (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), // padded 12B slot + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2) +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), // padded 12B slot +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), // padded 12B slot +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), // padded 12B slot #if defined(CUCO_HAS_128BIT_ATOMICS) , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) #endif + // (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), // 32B slot + // (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), // oversized slot + // (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) // oversized slot ) { constexpr size_type num_keys{100}; diff --git a/tests/static_map/hash_test.cu b/tests/static_map/hash_test.cu index b4f529418..62a7f5e3c 100644 --- a/tests/static_map/hash_test.cu +++ b/tests/static_map/hash_test.cu @@ -31,7 +31,7 @@ using size_type = std::size_t; template void test_hash_function() { - using Value = int64_t; + using Value = Key; constexpr size_type num_keys{400}; @@ -63,15 +63,17 @@ void test_hash_function() REQUIRE(cuco::test::all_of(d_keys_exist.begin(), d_keys_exist.end(), cuda::std::identity{})); } +// Robin Hood is linear-probing + single-CAS only; unsupported variants (double_hashing, +// padded/oversized slots) are commented; 16B int64/int64 needs 128-bit atomics. TEMPLATE_TEST_CASE_SIG("static_map hash tests", "", ((typename Key)), - (int32_t), - (int64_t) + (int32_t) #if defined(CUCO_HAS_128BIT_ATOMICS) , - (__int128_t) + (int64_t) #endif + // (__int128_t) // 32B slot: oversized for single-CAS Robin Hood ) { test_hash_function>(); diff --git a/tests/static_map/heterogeneous_lookup_test.cu b/tests/static_map/heterogeneous_lookup_test.cu index 3b8b0d023..677b3785b 100644 --- a/tests/static_map/heterogeneous_lookup_test.cu +++ b/tests/static_map/heterogeneous_lookup_test.cu @@ -29,113 +29,126 @@ #include -// insert key type -template -struct key_pair { - T a; - T b; - - __host__ __device__ key_pair() {} - __host__ __device__ key_pair(T x) : a{x}, b{x} {} - - // Device equality operator is mandatory due to libcudacxx bug: - // https://github.com/NVIDIA/libcudacxx/issues/223 - __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; } - - __device__ explicit operator T() const noexcept { return a; } -}; - -// probe key type -template -struct key_triplet { - T a; - T b; - T c; - - __host__ __device__ key_triplet() {} - __host__ __device__ key_triplet(T x) : a{x}, b{x}, c{x} {} - - // Device equality operator is mandatory due to libcudacxx bug: - // https://github.com/NVIDIA/libcudacxx/issues/223 - __device__ bool operator==(key_triplet const& other) const - { - return a == other.a and b == other.b and c == other.c; - } -}; - -// User-defined device hasher -struct custom_hasher { - template - __device__ uint32_t operator()(CustomKey const& k) const - { - return k.a; - }; -}; - -// User-defined device key equality, Slot key always on the right-hand side -struct custom_key_equal { - template - __device__ bool operator()(InputKey const& lhs, SlotKey const& rhs) const - { - return lhs.a == rhs; - } -}; - -TEMPLATE_TEST_CASE_SIG("static_map heterogeneous lookup tests", - "", - ((typename T, int CGSize), T, CGSize), -#if defined(CUCO_HAS_INDEPENDENT_THREADS) // Key type larger than 8B only supported for sm_70 and - // up - (int64_t, 1), - (int64_t, 2), -#endif -#if defined(CUCO_HAS_128BIT_ATOMICS) - (__int128_t, 1), - (__int128_t, 2), -#endif - - (int32_t, 1), - (int32_t, 2)) -{ - using Key = T; - using Value = T; - using InsertKey = key_pair; - using ProbeKey = key_triplet; - using probe_type = cuco::double_hashing; - - auto const sentinel_key = Key{-1}; - auto const sentinel_value = Value{-1}; - - constexpr std::size_t num = 100; - constexpr std::size_t capacity = num * 2; - auto const probe = probe_type{custom_hasher{}, custom_hasher{}}; - - auto my_map = cuco::static_map{capacity, - cuco::empty_key{sentinel_key}, - cuco::empty_value{sentinel_value}, - custom_key_equal{}, - probe}; - - auto insert_pairs = cuda::make_transform_iterator( - cuda::counting_iterator(0), - cuda::proclaim_return_type>( - [] __device__(auto i) { return cuco::pair(i, i); })); - auto probe_keys = cuda::make_transform_iterator( - cuda::counting_iterator(0), - cuda::proclaim_return_type([] __device__(auto i) { return ProbeKey{i}; })); - - SECTION("All inserted keys-value pairs should be contained") - { - thrust::device_vector contained(num); - my_map.insert(insert_pairs, insert_pairs + num); - my_map.contains(probe_keys, probe_keys + num, contained.begin()); - REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), cuda::std::identity{})); - } - - SECTION("Non-inserted keys-value pairs should not be contained") - { - thrust::device_vector contained(num); - my_map.contains(probe_keys, probe_keys + num, contained.begin()); - REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), cuda::std::identity{})); - } -} +// Disabled under the experimental hard-wired Robin Hood static_map. +// This test does a HETEROGENEOUS insert (InsertKey = key_pair, but the stored Key = T), which +// the Robin Hood engine cannot support as-is: (1) robin_hood_age re-hashes the STORED key (T), but +// this test's custom_hasher only accepts the input key type (returns `k.a`) and rejects a plain T; +// (2) the displacement victim-adoption uses cuda::std::bit_cast(slot), which +// requires sizeof(in-flight insert value) == sizeof(slot), but pair, T> != pair. +// The pre-refactor code compiled this only because heterogeneous_lookup used double_hashing (the +// generic path), never the Robin Hood engine. Re-enabling needs real RH heterogeneous-insert +// support (native_value narrowing + a hasher accepting the stored key type). See +// stuff_to_raise_in_pr.md. +// +// // insert key type +// template +// struct key_pair { +// T a; +// T b; +// +// __host__ __device__ key_pair() {} +// __host__ __device__ key_pair(T x) : a{x}, b{x} {} +// +// // Device equality operator is mandatory due to libcudacxx bug: +// // https://github.com/NVIDIA/libcudacxx/issues/223 +// __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; +// } +// +// __device__ explicit operator T() const noexcept { return a; } +// }; +// +// // probe key type +// template +// struct key_triplet { +// T a; +// T b; +// T c; +// +// __host__ __device__ key_triplet() {} +// __host__ __device__ key_triplet(T x) : a{x}, b{x}, c{x} {} +// +// // Device equality operator is mandatory due to libcudacxx bug: +// // https://github.com/NVIDIA/libcudacxx/issues/223 +// __device__ bool operator==(key_triplet const& other) const +// { +// return a == other.a and b == other.b and c == other.c; +// } +// }; +// +// // User-defined device hasher +// struct custom_hasher { +// template +// __device__ uint32_t operator()(CustomKey const& k) const +// { +// return k.a; +// }; +// }; +// +// // User-defined device key equality, Slot key always on the right-hand side +// struct custom_key_equal { +// template +// __device__ bool operator()(InputKey const& lhs, SlotKey const& rhs) const +// { +// return lhs.a == rhs; +// } +// }; +// +// TEMPLATE_TEST_CASE_SIG("static_map heterogeneous lookup tests", +// "", +// ((typename T, int CGSize), T, CGSize), +// #if defined(CUCO_HAS_INDEPENDENT_THREADS) // Key type larger than 8B only supported for sm_70 +// and +// // up +// (int64_t, 1), +// (int64_t, 2), +// #endif +// #if defined(CUCO_HAS_128BIT_ATOMICS) +// (__int128_t, 1), +// (__int128_t, 2), +// #endif +// +// (int32_t, 1), +// (int32_t, 2)) +// { +// using Key = T; +// using Value = T; +// using InsertKey = key_pair; +// using ProbeKey = key_triplet; +// using probe_type = cuco::double_hashing; +// +// auto const sentinel_key = Key{-1}; +// auto const sentinel_value = Value{-1}; +// +// constexpr std::size_t num = 100; +// constexpr std::size_t capacity = num * 2; +// auto const probe = probe_type{custom_hasher{}, custom_hasher{}}; +// +// auto my_map = cuco::static_map{capacity, +// cuco::empty_key{sentinel_key}, +// cuco::empty_value{sentinel_value}, +// custom_key_equal{}, +// probe}; +// +// auto insert_pairs = cuda::make_transform_iterator( +// cuda::counting_iterator(0), +// cuda::proclaim_return_type>( +// [] __device__(auto i) { return cuco::pair(i, i); })); +// auto probe_keys = cuda::make_transform_iterator( +// cuda::counting_iterator(0), +// cuda::proclaim_return_type([] __device__(auto i) { return ProbeKey{i}; })); +// +// SECTION("All inserted keys-value pairs should be contained") +// { +// thrust::device_vector contained(num); +// my_map.insert(insert_pairs, insert_pairs + num); +// my_map.contains(probe_keys, probe_keys + num, contained.begin()); +// REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), cuda::std::identity{})); +// } +// +// SECTION("Non-inserted keys-value pairs should not be contained") +// { +// thrust::device_vector contained(num); +// my_map.contains(probe_keys, probe_keys + num, contained.begin()); +// REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), cuda::std::identity{})); +// } +// } diff --git a/tests/static_map/insert_and_find_test.cu b/tests/static_map/insert_and_find_test.cu index 17665eb2b..eda9bf146 100644 --- a/tests/static_map/insert_and_find_test.cu +++ b/tests/static_map/insert_and_find_test.cu @@ -27,81 +27,87 @@ #include -using size_type = std::size_t; - -TEMPLATE_TEST_CASE_SIG( - "static_map insert_and_find tests", - "", - ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), - Key, - Value, - Probe, - CGSize), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) -#if defined(CUCO_HAS_128BIT_ATOMICS) - , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) -#endif -) -{ -#if !defined(CUCO_HAS_INDEPENDENT_THREADS) - if constexpr (cuco::detail::is_packable>()) -#endif - { - using probe = std::conditional_t< - Probe == cuco::test::probe_sequence::linear_probing, - cuco::linear_probing>, - cuco::double_hashing, cuco::murmurhash3_32>>; - - constexpr size_type num_keys{400}; - - auto map = cuco::static_map, - cuda::thread_scope_device, - cuda::std::equal_to, - probe, - cuco::cuda_allocator, - cuco::storage<2>>{ - num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; - - auto pairs_begin = cuda::make_transform_iterator( - cuda::counting_iterator(0), - cuda::proclaim_return_type>( - [] __device__(auto i) { return cuco::pair{i, 1}; })); - - thrust::device_vector found1(num_keys); - thrust::device_vector found2(num_keys); - - thrust::device_vector inserted(num_keys); - - // insert first time, fills inserted with true - map.insert_and_find(pairs_begin, pairs_begin + num_keys, found1.begin(), inserted.begin()); - REQUIRE(cuco::test::all_of(inserted.begin(), inserted.end(), cuda::std::identity{})); - - // insert second time, fills inserted with false as keys already in map - map.insert_and_find(pairs_begin, pairs_begin + num_keys, found2.begin(), inserted.begin()); - REQUIRE(cuco::test::none_of(inserted.begin(), inserted.end(), cuda::std::identity{})); - - // both found1 and found2 should be same, as keys will be referring to same slot - REQUIRE( - cuco::test::equal(found1.begin(), found1.end(), found2.begin(), cuda::std::equal_to{})); - } -} +// Disabled under the experimental hard-wired Robin Hood static_map. +// insert_and_find is excluded under Robin Hood: its returned iterator can dangle once a later +// insert displaces the key (pointer instability). Re-enable (and trim slot types / probe scheme as +// needed) once Robin Hood support is generalized. See robin_hood_refactor_plan.md. +// +// using size_type = std::size_t; +// +// TEMPLATE_TEST_CASE_SIG( +// "static_map insert_and_find tests", +// "", +// ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), +// Key, +// Value, +// Probe, +// CGSize), +// (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), +// (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), +// (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), +// (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), +// (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), +// (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) +// #if defined(CUCO_HAS_128BIT_ATOMICS) +// , +// (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), +// (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) +// #endif +// ) +// { +// #if !defined(CUCO_HAS_INDEPENDENT_THREADS) +// if constexpr (cuco::detail::is_packable>()) +// #endif +// { +// using probe = std::conditional_t< +// Probe == cuco::test::probe_sequence::linear_probing, +// cuco::linear_probing>, +// cuco::double_hashing, cuco::murmurhash3_32>>; +// +// constexpr size_type num_keys{400}; +// +// auto map = cuco::static_map, +// cuda::thread_scope_device, +// cuda::std::equal_to, +// probe, +// cuco::cuda_allocator, +// cuco::storage<2>>{ +// num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; +// +// auto pairs_begin = cuda::make_transform_iterator( +// cuda::counting_iterator(0), +// cuda::proclaim_return_type>( +// [] __device__(auto i) { return cuco::pair{i, 1}; })); +// +// thrust::device_vector found1(num_keys); +// thrust::device_vector found2(num_keys); +// +// thrust::device_vector inserted(num_keys); +// +// // insert first time, fills inserted with true +// map.insert_and_find(pairs_begin, pairs_begin + num_keys, found1.begin(), inserted.begin()); +// REQUIRE(cuco::test::all_of(inserted.begin(), inserted.end(), cuda::std::identity{})); +// +// // insert second time, fills inserted with false as keys already in map +// map.insert_and_find(pairs_begin, pairs_begin + num_keys, found2.begin(), inserted.begin()); +// REQUIRE(cuco::test::none_of(inserted.begin(), inserted.end(), cuda::std::identity{})); +// +// // both found1 and found2 should be same, as keys will be referring to same slot +// REQUIRE( +// cuco::test::equal(found1.begin(), found1.end(), found2.begin(), +// cuda::std::equal_to{})); +// } +// } diff --git a/tests/static_map/insert_or_apply_test.cu b/tests/static_map/insert_or_apply_test.cu index 0a4d07ea3..481500b6f 100644 --- a/tests/static_map/insert_or_apply_test.cu +++ b/tests/static_map/insert_or_apply_test.cu @@ -30,261 +30,267 @@ #include -using size_type = std::size_t; - -template -void test_insert_or_apply(Map& map, size_type num_keys, size_type num_unique_keys, Init init) -{ - REQUIRE((num_keys % num_unique_keys) == 0); - - using Key = typename Map::key_type; - using Value = typename Map::mapped_type; - - // Insert pairs - auto pairs_begin = cuda::make_transform_iterator( - cuda::counting_iterator(0), - cuda::proclaim_return_type>([num_unique_keys] __device__(auto i) { - return cuco::pair{i % num_unique_keys, 1}; - })); - - auto constexpr plus_op = cuco::reduce::plus{}; - if constexpr (HasInit) { - map.insert_or_apply(pairs_begin, pairs_begin + num_keys, init, plus_op); - } else { - map.insert_or_apply(pairs_begin, pairs_begin + num_keys, plus_op); - } - - REQUIRE(map.size() == num_unique_keys); - - thrust::device_vector d_keys(num_unique_keys); - thrust::device_vector d_values(num_unique_keys); - map.retrieve_all(d_keys.begin(), d_values.begin()); - - REQUIRE(cuco::test::equal(d_values.begin(), - d_values.end(), - cuda::make_constant_iterator(num_keys / num_unique_keys), - cuda::std::equal_to{})); -} - -template -void test_insert_or_apply_shmem(Map& map, size_type num_keys, size_type num_unique_keys, Init init) -{ - REQUIRE((num_keys % num_unique_keys) == 0); - - using Key = typename Map::key_type; - using Value = typename Map::mapped_type; - - using KeyEqual = typename Map::key_equal; - using ProbingScheme = typename Map::probing_scheme_type; - using Allocator = typename Map::allocator_type; - auto constexpr cg_size = Map::cg_size; - - int32_t constexpr shmem_block_size = 1024; - - using shmem_size_type = int32_t; - - shmem_size_type constexpr cardinality_threshold = shmem_block_size; - shmem_size_type constexpr shared_map_num_elements = cardinality_threshold + shmem_block_size; - float constexpr load_factor = 0.7; - shmem_size_type constexpr shared_map_size = - static_cast((1.0 / load_factor) * shared_map_num_elements); - - using extent_type = cuco::extent; - using shared_map_type = cuco::static_map>; - - using shared_map_ref_type = typename shared_map_type::template ref_type<>; - auto constexpr valid_extent = - cuco::make_valid_extent(extent_type{}); - - // Insert pairs - auto pairs_begin = cuda::make_transform_iterator( - cuda::counting_iterator(0), - cuda::proclaim_return_type>([num_unique_keys] __device__(auto i) { - return cuco::pair{i % num_unique_keys, 1}; - })); - - auto const shmem_grid_size = cuco::detail::grid_size(num_keys, cg_size, 1, shmem_block_size); - - cuda::stream_ref stream{cudaStream_t{nullptr}}; - - // launch the shmem kernel - cuco::detail::static_map_ns:: - insert_or_apply_shmem - <<>>(pairs_begin, - num_keys, - init, - cuco::reduce::plus{}, - map.ref(cuco::op::insert_or_apply), - valid_extent); - - REQUIRE(map.size() == num_unique_keys); - - thrust::device_vector d_keys(num_unique_keys); - thrust::device_vector d_values(num_unique_keys); - map.retrieve_all(d_keys.begin(), d_values.begin()); - - REQUIRE(cuco::test::equal(d_values.begin(), - d_values.end(), - cuda::make_constant_iterator(num_keys / num_unique_keys), - cuda::std::equal_to{})); -} - -/* -TEMPLATE_TEST_CASE_SIG( - "static_map insert_or_apply tests", - "", - ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), - Key, - Value, - Probe, - CGSize), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) -#if defined(CUCO_HAS_128BIT_ATOMICS) - , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) -#endif -) -{ - constexpr size_type num_keys{10'000}; - constexpr size_type num_unique_keys{100}; - - using probe = std::conditional_t< - Probe == cuco::test::probe_sequence::linear_probing, - cuco::linear_probing>, - cuco::double_hashing, cuco::murmurhash3_32>>; - - using map_type = cuco::static_map, - cuda::thread_scope_device, - cuda::std::equal_to, - probe, - cuco::cuda_allocator, - cuco::storage<2>>; - - SECTION("sentinel equals init; has_init = true") - { - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply(map, num_keys, num_unique_keys, static_cast(0)); - } - SECTION("sentinel equals init; has_init = false") - { - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply(map, num_keys, num_unique_keys, static_cast(0)); - } - SECTION("sentinel not equals init; has_init = true") - { - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply(map, num_keys, num_unique_keys, static_cast(-1)); - } - SECTION("sentinel not equals init; has_init = false") - { - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply(map, num_keys, num_unique_keys, static_cast(-1)); - } -} - -TEMPLATE_TEST_CASE_SIG("static_map insert_or_apply all unique keys tests", - "", - ((typename Key)), - (int32_t), - (int64_t) -#if defined(CUCO_HAS_128BIT_ATOMICS) - , - (__int128_t) -#endif -) -{ - using Value = Key; - - constexpr size_type num_keys = 100; - - using map_type = cuco::static_map, - cuda::thread_scope_device, - cuda::std::equal_to, - cuco::linear_probing<2, cuco::murmurhash3_32>, - cuco::cuda_allocator, - cuco::storage<2>>; - - SECTION("sentinel equals init; has_init = true") - { - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply(map, num_keys, num_keys, static_cast(0)); - } - SECTION("sentinel equals init; has_init = false") - { - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply(map, num_keys, num_keys, static_cast(0)); - } - SECTION("sentinel not equals init; has_init = true") - { - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply(map, num_keys, num_keys, static_cast(-1)); - } - SECTION("sentinel not equals init; has_init = false") - { - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply(map, num_keys, num_keys, static_cast(-1)); - } -} -*/ - -TEMPLATE_TEST_CASE_SIG( - "static_map insert_or_apply shared memory", "", ((typename Key)), (int32_t), (int64_t)) -{ - using Value = Key; - - using map_type = cuco::static_map, - cuda::thread_scope_device, - cuda::std::equal_to, - cuco::linear_probing<1, cuco::murmurhash3_32>, - cuco::cuda_allocator, - cuco::storage<2>>; - - SECTION("duplicate keys") - { - constexpr size_type num_keys = 10'000; - constexpr size_type num_unique_keys = 100; - - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply_shmem(map, num_keys, num_unique_keys, static_cast(0)); - } - - SECTION("unique keys") - { - constexpr size_type num_keys = 10'000; - constexpr size_type num_unique_keys = num_keys; - - auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; - test_insert_or_apply_shmem(map, num_keys, num_unique_keys, static_cast(0)); - } -} +// Disabled under the experimental hard-wired Robin Hood static_map. +// insert_or_apply is excluded under Robin Hood: same value-only-atomic-vs-displacement race as +// insert_or_assign. Re-enable (and trim slot types / probe scheme as needed) once Robin Hood +// support is generalized. See robin_hood_refactor_plan.md. +// +// using size_type = std::size_t; +// +// template +// void test_insert_or_apply(Map& map, size_type num_keys, size_type num_unique_keys, Init init) +// { +// REQUIRE((num_keys % num_unique_keys) == 0); +// +// using Key = typename Map::key_type; +// using Value = typename Map::mapped_type; +// +// // Insert pairs +// auto pairs_begin = cuda::make_transform_iterator( +// cuda::counting_iterator(0), +// cuda::proclaim_return_type>([num_unique_keys] __device__(auto i) { +// return cuco::pair{i % num_unique_keys, 1}; +// })); +// +// auto constexpr plus_op = cuco::reduce::plus{}; +// if constexpr (HasInit) { +// map.insert_or_apply(pairs_begin, pairs_begin + num_keys, init, plus_op); +// } else { +// map.insert_or_apply(pairs_begin, pairs_begin + num_keys, plus_op); +// } +// +// REQUIRE(map.size() == num_unique_keys); +// +// thrust::device_vector d_keys(num_unique_keys); +// thrust::device_vector d_values(num_unique_keys); +// map.retrieve_all(d_keys.begin(), d_values.begin()); +// +// REQUIRE(cuco::test::equal(d_values.begin(), +// d_values.end(), +// cuda::make_constant_iterator(num_keys / num_unique_keys), +// cuda::std::equal_to{})); +// } +// +// template +// void test_insert_or_apply_shmem(Map& map, size_type num_keys, size_type num_unique_keys, Init +// init) +// { +// REQUIRE((num_keys % num_unique_keys) == 0); +// +// using Key = typename Map::key_type; +// using Value = typename Map::mapped_type; +// +// using KeyEqual = typename Map::key_equal; +// using ProbingScheme = typename Map::probing_scheme_type; +// using Allocator = typename Map::allocator_type; +// auto constexpr cg_size = Map::cg_size; +// +// int32_t constexpr shmem_block_size = 1024; +// +// using shmem_size_type = int32_t; +// +// shmem_size_type constexpr cardinality_threshold = shmem_block_size; +// shmem_size_type constexpr shared_map_num_elements = cardinality_threshold + shmem_block_size; +// float constexpr load_factor = 0.7; +// shmem_size_type constexpr shared_map_size = +// static_cast((1.0 / load_factor) * shared_map_num_elements); +// +// using extent_type = cuco::extent; +// using shared_map_type = cuco::static_map>; +// +// using shared_map_ref_type = typename shared_map_type::template ref_type<>; +// auto constexpr valid_extent = +// cuco::make_valid_extent(extent_type{}); +// +// // Insert pairs +// auto pairs_begin = cuda::make_transform_iterator( +// cuda::counting_iterator(0), +// cuda::proclaim_return_type>([num_unique_keys] __device__(auto i) { +// return cuco::pair{i % num_unique_keys, 1}; +// })); +// +// auto const shmem_grid_size = cuco::detail::grid_size(num_keys, cg_size, 1, shmem_block_size); +// +// cuda::stream_ref stream{cudaStream_t{nullptr}}; +// +// // launch the shmem kernel +// cuco::detail::static_map_ns:: +// insert_or_apply_shmem +// <<>>(pairs_begin, +// num_keys, +// init, +// cuco::reduce::plus{}, +// map.ref(cuco::op::insert_or_apply), +// valid_extent); +// +// REQUIRE(map.size() == num_unique_keys); +// +// thrust::device_vector d_keys(num_unique_keys); +// thrust::device_vector d_values(num_unique_keys); +// map.retrieve_all(d_keys.begin(), d_values.begin()); +// +// REQUIRE(cuco::test::equal(d_values.begin(), +// d_values.end(), +// cuda::make_constant_iterator(num_keys / num_unique_keys), +// cuda::std::equal_to{})); +// } +// +// /* +// TEMPLATE_TEST_CASE_SIG( +// "static_map insert_or_apply tests", +// "", +// ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), +// Key, +// Value, +// Probe, +// CGSize), +// (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), +// (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), +// (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), +// (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), +// (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), +// (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) +// #if defined(CUCO_HAS_128BIT_ATOMICS) +// , +// (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), +// (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) +// #endif +// ) +// { +// constexpr size_type num_keys{10'000}; +// constexpr size_type num_unique_keys{100}; +// +// using probe = std::conditional_t< +// Probe == cuco::test::probe_sequence::linear_probing, +// cuco::linear_probing>, +// cuco::double_hashing, cuco::murmurhash3_32>>; +// +// using map_type = cuco::static_map, +// cuda::thread_scope_device, +// cuda::std::equal_to, +// probe, +// cuco::cuda_allocator, +// cuco::storage<2>>; +// +// SECTION("sentinel equals init; has_init = true") +// { +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply(map, num_keys, num_unique_keys, static_cast(0)); +// } +// SECTION("sentinel equals init; has_init = false") +// { +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply(map, num_keys, num_unique_keys, static_cast(0)); +// } +// SECTION("sentinel not equals init; has_init = true") +// { +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply(map, num_keys, num_unique_keys, static_cast(-1)); +// } +// SECTION("sentinel not equals init; has_init = false") +// { +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply(map, num_keys, num_unique_keys, static_cast(-1)); +// } +// } +// +// TEMPLATE_TEST_CASE_SIG("static_map insert_or_apply all unique keys tests", +// "", +// ((typename Key)), +// (int32_t), +// (int64_t) +// #if defined(CUCO_HAS_128BIT_ATOMICS) +// , +// (__int128_t) +// #endif +// ) +// { +// using Value = Key; +// +// constexpr size_type num_keys = 100; +// +// using map_type = cuco::static_map, +// cuda::thread_scope_device, +// cuda::std::equal_to, +// cuco::linear_probing<2, cuco::murmurhash3_32>, +// cuco::cuda_allocator, +// cuco::storage<2>>; +// +// SECTION("sentinel equals init; has_init = true") +// { +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply(map, num_keys, num_keys, static_cast(0)); +// } +// SECTION("sentinel equals init; has_init = false") +// { +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply(map, num_keys, num_keys, static_cast(0)); +// } +// SECTION("sentinel not equals init; has_init = true") +// { +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply(map, num_keys, num_keys, static_cast(-1)); +// } +// SECTION("sentinel not equals init; has_init = false") +// { +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply(map, num_keys, num_keys, static_cast(-1)); +// } +// } +// */ +// +// TEMPLATE_TEST_CASE_SIG( +// "static_map insert_or_apply shared memory", "", ((typename Key)), (int32_t), (int64_t)) +// { +// using Value = Key; +// +// using map_type = cuco::static_map, +// cuda::thread_scope_device, +// cuda::std::equal_to, +// cuco::linear_probing<1, cuco::murmurhash3_32>, +// cuco::cuda_allocator, +// cuco::storage<2>>; +// +// SECTION("duplicate keys") +// { +// constexpr size_type num_keys = 10'000; +// constexpr size_type num_unique_keys = 100; +// +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply_shmem(map, num_keys, num_unique_keys, static_cast(0)); +// } +// +// SECTION("unique keys") +// { +// constexpr size_type num_keys = 10'000; +// constexpr size_type num_unique_keys = num_keys; +// +// auto map = map_type{num_keys, cuco::empty_key{-1}, cuco::empty_value{0}}; +// test_insert_or_apply_shmem(map, num_keys, num_unique_keys, static_cast(0)); +// } +// } diff --git a/tests/static_map/insert_or_assign_test.cu b/tests/static_map/insert_or_assign_test.cu index 2cccd62ad..24d688d6d 100644 --- a/tests/static_map/insert_or_assign_test.cu +++ b/tests/static_map/insert_or_assign_test.cu @@ -28,96 +28,102 @@ #include -using size_type = std::size_t; - -template -void test_insert_or_assign(Map& map, size_type num_keys) -{ - using Key = typename Map::key_type; - using Value = typename Map::mapped_type; - - // Insert pairs - auto pairs_begin = cuda::make_transform_iterator( - cuda::counting_iterator(0), - cuda::proclaim_return_type>( - [] __device__(auto i) { return cuco::pair{i, i}; })); - - auto const initial_size = map.insert(pairs_begin, pairs_begin + num_keys); - REQUIRE(initial_size == num_keys); // all keys should be inserted - - // Query pairs have the same keys but different payloads - auto query_pairs_begin = cuda::make_transform_iterator( - cuda::counting_iterator(0), - cuda::proclaim_return_type>( - [] __device__(auto i) { return cuco::pair(i, i * 2); })); - - map.insert_or_assign(query_pairs_begin, query_pairs_begin + num_keys); - - auto const updated_size = map.size(); - // all keys are present in the map so the size shouldn't change - REQUIRE(updated_size == initial_size); - - thrust::device_vector d_keys(num_keys); - thrust::device_vector d_values(num_keys); - map.retrieve_all(d_keys.begin(), d_values.begin()); - - auto gold_values_begin = cuda::make_transform_iterator( - cuda::counting_iterator(0), - cuda::proclaim_return_type([] __device__(auto i) { return i * 2; })); - - thrust::sort(thrust::device, d_values.begin(), d_values.end()); - REQUIRE(cuco::test::equal( - d_values.begin(), d_values.end(), gold_values_begin, cuda::std::equal_to{})); -} - -TEMPLATE_TEST_CASE_SIG( - "static_map insert_or_assign tests", - "", - ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), - Key, - Value, - Probe, - CGSize), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) -#if defined(CUCO_HAS_128BIT_ATOMICS) - , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) -#endif -) -{ - constexpr size_type num_keys{400}; - - using probe = std::conditional_t< - Probe == cuco::test::probe_sequence::linear_probing, - cuco::linear_probing>, - cuco::double_hashing, cuco::murmurhash3_32>>; - - auto map = cuco::static_map, - cuda::thread_scope_device, - cuda::std::equal_to, - probe, - cuco::cuda_allocator, - cuco::storage<2>>{ - num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; - - test_insert_or_assign(map, num_keys); -} +// Disabled under the experimental hard-wired Robin Hood static_map. +// insert_or_assign is excluded under Robin Hood: a value-only-atomic update races RH displacement; +// it needs a whole-slot-CAS update API (see stuff_to_raise_in_pr.md). Re-enable (and trim slot +// types / probe scheme as needed) once Robin Hood support is generalized. See +// robin_hood_refactor_plan.md. +// +// using size_type = std::size_t; +// +// template +// void test_insert_or_assign(Map& map, size_type num_keys) +// { +// using Key = typename Map::key_type; +// using Value = typename Map::mapped_type; +// +// // Insert pairs +// auto pairs_begin = cuda::make_transform_iterator( +// cuda::counting_iterator(0), +// cuda::proclaim_return_type>( +// [] __device__(auto i) { return cuco::pair{i, i}; })); +// +// auto const initial_size = map.insert(pairs_begin, pairs_begin + num_keys); +// REQUIRE(initial_size == num_keys); // all keys should be inserted +// +// // Query pairs have the same keys but different payloads +// auto query_pairs_begin = cuda::make_transform_iterator( +// cuda::counting_iterator(0), +// cuda::proclaim_return_type>( +// [] __device__(auto i) { return cuco::pair(i, i * 2); })); +// +// map.insert_or_assign(query_pairs_begin, query_pairs_begin + num_keys); +// +// auto const updated_size = map.size(); +// // all keys are present in the map so the size shouldn't change +// REQUIRE(updated_size == initial_size); +// +// thrust::device_vector d_keys(num_keys); +// thrust::device_vector d_values(num_keys); +// map.retrieve_all(d_keys.begin(), d_values.begin()); +// +// auto gold_values_begin = cuda::make_transform_iterator( +// cuda::counting_iterator(0), +// cuda::proclaim_return_type([] __device__(auto i) { return i * 2; })); +// +// thrust::sort(thrust::device, d_values.begin(), d_values.end()); +// REQUIRE(cuco::test::equal( +// d_values.begin(), d_values.end(), gold_values_begin, cuda::std::equal_to{})); +// } +// +// TEMPLATE_TEST_CASE_SIG( +// "static_map insert_or_assign tests", +// "", +// ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), +// Key, +// Value, +// Probe, +// CGSize), +// (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), +// (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), +// (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), +// (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), +// (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), +// (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), +// (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) +// #if defined(CUCO_HAS_128BIT_ATOMICS) +// , +// (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), +// (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), +// (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) +// #endif +// ) +// { +// constexpr size_type num_keys{400}; +// +// using probe = std::conditional_t< +// Probe == cuco::test::probe_sequence::linear_probing, +// cuco::linear_probing>, +// cuco::double_hashing, cuco::murmurhash3_32>>; +// +// auto map = cuco::static_map, +// cuda::thread_scope_device, +// cuda::std::equal_to, +// probe, +// cuco::cuda_allocator, +// cuco::storage<2>>{ +// num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; +// +// test_insert_or_assign(map, num_keys); +// } diff --git a/tests/static_map/key_sentinel_test.cu b/tests/static_map/key_sentinel_test.cu index edfa499b6..ee75feba0 100644 --- a/tests/static_map/key_sentinel_test.cu +++ b/tests/static_map/key_sentinel_test.cu @@ -36,11 +36,16 @@ struct custom_equals { TEMPLATE_TEST_CASE_SIG("static_map key sentinel tests", "", ((typename T), T), - (int32_t), - (int64_t) -#if defined(CUCO_HAS_128BIT_ATOMICS) + // Robin Hood is linear-probing + single-CAS only. Unsupported variants are + // kept commented for reference / easy re-enable: padded >8B slots and >16B + // slots. 16B packable slots (int64) need 128-bit atomics, so they sit under + // the #if. + (int32_t) +#if defined(CUCO_HAS_128BIT_ATOMICS) // int64/int64 is a 16B slot -> single-CAS only with 128-bit + // atomics , - (__int128_t) + (int64_t) +// (__int128_t) #endif ) { diff --git a/tests/static_map/rehash_test.cu b/tests/static_map/rehash_test.cu index d4b7ffa17..7db8117c5 100644 --- a/tests/static_map/rehash_test.cu +++ b/tests/static_map/rehash_test.cu @@ -23,8 +23,9 @@ TEST_CASE("static_map rehash test", "") { - using key_type = int; - using mapped_type = long; + using key_type = int; + // int (not long): Robin Hood needs a single-CAS slot, and pair is a padded 16B slot. + using mapped_type = int; constexpr std::size_t num_keys{400}; constexpr std::size_t num_erased_keys{100}; diff --git a/tests/static_map/retrieve_if_test.cu b/tests/static_map/retrieve_if_test.cu index fe4441822..ad65699ed 100644 --- a/tests/static_map/retrieve_if_test.cu +++ b/tests/static_map/retrieve_if_test.cu @@ -107,15 +107,18 @@ __global__ void test_retrieve_if_all_true_kernel( *atomic_counter); } -TEMPLATE_TEST_CASE_SIG("static_map retrieve_if", - "", - ((typename Key, typename T), Key, T), - (int32_t, int32_t), - (int64_t, int64_t) +// Robin Hood is linear-probing + single-CAS only; unsupported variants (double_hashing, +// padded/oversized slots) are commented; 16B int64/int64 needs 128-bit atomics. +TEMPLATE_TEST_CASE_SIG( + "static_map retrieve_if", + "", + ((typename Key, typename T), Key, T), + (int32_t, int32_t) #if defined(CUCO_HAS_128BIT_ATOMICS) - , - (__int128_t, __int128_t) + , + (int64_t, int64_t) #endif + // (__int128_t, __int128_t) // 32B slot: oversized for single-CAS Robin Hood ) { constexpr size_type num_keys{400}; diff --git a/tests/static_map/retrieve_test.cu b/tests/static_map/retrieve_test.cu index a45dcb6e1..6e3f28fe3 100644 --- a/tests/static_map/retrieve_test.cu +++ b/tests/static_map/retrieve_test.cu @@ -105,28 +105,30 @@ TEMPLATE_TEST_CASE_SIG( Value, Probe, CGSize), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // Robin Hood is linear-probing + single-CAS only; unsupported variants (double_hashing, + // padded/oversized slots) are commented; 16B int64/int64 needs 128-bit atomics. + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), + // (int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2), + // (int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2), (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), - (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), - (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) + // (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), // padded 12B slot + (int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2) +// (int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2), // padded 12B slot +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1), // padded 12B slot +// (int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2), // padded 12B slot #if defined(CUCO_HAS_128BIT_ATOMICS) , - (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), - (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2) #endif + // (__int128_t, __int128_t, cuco::test::probe_sequence::double_hashing, 2), // 32B slot + // (__int128_t, int64_t, cuco::test::probe_sequence::double_hashing, 1), // oversized slot + // (int32_t, __int128_t, cuco::test::probe_sequence::linear_probing, 2) // oversized slot ) { constexpr size_type num_keys{1'000}; diff --git a/tests/static_map/robin_hood_test.cu b/tests/static_map/robin_hood_test.cu new file mode 100644 index 000000000..aea8786ab --- /dev/null +++ b/tests/static_map/robin_hood_test.cu @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2024-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +// Per-probe-step Robin Hood layout check. The unit is the *stride group* of `cg_size * bucket_size` +// contiguous slots that one probing step examines -- a single bucket for scalar probing, the whole +// cooperative-group window for CG probing. Within a stride group the slot order is free (the probe +// step distance is identical for every slot in it, so the intra-group offset cancels in all +// comparisons), so the invariant is only meaningful *between* groups. For each occupied group `g` +// (with predecessor `pg`), the resident probe-step distances ("ages") must satisfy: +// +// (1) Contiguity. If `g` holds any overflowed resident (distance >= 1), `pg` must be full -- +// otherwise that resident would have stopped in `pg`'s free slot instead of probing past it. +// (2) Balance. No resident of `pg` may be more than one probing step *richer* than the poorest +// resident of `g` (`min_age(pg) >= max_age(g) - 1`) -- otherwise the poorest resident of `g` +// should have displaced it. This is the property that distinguishes Robin Hood from plain +// linear probing, and (via condition 1) it inductively forces the whole home-to-position run +// to be full. +// +// `probe_distance` is reused here -- it is exercised independently by the probe-distance test +// below, so a bug in *insert* (a layout that violates the invariant) is still caught. +template +__global__ void robin_hood_invariant_kernel(Ref ref, int* violations) +{ + using size_type = typename Ref::size_type; + constexpr int bs = Ref::bucket_size; + constexpr int stride = Ref::cg_size * Ref::bucket_size; + auto const storage_ref = ref.storage_ref(); + auto const slots = storage_ref.data(); + auto const num_groups = storage_ref.capacity() / stride; + auto const extent = storage_ref.extent(); + auto const empty_key = ref.empty_key_sentinel(); + auto const erased_key = ref.erased_key_sentinel(); + auto const scheme = ref.probing_scheme(); + + for (size_type g = blockIdx.x * blockDim.x + threadIdx.x; g < num_groups; + g += gridDim.x * blockDim.x) { + int occupied_g = 0; + size_type max_age_g = 0; + for (int s = 0; s < stride; ++s) { + auto const slot = slots[g * stride + s]; + if (slot.first != empty_key) { // tombstones count as residents (erase enabled => != empty) + ++occupied_g; + // A tombstone keeps its age in its payload; a live key's age is its probe distance. + auto const age = (slot.first == erased_key) + ? static_cast(slot.second) + : cuco::detail::robin_hood::probe_distance( + scheme, slot.first, static_cast(g * stride + s), extent); + if (age > max_age_g) { max_age_g = age; } + } + } + if (occupied_g == 0) { continue; } + + size_type const pg = (g + num_groups - 1) % num_groups; + int occupied_p = 0; + size_type min_age_p = 0; + for (int s = 0; s < stride; ++s) { + auto const slot = slots[pg * stride + s]; + if (slot.first != empty_key) { + auto const age = (slot.first == erased_key) + ? static_cast(slot.second) + : cuco::detail::robin_hood::probe_distance( + scheme, slot.first, static_cast(pg * stride + s), extent); + if (occupied_p == 0 || age < min_age_p) { min_age_p = age; } + ++occupied_p; + } + } + + if (max_age_g >= 1 && occupied_p < stride) { atomicAdd(violations, 1); } // (1) + if (occupied_p > 0 && min_age_p + 1 < max_age_g) { atomicAdd(violations, 1); } // (2) + } +} + +// Asserts that a populated Robin Hood `map` satisfies the per-bucket layout invariant above. The +// `find` ref reaches the live storage pointer, the probing scheme, and the sentinels that the +// kernel needs. +template +void check_robin_hood_invariant(Map& map) +{ + auto const ref = map.ref(cuco::op::find); + + thrust::device_vector d_violations(1, 0); + auto constexpr block_size = 128; + auto const grid_size = (map.capacity() + block_size - 1) / block_size; + robin_hood_invariant_kernel<<>>( + ref, thrust::raw_pointer_cast(d_violations.data())); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + + REQUIRE(d_violations[0] == 0); +} + +// Walks each lane's probe iterator and records, at every step, the probe distance reported for the +// slot that lane is visiting. Because the resident under test is `key` itself, the slot visited at +// step `i` is at probe distance `i` from `key`'s home -- for every lane. Recording one column per +// lane lets the host check both that `probe_distance` inverts `make_iterator` and that it strips +// the per-lane intra-stride offset (so all lanes at a given step agree). +template +__global__ void generate_cg_probe_distance_sequence(Key key, + Extent upper_bound, + std::size_t seq_length, + OutputIt out_seq) +{ + auto constexpr cg_size = ProbingScheme::cg_size; + + auto const tid = blockIdx.x * blockDim.x + threadIdx.x; + auto probing_scheme = ProbingScheme{}; + + if (tid < cg_size) { + auto const tile = + cooperative_groups::tiled_partition( + cooperative_groups::this_thread_block()); + + auto iter = probing_scheme.template make_iterator(tile, key, upper_bound); + + for (std::size_t i = 0; i < seq_length; ++i) { + out_seq[i * cg_size + tile.thread_rank()] = + cuco::detail::robin_hood::probe_distance( + probing_scheme, key, *iter, upper_bound); + ++iter; + } + } +} + +} // namespace + +TEMPLATE_TEST_CASE_SIG( + "static_map robin_hood probe_distance inverts make_iterator", + "", + ((typename Key, int32_t CGSize, int32_t BucketSize), Key, CGSize, BucketSize), + (int32_t, 1, 1), + (int32_t, 4, 1), + (int32_t, 8, 1), + (int32_t, 8, 2), + (int64_t, 4, 1), + (int64_t, 8, 2)) +{ + // Robin Hood wraps a linear probe sequence; `probe_distance` is its inverse. For `key`'s own + // probe sequence, the slot visited at step `i` must report probe distance `i`. + using probe = cuco::linear_probing>; + + // A deliberately small capacity, so the probe sequence wraps around the table within + // `seq_length` steps -- this exercises the modular-subtraction (wrap) path in `probe_distance`. + auto const upper_bound = + cuco::make_valid_extent>(cuco::extent{64}); + + // Probe distance is measured in whole probing steps and lives in `[0, num_buckets)`, where one + // step spans the full `cg_size * bucket_size` stride. Taking `seq_length` past `num_buckets` + // guarantees the walk wraps at least once. + auto const capacity = static_cast(upper_bound); + auto const num_buckets = capacity / (CGSize * BucketSize); + auto const seq_length = num_buckets + 3; + constexpr Key key{42}; + + thrust::device_vector distances(seq_length * CGSize); + generate_cg_probe_distance_sequence + <<<1, CGSize>>>(key, upper_bound, seq_length, distances.begin()); + CUCO_CUDA_TRY(cudaDeviceSynchronize()); + + // Under wrap, the slot visited at step `i` sits at probe distance `i mod num_buckets`. + for (std::size_t i = 0; i < seq_length; ++i) { + for (std::int32_t r = 0; r < CGSize; ++r) { + REQUIRE(distances[i * CGSize + r] == i % num_buckets); + } + } +} + +TEMPLATE_TEST_CASE_SIG( + "static_map robin_hood high-load-factor invariant", + "", + ((typename Key, typename Value, int CGSize, int BucketSize), Key, Value, CGSize, BucketSize), + (int32_t, int32_t, 1, 1), + (int32_t, int32_t, 1, 2) +#if defined(CUCO_HAS_128BIT_ATOMICS) + , + (int64_t, int64_t, 1, 1), + (int64_t, int64_t, 1, 2) +#endif +) +{ + // Robin Hood is most meaningfully exercised when the table is nearly full: the displacement + // chains are long and the layout invariant becomes load-bearing. Size the table for ~95% load. + using size_type = std::int32_t; + + constexpr size_type num_keys = 100'000; + + using extent_type = cuco::extent; + using probe = cuco::linear_probing>; + using map_type = cuco::static_map, + probe, + cuco::cuda_allocator, + cuco::storage>; + + // High load factor: size the table for ~95% occupancy so Robin Hood is exercised near-full. + auto map = + map_type{extent_type{num_keys}, 0.95, cuco::empty_key{-1}, cuco::empty_value{-1}}; + + auto keys_begin = cuda::counting_iterator{0}; + auto pairs_begin = cuda::make_transform_iterator( + cuda::make_counting_iterator(0), + cuda::proclaim_return_type>( + [] __device__(auto i) { return cuco::pair{i, i}; })); + + map.insert(pairs_begin, pairs_begin + num_keys); + REQUIRE(map.size() == num_keys); + + // The hand-built Robin Hood layout must be structurally valid after a near-full insert. + check_robin_hood_invariant(map); + + // Every inserted unique key must be found and contained. + thrust::device_vector d_contained(num_keys); + map.contains(keys_begin, keys_begin + num_keys, d_contained.begin()); + REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), cuda::std::identity{})); + + thrust::device_vector d_values(num_keys); + map.find(keys_begin, keys_begin + num_keys, d_values.begin()); + auto zip = thrust::make_zip_iterator(cuda::std::tuple{d_values.begin(), keys_begin}); + REQUIRE(cuco::test::all_of( + zip, zip + num_keys, cuda::proclaim_return_type([] __device__(auto const& p) { + return cuda::std::get<0>(p) == cuda::std::get<1>(p); + }))); +} diff --git a/tests/static_map/shared_memory_test.cu b/tests/static_map/shared_memory_test.cu index 465ffbe55..a7fa749c5 100644 --- a/tests/static_map/shared_memory_test.cu +++ b/tests/static_map/shared_memory_test.cu @@ -70,13 +70,18 @@ __global__ void shared_memory_test_kernel(Ref* maps, TEMPLATE_TEST_CASE_SIG("static_map shared memory tests", "", ((typename Key, typename Value), Key, Value), - (int32_t, int32_t), - (int32_t, int64_t), - (int64_t, int32_t), - (int64_t, int64_t) -#if defined(CUCO_HAS_128BIT_ATOMICS) + // Robin Hood is linear-probing + single-CAS only. Unsupported variants are + // kept commented for reference / easy re-enable: padded >8B slots + // (int32/int64, int64/int32) and >16B slots. 16B packable slots + // (int64/int64) need 128-bit atomics, so they sit under the #if. + (int32_t, int32_t) +// (int32_t, int64_t), +// (int64_t, int32_t), +#if defined(CUCO_HAS_128BIT_ATOMICS) // int64/int64 is a 16B slot -> single-CAS only with 128-bit + // atomics , - (__int128_t, __int128_t) + (int64_t, int64_t) +// (__int128_t, __int128_t) #endif ) { diff --git a/tests/static_map/stream_test.cu b/tests/static_map/stream_test.cu index 6c1701d98..99731493b 100644 --- a/tests/static_map/stream_test.cu +++ b/tests/static_map/stream_test.cu @@ -32,13 +32,18 @@ TEMPLATE_TEST_CASE_SIG("static_map: unique sequence of keys on given stream", "", ((typename Key, typename Value), Key, Value), - (int32_t, int32_t), - (int32_t, int64_t), - (int64_t, int32_t), - (int64_t, int64_t) -#if defined(CUCO_HAS_128BIT_ATOMICS) + // Robin Hood is linear-probing + single-CAS only. Unsupported variants are + // kept commented for reference / easy re-enable: padded >8B slots + // (int32/int64, int64/int32) and >16B slots. 16B packable slots + // (int64/int64) need 128-bit atomics, so they sit under the #if. + (int32_t, int32_t) +// (int32_t, int64_t), +// (int64_t, int32_t), +#if defined(CUCO_HAS_128BIT_ATOMICS) // int64/int64 is a 16B slot -> single-CAS only with 128-bit + // atomics , - (__int128_t, __int128_t) + (int64_t, int64_t) +// (__int128_t, __int128_t) #endif ) {