Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ if(NOT BUILD_CPU_ONLY)
"$<$<COMPILE_LANGUAGE:CUDA>:${CUVS_CUDA_FLAGS}>"
)
target_compile_features(jit_lto_kernel_usage_requirements INTERFACE cuda_std_20)
target_link_libraries(jit_lto_kernel_usage_requirements INTERFACE rmm::rmm raft::raft CCCL::CCCL)
target_link_libraries(
jit_lto_kernel_usage_requirements INTERFACE rmm::rmm raft::raft CCCL::CCCL cuco::cuco
)

block(PROPAGATE jit_lto_files)
set(jit_lto_files)
Expand Down
1 change: 1 addition & 0 deletions cpp/include/cuvs/detail/jit_lto/common_fragments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct tag_i8 {};
struct tag_u8 {};
struct tag_filter_none {};
struct tag_filter_bitset {};
struct tag_filter_bloom_filter {};
struct tag_filter_udf {};

struct tag_bitset_u32 {};
Expand Down
28 changes: 27 additions & 1 deletion cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ namespace filtering {
* @{
*/

enum class FilterType { None, Bitmap, Bitset, UDF };
enum class FilterType { None, Bitmap, Bitset, Bloom, UDF };

struct base_filter {
~base_filter() = default;
Expand Down Expand Up @@ -617,6 +617,32 @@ struct bitset_filter : public base_filter {
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
};

/**
* @brief Filter CAGRA candidates with a global @c cuco bloom filter over the index.
*
* Build the filter once on the host with bulk @c add() over the allowed dataset row ids, obtain a
* @c ref() from the owning @c cuco::bloom_filter, copy that ref to device memory, and pass the
* device pointer as @c filter_data. The linked JIT-LTO fragment probes the same filter for every
* query and candidate, similar to @ref bitset_filter but with probabilistic membership tests.
*
* Bloom filters have no false negatives: if a row was inserted, @c contains returns @c true. False
* positives are possible, so highly selective predicates may still need a bitset or UDF for exact
* filtering.
*/
struct bloom_filter : public base_filter {
void* filter_data{nullptr};
float filtering_rate{-1.0f};

bloom_filter() = default;

explicit bloom_filter(void* filter_data, float filtering_rate = -1.0f)
: filter_data(filter_data), filtering_rate(filtering_rate)
{
}

FilterType get_filter_type() const override { return FilterType::Bloom; }
};

/**
* @brief JIT-LTO user-defined filter predicate.
*
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,25 @@ void search(raft::resources const& res,
} catch (const std::bad_cast&) {
}

try {
auto& sample_filter =
dynamic_cast<const cuvs::neighbors::filtering::bloom_filter&>(sample_filter_ref);
search_params params_copy = params;
if (params.filtering_rate < 0.0) {
const float min_filtering_rate = 0.0f;
const float max_filtering_rate = 0.999f;
params_copy.filtering_rate =
sample_filter.filtering_rate < 0.0f
? 0.0f
: std::min(std::max(sample_filter.filtering_rate, min_filtering_rate),
max_filtering_rate);
}
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, decltype(sample_filter_copy), OutputIdxT>(
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
} catch (const std::bad_cast&) {
}

try {
auto& sample_filter =
dynamic_cast<const cuvs::neighbors::filtering::udf_filter&>(sample_filter_ref);
Expand Down
10 changes: 9 additions & 1 deletion cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ template <typename bitset_t, typename index_t>
struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter<bitset_t, index_t>>
: std::true_type {};

template <typename T>
struct is_bloom_filter : std::false_type {};

template <>
struct is_bloom_filter<::cuvs::neighbors::filtering::bloom_filter> : std::true_type {};

template <typename T>
struct is_udf_filter : std::false_type {};

Expand Down Expand Up @@ -177,6 +183,8 @@ void fill_cagra_sample_filter(cagra_sample_filter<SourceIndexT>& out,
using DecayedFilter = std::decay_t<FilterT>;
if constexpr (is_bitset_filter<DecayedFilter>::value) {
out.filter_data = make_cagra_bitset_filter_payload<SourceIndexT>(filter, stream);
} else if constexpr (is_bloom_filter<DecayedFilter>::value) {
out.filter_data = filter.filter_data;
} else if constexpr (is_udf_filter<DecayedFilter>::value) {
out.filter_data = filter.filter_data;
}
Expand All @@ -199,7 +207,7 @@ template <typename FilterT>
void* cagra_filter_data_ptr(const FilterT& filter)
{
using DecayedFilter = std::decay_t<FilterT>;
if constexpr (is_udf_filter<DecayedFilter>::value) {
if constexpr (is_bloom_filter<DecayedFilter>::value || is_udf_filter<DecayedFilter>::value) {
return filter.filter_data;
} else if constexpr (requires { filter.filter; }) {
return cagra_filter_data_ptr(filter.filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#include "../../sample_filter_data.cuh"

#include <cuco/bloom_filter_ref.cuh>

#include <raft/core/bitset.cuh>

#include <cstdint>
Expand Down Expand Up @@ -38,4 +40,15 @@ __device__ bool sample_filter_bitset_impl(uint32_t /*query_id*/,
return view.test(node_id);
}

template <typename SourceIndexT, typename Key = std::uint32_t>
__device__ bool sample_filter_bloom_filter_impl(uint32_t /*query_id*/,
SourceIndexT node_id,
void* filter_data)
{
if (filter_data == nullptr) { return true; }

auto* data = static_cast<bloom_filter_data_t<Key>*>(filter_data);
return data->filter.contains(static_cast<Key>(node_id));
}

} // namespace cuvs::neighbors::detail
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"filter_name": ["none", "bitset"],
"filter_name": ["none", "bitset", "bloom_filter"],
"_bitset": [
{
"bitset_type": "uint32_t",
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace {
using data_t = @data_type@;
using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>;
using bloom_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
cuvs::neighbors::filtering::bloom_filter>;
using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
cuvs::neighbors::filtering::udf_filter>;

Expand All @@ -22,6 +24,7 @@ instantiate_kernel_selection(data_t,
float,
cuvs::neighbors::filtering::none_sample_filter);
instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t);
instantiate_kernel_selection(data_t, uint32_t, float, bloom_filter_t);
instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t);

} // namespace cuvs::neighbors::cagra::detail::multi_cta_search
3 changes: 3 additions & 0 deletions cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace {
using data_t = @data_type@;
using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>;
using bloom_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
cuvs::neighbors::filtering::bloom_filter>;
using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<
cuvs::neighbors::filtering::udf_filter>;

Expand All @@ -22,6 +24,7 @@ instantiate_kernel_selection(data_t,
float,
cuvs::neighbors::filtering::none_sample_filter);
instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t);
instantiate_kernel_selection(data_t, uint32_t, float, bloom_filter_t);
instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t);

} // namespace cuvs::neighbors::cagra::detail::single_cta_search
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ std::uint64_t cagra_sample_filter_type_id(const SampleFilterT& sample_filter)
{
using DecayedFilter = std::decay_t<SampleFilterT>;
if constexpr (is_udf_filter<DecayedFilter>::value) {
return 3;
} else if constexpr (is_bloom_filter<DecayedFilter>::value) {
return 2;
} else if constexpr (is_bitset_filter<DecayedFilter>::value) {
return 1;
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ struct sample_filter_jit_tag {
using namespace cuvs::neighbors::filtering;
if constexpr (std::is_same_v<U, none_sample_filter>) {
return cuvs::neighbors::detail::tag_filter_none{};
} else if constexpr (is_bloom_filter<U>::value) {
return cuvs::neighbors::detail::tag_filter_bloom_filter{};
} else if constexpr (is_udf_filter<U>::value) {
return cuvs::neighbors::detail::tag_filter_udf{};
} else if constexpr (requires { std::declval<U>().filter; }) {
Expand All @@ -109,6 +111,8 @@ struct sample_filter_jit_tag {
std::is_same_v<std::decay_t<InnerFilter>,
bitset_filter<uint32_t, uint32_t>>) {
return cuvs::neighbors::detail::tag_filter_bitset{};
} else if constexpr (is_bloom_filter<std::decay_t<InnerFilter>>::value) {
return cuvs::neighbors::detail::tag_filter_bloom_filter{};
} else if constexpr (is_udf_filter<std::decay_t<InnerFilter>>::value) {
return cuvs::neighbors::detail::tag_filter_udf{};
} else {
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/neighbors/detail/sample_filter_data.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#pragma once

#include <cuco/bloom_filter.cuh>

#include <cstddef>
#include <cstdint>

Expand All @@ -21,4 +23,12 @@ struct bitset_filter_data_t {
SourceIndexT original_nbits{};
};

/// Global cuco bloom filter ref for linked @c sample_filter in CAGRA JIT LTO.
template <typename Key = std::uint32_t>
struct bloom_filter_data_t {
using ref_type = typename cuco::bloom_filter<Key>::ref_type<>;

ref_type filter{};
};

} // namespace cuvs::neighbors::detail
10 changes: 10 additions & 0 deletions examples/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ find_package(Threads)
rapids_cpm_init()
set(BUILD_CUVS_C_LIBRARY OFF)
include(../cmake/thirdparty/get_cuvs.cmake)
include(${rapids-cmake-dir}/cpm/cuco.cmake)
rapids_cpm_cuco()

# -------------- compile tasks ----------------- #
add_executable(BRUTE_FORCE_EXAMPLE src/brute_force_bitmap.cu)
add_executable(CAGRA_EXAMPLE src/cagra_example.cu)
add_executable(CAGRA_FILTER_UDF_EXAMPLE src/cagra_filter_udf_example.cu)
add_executable(CAGRA_BLOOM_FILTER_EXAMPLE src/cagra_bloom_filter_example.cu)
add_executable(CAGRA_FILTER_BENCHMARK src/cagra_filter_benchmark.cu)
add_executable(CAGRA_HNSW_ACE_EXAMPLE src/cagra_hnsw_ace_example.cu)
add_executable(CAGRA_PERSISTENT_EXAMPLE src/cagra_persistent_example.cu)
add_executable(DYNAMIC_BATCHING_EXAMPLE src/dynamic_batching_example.cu)
Expand All @@ -48,6 +52,12 @@ target_link_libraries(CAGRA_EXAMPLE PRIVATE cuvs::cuvs $<TARGET_NAME_IF_EXISTS:c
target_link_libraries(
CAGRA_FILTER_UDF_EXAMPLE PRIVATE cuvs::cuvs $<TARGET_NAME_IF_EXISTS:conda_env>
)
target_link_libraries(
CAGRA_BLOOM_FILTER_EXAMPLE PRIVATE cuvs::cuvs cuco::cuco $<TARGET_NAME_IF_EXISTS:conda_env>
)
target_link_libraries(
CAGRA_FILTER_BENCHMARK PRIVATE cuvs::cuvs cuco::cuco $<TARGET_NAME_IF_EXISTS:conda_env>
)
target_link_libraries(CAGRA_HNSW_ACE_EXAMPLE PRIVATE cuvs::cuvs $<TARGET_NAME_IF_EXISTS:conda_env>)
target_link_libraries(
CAGRA_PERSISTENT_EXAMPLE PRIVATE cuvs::cuvs $<TARGET_NAME_IF_EXISTS:conda_env> Threads::Threads
Expand Down
Loading
Loading