diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 227c2906cc..b3f6e23ca9 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -367,7 +367,9 @@ if(NOT BUILD_CPU_ONLY) "$<$:${CUVS_CUDA_FLAGS}>" ) target_compile_features(jit_lto_kernel_usage_requirements INTERFACE cuda_std_20) - target_link_libraries(jit_lto_kernel_usage_requirements INTERFACE rmm::rmm raft::raft CCCL::CCCL) + target_link_libraries( + jit_lto_kernel_usage_requirements INTERFACE rmm::rmm raft::raft CCCL::CCCL cuco::cuco + ) block(PROPAGATE jit_lto_files) set(jit_lto_files) diff --git a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp index ef2a8e6002..3ae94a4b13 100644 --- a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp @@ -14,6 +14,7 @@ struct tag_i8 {}; struct tag_u8 {}; struct tag_filter_none {}; struct tag_filter_bitset {}; +struct tag_filter_bloom_filter {}; struct tag_filter_udf {}; struct tag_bitset_u32 {}; diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 2fd804f115..da9b812f60 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -497,7 +497,7 @@ namespace filtering { * @{ */ -enum class FilterType { None, Bitmap, Bitset, UDF }; +enum class FilterType { None, Bitmap, Bitset, Bloom, UDF }; struct base_filter { ~base_filter() = default; @@ -617,6 +617,32 @@ struct bitset_filter : public base_filter { void to_csr(raft::resources const& handle, csr_matrix_t& csr); }; +/** + * @brief Filter CAGRA candidates with a global @c cuco bloom filter over the index. + * + * Build the filter once on the host with bulk @c add() over the allowed dataset row ids, obtain a + * @c ref() from the owning @c cuco::bloom_filter, copy that ref to device memory, and pass the + * device pointer as @c filter_data. The linked JIT-LTO fragment probes the same filter for every + * query and candidate, similar to @ref bitset_filter but with probabilistic membership tests. + * + * Bloom filters have no false negatives: if a row was inserted, @c contains returns @c true. False + * positives are possible, so highly selective predicates may still need a bitset or UDF for exact + * filtering. + */ +struct bloom_filter : public base_filter { + void* filter_data{nullptr}; + float filtering_rate{-1.0f}; + + bloom_filter() = default; + + explicit bloom_filter(void* filter_data, float filtering_rate = -1.0f) + : filter_data(filter_data), filtering_rate(filtering_rate) + { + } + + FilterType get_filter_type() const override { return FilterType::Bloom; } +}; + /** * @brief JIT-LTO user-defined filter predicate. * diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index ee87c2c0ab..af8331a9f1 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -385,6 +385,25 @@ void search(raft::resources const& res, } catch (const std::bad_cast&) { } + try { + auto& sample_filter = + dynamic_cast(sample_filter_ref); + search_params params_copy = params; + if (params.filtering_rate < 0.0) { + const float min_filtering_rate = 0.0f; + const float max_filtering_rate = 0.999f; + params_copy.filtering_rate = + sample_filter.filtering_rate < 0.0f + ? 0.0f + : std::min(std::max(sample_filter.filtering_rate, min_filtering_rate), + max_filtering_rate); + } + auto sample_filter_copy = sample_filter; + return search_with_filtering( + res, params_copy, idx, queries, neighbors, distances, sample_filter_copy); + } catch (const std::bad_cast&) { + } + try { auto& sample_filter = dynamic_cast(sample_filter_ref); diff --git a/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp b/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp index 88f30f7745..aa927215ef 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp +++ b/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp @@ -139,6 +139,12 @@ template struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter> : std::true_type {}; +template +struct is_bloom_filter : std::false_type {}; + +template <> +struct is_bloom_filter<::cuvs::neighbors::filtering::bloom_filter> : std::true_type {}; + template struct is_udf_filter : std::false_type {}; @@ -177,6 +183,8 @@ void fill_cagra_sample_filter(cagra_sample_filter& out, using DecayedFilter = std::decay_t; if constexpr (is_bitset_filter::value) { out.filter_data = make_cagra_bitset_filter_payload(filter, stream); + } else if constexpr (is_bloom_filter::value) { + out.filter_data = filter.filter_data; } else if constexpr (is_udf_filter::value) { out.filter_data = filter.filter_data; } @@ -199,7 +207,7 @@ template void* cagra_filter_data_ptr(const FilterT& filter) { using DecayedFilter = std::decay_t; - if constexpr (is_udf_filter::value) { + if constexpr (is_bloom_filter::value || is_udf_filter::value) { return filter.filter_data; } else if constexpr (requires { filter.filter; }) { return cagra_filter_data_ptr(filter.filter); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh index d01f58166d..1b3b3825f2 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh @@ -9,6 +9,8 @@ #include "../../sample_filter_data.cuh" +#include + #include #include @@ -38,4 +40,15 @@ __device__ bool sample_filter_bitset_impl(uint32_t /*query_id*/, return view.test(node_id); } +template +__device__ bool sample_filter_bloom_filter_impl(uint32_t /*query_id*/, + SourceIndexT node_id, + void* filter_data) +{ + if (filter_data == nullptr) { return true; } + + auto* data = static_cast*>(filter_data); + return data->filter.contains(static_cast(node_id)); +} + } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json index 0136587b48..b58f56ceb6 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json @@ -1,5 +1,5 @@ { - "filter_name": ["none", "bitset"], + "filter_name": ["none", "bitset", "bloom_filter"], "_bitset": [ { "bitset_type": "uint32_t", diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in index 7c642fe406..e60af91046 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in @@ -11,6 +11,8 @@ namespace { using data_t = @data_type@; using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< cuvs::neighbors::filtering::bitset_filter>; +using bloom_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bloom_filter>; using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< cuvs::neighbors::filtering::udf_filter>; @@ -22,6 +24,7 @@ instantiate_kernel_selection(data_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t); +instantiate_kernel_selection(data_t, uint32_t, float, bloom_filter_t); instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in index 4616a9652b..c9eac33d44 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in @@ -11,6 +11,8 @@ namespace { using data_t = @data_type@; using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< cuvs::neighbors::filtering::bitset_filter>; +using bloom_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bloom_filter>; using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< cuvs::neighbors::filtering::udf_filter>; @@ -22,6 +24,7 @@ instantiate_kernel_selection(data_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t); +instantiate_kernel_selection(data_t, uint32_t, float, bloom_filter_t); instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t); } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index a8731d06b5..28d660e2e1 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -80,6 +80,8 @@ std::uint64_t cagra_sample_filter_type_id(const SampleFilterT& sample_filter) { using DecayedFilter = std::decay_t; if constexpr (is_udf_filter::value) { + return 3; + } else if constexpr (is_bloom_filter::value) { return 2; } else if constexpr (is_bitset_filter::value) { return 1; diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp index e5157ffa6a..dd118757e4 100644 --- a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -100,6 +100,8 @@ struct sample_filter_jit_tag { using namespace cuvs::neighbors::filtering; if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_filter_none{}; + } else if constexpr (is_bloom_filter::value) { + return cuvs::neighbors::detail::tag_filter_bloom_filter{}; } else if constexpr (is_udf_filter::value) { return cuvs::neighbors::detail::tag_filter_udf{}; } else if constexpr (requires { std::declval().filter; }) { @@ -109,6 +111,8 @@ struct sample_filter_jit_tag { std::is_same_v, bitset_filter>) { return cuvs::neighbors::detail::tag_filter_bitset{}; + } else if constexpr (is_bloom_filter>::value) { + return cuvs::neighbors::detail::tag_filter_bloom_filter{}; } else if constexpr (is_udf_filter>::value) { return cuvs::neighbors::detail::tag_filter_udf{}; } else { diff --git a/cpp/src/neighbors/detail/sample_filter_data.cuh b/cpp/src/neighbors/detail/sample_filter_data.cuh index 4c99ca1e3a..3f2e412681 100644 --- a/cpp/src/neighbors/detail/sample_filter_data.cuh +++ b/cpp/src/neighbors/detail/sample_filter_data.cuh @@ -5,6 +5,8 @@ #pragma once +#include + #include #include @@ -21,4 +23,12 @@ struct bitset_filter_data_t { SourceIndexT original_nbits{}; }; +/// Global cuco bloom filter ref for linked @c sample_filter in CAGRA JIT LTO. +template +struct bloom_filter_data_t { + using ref_type = typename cuco::bloom_filter::ref_type<>; + + ref_type filter{}; +}; + } // namespace cuvs::neighbors::detail diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index d63ddbdb71..07573d40dd 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -27,11 +27,15 @@ find_package(Threads) rapids_cpm_init() set(BUILD_CUVS_C_LIBRARY OFF) include(../cmake/thirdparty/get_cuvs.cmake) +include(${rapids-cmake-dir}/cpm/cuco.cmake) +rapids_cpm_cuco() # -------------- compile tasks ----------------- # add_executable(BRUTE_FORCE_EXAMPLE src/brute_force_bitmap.cu) add_executable(CAGRA_EXAMPLE src/cagra_example.cu) add_executable(CAGRA_FILTER_UDF_EXAMPLE src/cagra_filter_udf_example.cu) +add_executable(CAGRA_BLOOM_FILTER_EXAMPLE src/cagra_bloom_filter_example.cu) +add_executable(CAGRA_FILTER_BENCHMARK src/cagra_filter_benchmark.cu) add_executable(CAGRA_HNSW_ACE_EXAMPLE src/cagra_hnsw_ace_example.cu) add_executable(CAGRA_PERSISTENT_EXAMPLE src/cagra_persistent_example.cu) add_executable(DYNAMIC_BATCHING_EXAMPLE src/dynamic_batching_example.cu) @@ -48,6 +52,12 @@ target_link_libraries(CAGRA_EXAMPLE PRIVATE cuvs::cuvs $ ) +target_link_libraries( + CAGRA_BLOOM_FILTER_EXAMPLE PRIVATE cuvs::cuvs cuco::cuco $ +) +target_link_libraries( + CAGRA_FILTER_BENCHMARK PRIVATE cuvs::cuvs cuco::cuco $ +) target_link_libraries(CAGRA_HNSW_ACE_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries( CAGRA_PERSISTENT_EXAMPLE PRIVATE cuvs::cuvs $ Threads::Threads diff --git a/examples/cpp/plot_filter_benchmark.py b/examples/cpp/plot_filter_benchmark.py new file mode 100755 index 0000000000..8840c5e9fc --- /dev/null +++ b/examples/cpp/plot_filter_benchmark.py @@ -0,0 +1,457 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +"""Plot CAGRA filter benchmark CSV (bitset vs bloom_filter line charts).""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +FILTER_LABELS = { + "bitset": "Bitset", + "bloom_filter": "Bloom", +} + +FILTER_ORDER = ["Bitset", "Bloom"] +FILTER_COLORS = {"Bitset": "#0173b2", "Bloom": "#de8f05"} +FILTER_MARKERS = {"Bitset": "o", "Bloom": "s"} + +SEARCH_LABELS = { + 10000: "10k queries", + 25000: "25k queries", +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate line charts from cagra_filter_benchmark CSV output.", + ) + parser.add_argument( + "csv", + nargs="?", + default="filter_bench.csv", + help="Path to benchmark CSV (default: filter_bench.csv)", + ) + parser.add_argument( + "-o", + "--output-dir", + default="filter_bench_plots", + help="Directory for PNG output (default: filter_bench_plots)", + ) + parser.add_argument( + "--with-recall", + action="store_true", + help="Also plot recall charts when non-NaN values are present", + ) + return parser.parse_args() + + +def load_csv(path: Path) -> pd.DataFrame: + df = pd.read_csv(path) + df["recall"] = pd.to_numeric(df["recall"], errors="coerce") + df["filter_label"] = ( + df["filter_type"].map(FILTER_LABELS).fillna(df["filter_type"]) + ) + df["search_label"] = ( + df["search_n_rows"] + .map(SEARCH_LABELS) + .fillna(df["search_n_rows"].astype(str) + " queries") + ) + return df + + +def save_figure(fig: plt.Figure, path: Path) -> None: + fig.subplots_adjust(left=0.11, bottom=0.08, right=0.98, top=0.94) + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_panel_bars( + ax: plt.Axes, + panel: pd.DataFrame, + x_col: str, + y_col: str, +) -> list: + """Grouped bars for two filters at each x value — visible even when latencies are close.""" + x_values = sorted(panel[x_col].unique()) + x_idx = np.arange(len(x_values)) + bar_width = 0.36 + handles = [] + + for series_idx, label in enumerate(FILTER_ORDER): + heights = [] + for x_val in x_values: + row = panel[ + (panel["filter_label"] == label) & (panel[x_col] == x_val) + ] + heights.append(row[y_col].iloc[0] if not row.empty else 0.0) + offset = (series_idx - 0.5) * bar_width + bars = ax.bar( + x_idx + offset, + heights, + width=bar_width, + label=label, + color=FILTER_COLORS[label], + edgecolor="black", + linewidth=0.6, + alpha=0.9, + ) + handles.append(bars[0]) + + ax.set_xticks(x_idx) + if x_col == "build_n_rows": + ax.set_xticklabels([f"{int(v):,}" for v in x_values]) + else: + ax.set_xticklabels( + [ + str(int(v)) if float(v).is_integer() else str(v) + for v in x_values + ] + ) + return handles + + +def plot_panel_lines( + ax: plt.Axes, + panel: pd.DataFrame, + x_col: str, + y_col: str, + log_x: bool = False, +) -> list: + """Draw bitset vs bloom as two explicit series (no seaborn hue/style mashup).""" + handles = [] + x_values = sorted(panel[x_col].unique()) + if log_x: + ax.set_xscale("log") + ax.set_xticks(x_values) + ax.set_xticklabels([f"{int(v):,}" for v in x_values]) + ax.minorticks_off() + + for label in FILTER_ORDER: + series = panel[panel["filter_label"] == label].sort_values(x_col) + if series.empty: + continue + (line,) = ax.plot( + series[x_col], + series[y_col], + label=label, + color=FILTER_COLORS[label], + marker=FILTER_MARKERS[label], + markersize=7, + linewidth=2.0, + markerfacecolor="white", + markeredgewidth=1.5, + ) + handles.append(line) + + return handles + + +def add_row_band_label( + fig: plt.Figure, axes: np.ndarray, row_idx: int, text: str +) -> None: + ax = axes[row_idx, 0] + pos = ax.get_position() + fig.text( + 0.02, + (pos.y0 + pos.y1) / 2, + text, + ha="left", + va="center", + rotation=90, + fontsize=10, + fontweight="bold", + ) + + +def plot_latency_grid_for_valid_pct( + df: pd.DataFrame, + out_dir: Path, + valid_pct: int, +) -> None: + """One PNG per valid_pct: 6×3 grid. + + Row bands (top to bottom): 10k queries × dims 128/512/1024, then 25k × dims. + Columns: k = 64 / 256 / 1024. + Each panel has only two lines (bitset vs bloom) over index size. + """ + sub = df[df["valid_pct"] == valid_pct].copy() + if sub.empty: + print(f"valid_pct={valid_pct}%: no rows, skipping") + return + + k_values = sorted(sub["k"].unique()) + col_values = sorted(sub["build_n_cols"].unique()) + search_values = sorted(sub["search_n_rows"].unique()) + + row_specs = [ + (search, dims) for search in search_values for dims in col_values + ] + n_rows = len(row_specs) + n_cols = len(k_values) + + fig, axes = plt.subplots( + n_rows, + n_cols, + figsize=(4.0 * n_cols + 0.8, 2.4 * n_rows), + squeeze=False, + ) + + legend_handles = None + for row_idx, (search_n_rows, build_n_cols) in enumerate(row_specs): + for col_idx, k in enumerate(k_values): + ax = axes[row_idx, col_idx] + panel = sub[ + (sub["build_n_cols"] == build_n_cols) + & (sub["k"] == k) + & (sub["search_n_rows"] == search_n_rows) + ] + if panel.empty: + ax.set_visible(False) + continue + + handles = plot_panel_bars( + ax, + panel, + x_col="build_n_rows", + y_col="avg_latency_per_query_ms", + ) + if legend_handles is None and handles: + legend_handles = handles + + ax.set_title(f"k={k}", fontsize=10) + ax.set_xlabel("index rows") + if col_idx == 0: + ax.set_ylabel(f"dims={build_n_cols}\nlatency / query (ms)") + else: + ax.set_ylabel("") + + for search_idx, search_n_rows in enumerate(search_values): + band_row = search_idx * len(col_values) + add_row_band_label( + fig, + axes, + band_row, + SEARCH_LABELS.get(search_n_rows, f"{search_n_rows:,} queries"), + ) + + if legend_handles: + fig.legend( + legend_handles, + FILTER_ORDER, + loc="lower center", + bbox_to_anchor=(0.5, -0.02), + ncol=2, + frameon=False, + fontsize=11, + ) + + fig.suptitle( + f"Per-query search latency — {valid_pct}% rows valid", + fontsize=13, + y=1.01, + ) + path = out_dir / f"latency_valid_{valid_pct}pct.png" + save_figure(fig, path) + print(f"wrote {path}") + + +def plot_all_valid_pct_grids(df: pd.DataFrame, out_dir: Path) -> None: + for valid_pct in sorted(df["valid_pct"].unique()): + plot_latency_grid_for_valid_pct(df, out_dir, int(valid_pct)) + + +def plot_valid_pct_overview( + df: pd.DataFrame, + out_dir: Path, + build_n_rows: int, + search_n_rows: int, +) -> None: + """valid_pct on x-axis at one build/search point; 3×3 grid (dims × k).""" + sub = df[ + (df["build_n_rows"] == build_n_rows) + & (df["search_n_rows"] == search_n_rows) + ].copy() + if sub.empty: + print( + "valid_pct overview: no rows for selected build/search slice, skipping" + ) + return + + k_values = sorted(sub["k"].unique()) + col_values = sorted(sub["build_n_cols"].unique()) + + fig, axes = plt.subplots( + len(col_values), + len(k_values), + figsize=(4.0 * len(k_values), 2.8 * len(col_values)), + squeeze=False, + ) + + legend_handles = None + for row_idx, build_n_cols in enumerate(col_values): + for col_idx, k in enumerate(k_values): + ax = axes[row_idx, col_idx] + panel = sub[ + (sub["build_n_cols"] == build_n_cols) & (sub["k"] == k) + ] + if panel.empty: + ax.set_visible(False) + continue + + handles = plot_panel_bars( + ax, + panel, + x_col="valid_pct", + y_col="avg_latency_per_query_ms", + ) + if legend_handles is None and handles: + legend_handles = handles + + ax.set_title(f"k={k}", fontsize=10) + ax.set_xlabel("valid rows (%)") + if col_idx == 0: + ax.set_ylabel(f"dims={build_n_cols}\nlatency / query (ms)") + else: + ax.set_ylabel("") + + if legend_handles: + fig.legend( + legend_handles, + FILTER_ORDER, + loc="lower center", + bbox_to_anchor=(0.5, -0.02), + ncol=2, + frameon=False, + fontsize=11, + ) + + fig.suptitle( + f"Latency vs filter selectivity — " + f"build={build_n_rows:,} rows search={search_n_rows:,} queries", + fontsize=12, + y=1.01, + ) + path = out_dir / "overview_valid_pct_sweep.png" + save_figure(fig, path) + print(f"wrote {path}") + + +def plot_recall_grid_for_valid_pct( + df: pd.DataFrame, + out_dir: Path, + valid_pct: int, +) -> None: + sub = df[(df["valid_pct"] == valid_pct) & df["recall"].notna()].copy() + if sub.empty: + return + + k_values = sorted(sub["k"].unique()) + col_values = sorted(sub["build_n_cols"].unique()) + search_values = sorted(sub["search_n_rows"].unique()) + row_specs = [ + (search, dims) for search in search_values for dims in col_values + ] + + fig, axes = plt.subplots( + len(row_specs), + len(k_values), + figsize=(4.0 * len(k_values) + 0.8, 2.4 * len(row_specs)), + squeeze=False, + ) + + legend_handles = None + for row_idx, (search_n_rows, build_n_cols) in enumerate(row_specs): + for col_idx, k in enumerate(k_values): + ax = axes[row_idx, col_idx] + panel = sub[ + (sub["build_n_cols"] == build_n_cols) + & (sub["k"] == k) + & (sub["search_n_rows"] == search_n_rows) + ] + if panel.empty: + ax.set_visible(False) + continue + + handles = plot_panel_bars( + ax, + panel, + x_col="build_n_rows", + y_col="recall", + ) + if legend_handles is None and handles: + legend_handles = handles + + ax.set_ylim(0.0, 1.0) + ax.set_title(f"k={k}", fontsize=10) + ax.set_xlabel("index rows") + if col_idx == 0: + ax.set_ylabel(f"dims={build_n_cols}\nrecall@k") + else: + ax.set_ylabel("") + + for search_idx, search_n_rows in enumerate(search_values): + band_row = search_idx * len(col_values) + add_row_band_label( + fig, + axes, + band_row, + SEARCH_LABELS.get(search_n_rows, f"{search_n_rows:,} queries"), + ) + + if legend_handles: + fig.legend( + legend_handles, + FILTER_ORDER, + loc="lower center", + bbox_to_anchor=(0.5, -0.02), + ncol=2, + frameon=False, + fontsize=11, + ) + + fig.suptitle(f"Recall@k — {valid_pct}% rows valid", fontsize=13, y=1.01) + path = out_dir / f"recall_valid_{valid_pct}pct.png" + save_figure(fig, path) + print(f"wrote {path}") + + +def main() -> None: + args = parse_args() + csv_path = Path(args.csv) + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + if not csv_path.is_file(): + raise SystemExit(f"CSV not found: {csv_path}") + + df = load_csv(csv_path) + print(f"loaded {len(df)} rows from {csv_path}") + + sns.set_theme(style="whitegrid", context="notebook") + + plot_all_valid_pct_grids(df, out_dir) + + default_build = int(df["build_n_rows"].min()) + default_search = int(df["search_n_rows"].min()) + plot_valid_pct_overview(df, out_dir, default_build, default_search) + + if args.with_recall and df["recall"].notna().any(): + for valid_pct in sorted(df["valid_pct"].unique()): + plot_recall_grid_for_valid_pct(df, out_dir, int(valid_pct)) + elif args.with_recall: + print( + "recall charts skipped: CSV has no recall values (run benchmark with --ground-truth)" + ) + else: + print("recall charts skipped (default; pass --with-recall to enable)") + + +if __name__ == "__main__": + main() diff --git a/examples/cpp/src/cagra_bloom_filter_example.cu b/examples/cpp/src/cagra_bloom_filter_example.cu new file mode 100644 index 0000000000..7d16989bbb --- /dev/null +++ b/examples/cpp/src/cagra_bloom_filter_example.cu @@ -0,0 +1,171 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace { + +constexpr int64_t n_rows = 4096; +constexpr int64_t n_dim = 32; +constexpr int64_t n_queries = 4; +constexpr int64_t k = 8; +constexpr int sub_filters = 256; + +using key_type = std::uint32_t; +using filter_type = cuco::bloom_filter; +using ref_type = filter_type::ref_type<>; + +// Layout must match cuvs::neighbors::detail::bloom_filter_data_t in the JIT fragment. +struct bloom_payload { + ref_type filter; +}; + +// Global index filter: even row ids are valid candidates (same rule for every query). +bool is_valid_row(key_type source_id) { return (source_id % 2) == 0; } + +std::vector copy_neighbors_to_host( + raft::device_resources const& res, + raft::device_matrix_view neighbors) +{ + std::vector host(neighbors.size()); + raft::copy( + host.data(), neighbors.data_handle(), host.size(), raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + return host; +} + +} // namespace + +int main() +{ + raft::device_resources res; + auto stream = raft::resource::get_cuda_stream(res); + + rmm::mr::pool_memory_resource pool_mr(rmm::mr::get_current_device_resource_ref(), + 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(pool_mr); + + auto dataset = raft::make_device_matrix(res, n_rows, n_dim); + auto queries = raft::make_device_matrix(res, n_queries, n_dim); + + raft::random::RngState rng(1234ULL); + raft::random::uniform(res, rng, dataset.data_handle(), dataset.size(), -1.0f, 1.0f); + raft::random::uniform(res, rng, queries.data_handle(), queries.size(), -1.0f, 1.0f); + + cuvs::neighbors::cagra::index_params index_params; + index_params.metric = cuvs::distance::DistanceType::L2Expanded; + index_params.graph_degree = 32; + index_params.intermediate_graph_degree = 64; + index_params.graph_build_params = cuvs::neighbors::cagra::graph_build_params::nn_descent_params( + index_params.intermediate_graph_degree); + + std::cout << "Building CAGRA index" << std::endl; + auto index = + cuvs::neighbors::cagra::build(res, index_params, raft::make_const_mdspan(dataset.view())); + + // Build one global bloom filter over the index: bulk-insert every valid row id once. + std::vector valid_ids_host; + valid_ids_host.reserve(static_cast(n_rows / 2)); + for (int64_t i = 0; i < n_rows; ++i) { + if (is_valid_row(static_cast(i))) { + valid_ids_host.push_back(static_cast(i)); + } + } + + rmm::device_uvector valid_ids_device(valid_ids_host.size(), stream); + raft::copy(valid_ids_device.data(), valid_ids_host.data(), valid_ids_host.size(), stream); + + filter_type allowed_rows{sub_filters}; + allowed_rows.add_async( + valid_ids_device.data(), valid_ids_device.data() + valid_ids_device.size(), stream); + raft::resource::sync_stream(res); + + std::cout << "Inserted " << valid_ids_host.size() + << " valid row ids into global bloom filter via bulk add_async" << std::endl; + + // Copy the owning filter's device ref into a payload the JIT fragment can probe. + auto payload_device = raft::make_device_vector(res, 1); + bloom_payload host_payload{allowed_rows.ref()}; + raft::copy(payload_device.data_handle(), &host_payload, 1, stream); + raft::resource::sync_stream(res); + + auto neighbors = raft::make_device_matrix(res, n_queries, k); + auto distances = raft::make_device_matrix(res, n_queries, k); + + cuvs::neighbors::cagra::search_params search_params; + search_params.algo = cuvs::neighbors::cagra::search_algo::MULTI_CTA; + search_params.itopk_size = 128; + search_params.max_queries = n_queries; + search_params.thread_block_size = 256; + + // ~50% of rows are rejected by the global even-id predicate. + auto filter = cuvs::neighbors::filtering::bloom_filter(payload_device.data_handle(), 0.5f); + + cuvs::neighbors::cagra::search(res, + search_params, + index, + raft::make_const_mdspan(queries.view()), + neighbors.view(), + distances.view(), + filter); + + auto host_neighbors = copy_neighbors_to_host(res, neighbors.view()); + + std::cout << "bloom_filter first query neighbors:"; + for (int64_t i = 0; i < k; ++i) { + std::cout << " " << host_neighbors[static_cast(i)]; + } + std::cout << std::endl; + + // Validate with cuco's bulk contains API over the returned neighbors. + rmm::device_uvector neighbor_ids_device(host_neighbors.size(), stream); + rmm::device_uvector bloom_hits_device(host_neighbors.size(), stream); + raft::copy(neighbor_ids_device.data(), host_neighbors.data(), host_neighbors.size(), stream); + allowed_rows.contains_async(neighbor_ids_device.data(), + neighbor_ids_device.data() + neighbor_ids_device.size(), + bloom_hits_device.data(), + stream); + raft::resource::sync_stream(res); + + std::vector bloom_hits_host(bloom_hits_device.size()); + raft::copy(bloom_hits_host.data(), bloom_hits_device.data(), bloom_hits_host.size(), stream); + raft::resource::sync_stream(res); + + for (size_t i = 0; i < host_neighbors.size(); ++i) { + auto source_id = host_neighbors[i]; + if (source_id >= static_cast(n_rows)) { + std::cerr << "bloom_filter produced out-of-range source_id=" << source_id << std::endl; + return 1; + } + if (bloom_hits_host[i] == 0) { + std::cerr << "bloom_filter rejected source_id=" << source_id + << " but global bloom filter bulk contains says absent" << std::endl; + return 1; + } + if (!is_valid_row(source_id)) { + std::cerr << "bloom_filter allowed invalid source_id=" << source_id + << " (unexpected bloom false positive)" << std::endl; + return 1; + } + } + + std::cout << "CAGRA bloom filter example produced valid filtered neighbors." << std::endl; + return 0; +} diff --git a/examples/cpp/src/cagra_filter_benchmark.cu b/examples/cpp/src/cagra_filter_benchmark.cu new file mode 100644 index 0000000000..2c72c22ddd --- /dev/null +++ b/examples/cpp/src/cagra_filter_benchmark.cu @@ -0,0 +1,537 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr int k_warmup_runs = 1; +constexpr int k_timed_runs = 3; + +using key_type = std::uint32_t; +using filter_type = cuco::bloom_filter; +using ref_type = filter_type::ref_type<>; + +struct bloom_payload { + ref_type filter; +}; + +constexpr std::array k_build_rows{100'000, 500'000, 1'000'000}; +constexpr std::array k_build_cols{128, 512, 1024}; +constexpr std::array k_search_rows{10'000, 25'000}; +constexpr std::array k_search_cols{128, 512, 1024}; +constexpr std::array k_valid_pcts{10, 50, 90}; +constexpr std::array k_values{64, 256, 1024}; + +bool is_valid_row(key_type row_id, int valid_pct) +{ + // Deterministic ~valid_pct% membership independent of dataset size. + return (static_cast(row_id) * 2654435761ULL) % 100ULL < + static_cast(valid_pct); +} + +std::size_t bloom_num_blocks(std::size_t num_valid_rows) +{ + // Scale filter size with the number of inserted keys; keep a reasonable minimum. + std::size_t blocks = std::max(256, num_valid_rows / 8); + blocks = std::min(blocks, static_cast(1 << 20)); + return blocks; +} + +double compute_recall(std::vector const& expected, + std::vector const& actual, + int64_t n_queries, + int64_t k, + int64_t n_rows) +{ + std::size_t match_count = 0; + std::size_t total_count = static_cast(n_queries) * static_cast(k); + for (int64_t q = 0; q < n_queries; ++q) { + for (int64_t ki = 0; ki < k; ++ki) { + auto const act = actual[static_cast(q * k + ki)]; + if (act >= static_cast(n_rows)) { continue; } + for (int64_t kj = 0; kj < k; ++kj) { + if (expected[static_cast(q * k + kj)] == act) { + ++match_count; + break; + } + } + } + } + return total_count == 0 ? 0.0 + : static_cast(match_count) / static_cast(total_count); +} + +std::vector copy_neighbors_to_host(raft::device_resources const& res, + raft::device_matrix_view neighbors) +{ + std::vector host(neighbors.size()); + auto stream = raft::resource::get_cuda_stream(res); + raft::copy(host.data(), neighbors.data_handle(), host.size(), stream); + raft::resource::sync_stream(res); + return host; +} + +struct filter_assets { + cuvs::core::bitset removed_bitset; + cuvs::neighbors::filtering::bitset_filter bitset_filter; + filter_type bloom; + rmm::device_uvector bloom_payload; + cuvs::neighbors::filtering::bloom_filter bloom_filter; + float filtering_rate{0.0f}; +}; + +filter_assets make_filters(raft::device_resources const& res, + int64_t n_rows, + int valid_pct, + rmm::cuda_stream_view stream) +{ + std::vector valid_ids_host; + std::vector removed_ids_host; + valid_ids_host.reserve(static_cast(n_rows)); + removed_ids_host.reserve(static_cast(n_rows)); + + for (int64_t i = 0; i < n_rows; ++i) { + auto const row = static_cast(i); + if (is_valid_row(row, valid_pct)) { + valid_ids_host.push_back(row); + } else { + removed_ids_host.push_back(i); + } + } + + auto removed_ids = + raft::make_device_vector(res, static_cast(removed_ids_host.size())); + if (!removed_ids_host.empty()) { + raft::copy(removed_ids.data_handle(), removed_ids_host.data(), removed_ids_host.size(), stream); + } + + auto removed_bitset = cuvs::core::bitset(res, removed_ids.view(), n_rows); + auto bitset_filter = + cuvs::neighbors::filtering::bitset_filter(removed_bitset.view()); + auto bloom = filter_type{bloom_num_blocks(valid_ids_host.size()), {}, {}, {}, stream}; + auto payload_device = rmm::device_uvector{1, stream}; + float const filtering_rate = static_cast(100 - valid_pct) / 100.0f; + + if (!valid_ids_host.empty()) { + rmm::device_uvector valid_ids_device(valid_ids_host.size(), stream); + raft::copy(valid_ids_device.data(), valid_ids_host.data(), valid_ids_host.size(), stream); + bloom.add_async( + valid_ids_device.data(), valid_ids_device.data() + valid_ids_device.size(), stream); + } + + bloom_payload host_payload{bloom.ref()}; + raft::copy(payload_device.data(), &host_payload, 1, stream); + auto bloom_filter_obj = + cuvs::neighbors::filtering::bloom_filter(payload_device.data(), filtering_rate); + + raft::resource::sync_stream(res); + return filter_assets{std::move(removed_bitset), + std::move(bitset_filter), + std::move(bloom), + std::move(payload_device), + std::move(bloom_filter_obj), + filtering_rate}; +} + +struct benchmark_case { + int64_t build_n_rows; + int64_t build_n_cols; + int64_t search_n_rows; + int64_t search_n_cols; + int valid_pct; + int64_t k; +}; + +struct csv_row { + benchmark_case config; + std::string filter_name; + double build_time_ms; + double avg_search_latency_ms; + double avg_latency_per_query_ms; + double recall; +}; + +void write_csv_header(std::ostream& os) +{ + os << "build_n_rows,build_n_cols,search_n_rows,search_n_cols,valid_pct,filter_type," + "build_time_ms,avg_search_latency_ms,avg_latency_per_query_ms,recall,k,warmup_runs," + "timed_runs\n"; +} + +void write_csv_row(std::ostream& os, csv_row const& row) +{ + os << row.config.build_n_rows << ',' << row.config.build_n_cols << ',' << row.config.search_n_rows + << ',' << row.config.search_n_cols << ',' << row.config.valid_pct << ',' << row.filter_name + << ',' << row.build_time_ms << ',' << row.avg_search_latency_ms << ',' + << row.avg_latency_per_query_ms << ',' << row.recall << ',' << row.config.k << ',' + << k_warmup_runs << ',' << k_timed_runs << '\n'; +} + +template +double time_cuda_ms(raft::device_resources const& res, int runs, Fn&& fn) +{ + auto stream = raft::resource::get_cuda_stream(res); + cudaEvent_t start{}; + cudaEvent_t stop{}; + RAFT_CUDA_TRY(cudaEventCreate(&start)); + RAFT_CUDA_TRY(cudaEventCreate(&stop)); + + RAFT_CUDA_TRY(cudaEventRecord(start, stream)); + for (int i = 0; i < runs; ++i) { + fn(); + } + RAFT_CUDA_TRY(cudaEventRecord(stop, stream)); + RAFT_CUDA_TRY(cudaEventSynchronize(stop)); + + float elapsed_ms = 0.0f; + RAFT_CUDA_TRY(cudaEventElapsedTime(&elapsed_ms, start, stop)); + RAFT_CUDA_TRY(cudaEventDestroy(start)); + RAFT_CUDA_TRY(cudaEventDestroy(stop)); + return static_cast(elapsed_ms) / static_cast(runs); +} + +void append_cases(std::vector& cases, + std::vector const& build_rows, + std::vector const& build_cols, + std::vector const& search_rows, + std::vector const& search_cols, + std::vector const& valid_pcts, + std::vector const& k_sweep) +{ + for (auto build_n_rows : build_rows) { + for (auto build_n_cols : build_cols) { + for (auto search_n_rows : search_rows) { + for (auto search_n_cols : search_cols) { + if (search_n_cols != build_n_cols) { continue; } + for (auto valid_pct : valid_pcts) { + for (auto k : k_sweep) { + cases.push_back(benchmark_case{ + build_n_rows, build_n_cols, search_n_rows, search_n_cols, valid_pct, k}); + } + } + } + } + } + } +} + +std::vector make_cases(bool quick) +{ + std::vector cases; + if (quick) { + append_cases(cases, {100'000}, {128}, {10'000}, {128}, {1, 50}, {64}); + } else { + append_cases(cases, + {k_build_rows.begin(), k_build_rows.end()}, + {k_build_cols.begin(), k_build_cols.end()}, + {k_search_rows.begin(), k_search_rows.end()}, + {k_search_cols.begin(), k_search_cols.end()}, + {k_valid_pcts.begin(), k_valid_pcts.end()}, + {k_values.begin(), k_values.end()}); + } + return cases; +} + +constexpr std::size_t k_max_bf_bytes = 20ULL << 30; // skip BF recall above this estimate +constexpr std::size_t k_max_bf_chunk_bytes = 2ULL << 30; // cap each BF chunk when computing recall + +std::size_t estimate_bf_distance_matrix_bytes(int64_t n_queries, int64_t n_dataset) +{ + return static_cast(n_queries) * static_cast(n_dataset) * sizeof(float); +} + +bool should_compute_bf_recall(int64_t n_queries, int64_t n_dataset) +{ + return estimate_bf_distance_matrix_bytes(n_queries, n_dataset) <= k_max_bf_bytes; +} + +int64_t choose_gt_chunk_queries(int64_t n_dataset) +{ + int64_t chunk = 256; + while (chunk > 1 && estimate_bf_distance_matrix_bytes(chunk, n_dataset) > k_max_bf_chunk_bytes) { + chunk /= 2; + } + return chunk; +} + +std::vector brute_force_ground_truth( + raft::device_resources const& res, + cuvs::neighbors::brute_force::index& bf_index, + cuvs::neighbors::brute_force::search_params const& bf_search_params, + raft::device_matrix_view queries, + cuvs::neighbors::filtering::bitset_filter const& bitset_filter, + int64_t k, + int64_t gt_chunk_queries) +{ + int64_t const n_queries = queries.extent(0); + std::vector gt_host(static_cast(n_queries * k)); + auto stream = raft::resource::get_cuda_stream(res); + + for (int64_t query_offset = 0; query_offset < n_queries; query_offset += gt_chunk_queries) { + int64_t const chunk_queries = std::min(gt_chunk_queries, n_queries - query_offset); + auto query_chunk = raft::make_device_matrix_view( + queries.data_handle() + query_offset * queries.extent(1), chunk_queries, queries.extent(1)); + auto gt_neighbors = raft::make_device_matrix(res, chunk_queries, k); + auto gt_distances = raft::make_device_matrix(res, chunk_queries, k); + + cuvs::neighbors::brute_force::search(res, + bf_search_params, + bf_index, + raft::make_const_mdspan(query_chunk), + gt_neighbors.view(), + gt_distances.view(), + bitset_filter); + raft::resource::sync_stream(res); + + std::vector chunk_host(static_cast(chunk_queries * k)); + raft::copy(chunk_host.data(), gt_neighbors.data_handle(), chunk_host.size(), stream); + raft::resource::sync_stream(res); + + for (int64_t q = 0; q < chunk_queries; ++q) { + for (int64_t ki = 0; ki < k; ++ki) { + gt_host[static_cast((query_offset + q) * k + ki)] = + static_cast(chunk_host[static_cast(q * k + ki)]); + } + } + } + + return gt_host; +} + +} // namespace + +int main(int argc, char** argv) +{ + std::string output_path = "cagra_filter_benchmark_results.csv"; + bool quick = false; + bool compute_ground_truth = false; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--quick") { + quick = true; + } else if (arg == "--ground-truth") { + compute_ground_truth = true; + } else if (arg == "--skip-ground-truth") { + compute_ground_truth = false; + } else if (arg == "--output" && i + 1 < argc) { + output_path = argv[++i]; + } else if (arg == "--help" || arg == "-h") { + std::cout << "Usage: " << argv[0] + << " [--quick] [--ground-truth] [--skip-ground-truth] [--output path.csv]\n" + << "\n" + << "Brute-force recall is skipped by default. Pass --ground-truth to compute it.\n"; + return 0; + } else { + output_path = arg; + } + } + + auto cases = make_cases(quick); + std::cout << "Running " << cases.size() << " benchmark configurations" + << (quick ? " (quick mode)" : "") + << (compute_ground_truth ? " (with ground-truth recall)" : " (ground-truth skipped)") + << std::endl; + + std::ofstream csv(output_path); + if (!csv) { + std::cerr << "Failed to open output file: " << output_path << std::endl; + return 1; + } + write_csv_header(csv); + + raft::device_resources res; + auto stream = raft::resource::get_cuda_stream(res); + + // Large enough for the biggest benchmark configuration (1M x 1024 dataset + index overhead). + rmm::mr::pool_memory_resource pool_mr(rmm::mr::get_current_device_resource_ref(), 16ULL << 30); + rmm::mr::set_current_device_resource(pool_mr); + + int64_t prev_build_rows = -1; + int64_t prev_build_cols = -1; + double last_build_time_ms = 0.0; + + std::optional> index; + std::optional> dataset; + std::optional> bf_index; + + cuvs::neighbors::cagra::index_params index_params; + index_params.metric = cuvs::distance::DistanceType::L2Expanded; + index_params.graph_degree = 32; + index_params.intermediate_graph_degree = 64; + index_params.graph_build_params = cuvs::neighbors::cagra::graph_build_params::nn_descent_params( + index_params.intermediate_graph_degree); + + cuvs::neighbors::cagra::search_params search_params; + search_params.algo = cuvs::neighbors::cagra::search_algo::MULTI_CTA; + search_params.itopk_size = 128; + search_params.thread_block_size = 256; + + cuvs::neighbors::brute_force::index_params bf_index_params; + cuvs::neighbors::brute_force::search_params bf_search_params; + + for (std::size_t case_idx = 0; case_idx < cases.size(); ++case_idx) { + auto const& cfg = cases[case_idx]; + + if (cfg.build_n_rows != prev_build_rows || cfg.build_n_cols != prev_build_cols) { + std::cout << "Building CAGRA index: n_rows=" << cfg.build_n_rows + << " n_cols=" << cfg.build_n_cols << std::endl; + + dataset.emplace( + raft::make_device_matrix(res, cfg.build_n_rows, cfg.build_n_cols)); + raft::random::RngState rng( + static_cast(cfg.build_n_rows * 17 + cfg.build_n_cols)); + raft::random::uniform(res, rng, dataset->data_handle(), dataset->size(), -1.0f, 1.0f); + + auto build_start = std::chrono::steady_clock::now(); + index.emplace( + cuvs::neighbors::cagra::build(res, index_params, raft::make_const_mdspan(dataset->view()))); + bf_index.reset(); + raft::resource::sync_stream(res); + auto build_end = std::chrono::steady_clock::now(); + last_build_time_ms = + std::chrono::duration(build_end - build_start).count(); + + prev_build_rows = cfg.build_n_rows; + prev_build_cols = cfg.build_n_cols; + } + + int64_t const total_queries = cfg.search_n_rows; + search_params.max_queries = total_queries; + search_params.itopk_size = static_cast(cfg.k); + + std::cout << "Case " << (case_idx + 1) << '/' << cases.size() + << ": build_n_rows=" << cfg.build_n_rows << " search_n_rows=" << total_queries + << " k=" << cfg.k << " valid_pct=" << cfg.valid_pct << '%' << std::endl; + + try { + auto queries = + raft::make_device_matrix(res, total_queries, cfg.search_n_cols); + raft::random::RngState query_rng(static_cast( + cfg.build_n_rows * 31 + cfg.search_n_rows * 17 + cfg.search_n_cols + cfg.valid_pct)); + raft::random::uniform(res, query_rng, queries.data_handle(), queries.size(), -1.0f, 1.0f); + + auto neighbors = raft::make_device_matrix(res, total_queries, cfg.k); + auto distances = raft::make_device_matrix(res, total_queries, cfg.k); + + auto filters = make_filters(res, cfg.build_n_rows, cfg.valid_pct, stream); + + bool const run_bf_recall = + compute_ground_truth && should_compute_bf_recall(total_queries, cfg.build_n_rows); + std::optional> gt_host; + if (run_bf_recall) { + if (!bf_index.has_value()) { + bf_index.emplace(cuvs::neighbors::brute_force::build( + res, bf_index_params, raft::make_const_mdspan(dataset->view()))); + } + gt_host.emplace(brute_force_ground_truth(res, + *bf_index, + bf_search_params, + queries.view(), + filters.bitset_filter, + cfg.k, + choose_gt_chunk_queries(cfg.build_n_rows))); + } else if (compute_ground_truth) { + auto const est_gib = + static_cast(estimate_bf_distance_matrix_bytes(total_queries, cfg.build_n_rows)) / + static_cast(1ULL << 30); + std::cout << " skipping brute-force recall (estimated " << est_gib + << " GiB distance matrix > 20 GiB limit)" << std::endl; + } + + auto run_cagra_search = [&](cuvs::neighbors::filtering::base_filter const& filter) { + cuvs::neighbors::cagra::search(res, + search_params, + *index, + raft::make_const_mdspan(queries.view()), + neighbors.view(), + distances.view(), + filter); + raft::resource::sync_stream(res); + }; + + struct filter_run { + std::string name; + cuvs::neighbors::filtering::base_filter const* filter; + }; + std::vector filter_runs{ + {"bitset", &filters.bitset_filter}, + {"bloom_filter", &filters.bloom_filter}, + }; + + for (auto const& fr : filter_runs) { + for (int w = 0; w < k_warmup_runs; ++w) { + run_cagra_search(*fr.filter); + } + + double const avg_search_ms = + time_cuda_ms(res, k_timed_runs, [&] { run_cagra_search(*fr.filter); }); + double const avg_per_query_ms = avg_search_ms / static_cast(total_queries); + + double const recall = [&]() { + if (!gt_host.has_value()) { return std::numeric_limits::quiet_NaN(); } + auto result_host = copy_neighbors_to_host(res, neighbors.view()); + return compute_recall(*gt_host, result_host, total_queries, cfg.k, cfg.build_n_rows); + }(); + + write_csv_row( + csv, csv_row{cfg, fr.name, last_build_time_ms, avg_search_ms, avg_per_query_ms, recall}); + csv.flush(); + + std::cout << " " << fr.name << ": search_ms=" << avg_search_ms + << " per_query_ms=" << avg_per_query_ms << " recall="; + if (gt_host.has_value()) { + std::cout << recall; + } else { + std::cout << "n/a"; + } + std::cout << std::endl; + } + } catch (std::exception const& ex) { + std::cerr << " case failed: " << ex.what() << std::endl; + for (auto const* filter_name : {"bitset", "bloom_filter"}) { + write_csv_row(csv, + csv_row{cfg, + filter_name, + last_build_time_ms, + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + csv.flush(); + } + } + } + + std::cout << "Wrote results to " << output_path << std::endl; + return 0; +}