Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions UPSTREAM_PR_BODY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
## Add Rust bindings for the `refine` API

### What

This PR adds safe Rust bindings for the `cuvsRefine` C API in the `cuvs` crate.

Refinement is a free function (not an index type) that follows an approximate
nearest-neighbors search: given a per-query candidate list produced by an ANN
method, it recomputes exact distances against the original dataset and selects
the true top-`k`. This lets callers trade a cheap approximate first pass for an
exact re-rank over a small candidate set.

The new `cuvs::refine::refine` free function mirrors the shape of the existing
`cuvs::distance::pairwise_distance` wrapper — it takes `Resources`, input/output
`ManagedTensor`s, and a `DistanceType`, and returns `Result<()>`. No new index
struct is introduced.

```rust
pub fn refine(
res: &Resources,
dataset: &ManagedTensor,
queries: &ManagedTensor,
candidates: &ManagedTensor,
metric: DistanceType,
indices: &ManagedTensor,
distances: &ManagedTensor,
) -> Result<()>
```

### Files changed

- `rust/cuvs/src/refine.rs` (new) — `refine()` wrapper, doc comment with a
runnable (`no_run`) example, and a behavioral unit test.
- `rust/cuvs/src/lib.rs` — `pub mod refine;`.

### Reviewer notes

- **Bindings already existed.** `cuvsRefine` is already present in the generated
`rust/cuvs-sys/src/bindings.rs` (it lives in `core/all.h`, adjacent to the
`ivf_flat` block), so no `cuvs-sys` regeneration was required. This PR is
Rust-side only.
- **Contract from `c/src/neighbors/refine.cpp`:** all tensors must live in the
same memory space (all device or all host — the C layer rejects mixing).
`candidates` and output `indices` must be `int64`; output `distances` must be
`float32`; `queries`/`dataset` dtype codes must match. `k` is taken from the
output tensor shape (`[n_queries, k]`), and `n_candidates >= k`. The wrapper
forwards tensors as-is and surfaces these constraints in the doc comment;
validation is left to the C layer (consistent with the other wrappers).
- The free-function placement (`refine.rs` at the crate root, alongside
`distance/`) matches `pairwise_distance`. Open to relocating under a
`neighbors`-style module if the crate later groups neighbor ops.

### Testing summary

- `cargo build -p cuvs` — clean.
- `cargo test -p cuvs refine -- --test-threads=1` — the unit test
`test_refine_fixes_wrong_candidates` passes. It builds a small, well-separated
2-D dataset, hands `refine` deliberately **wrong / mis-ordered** candidate
lists (each containing a planted far-away noise index), and asserts that the
refined top-`k` exactly equals the brute-force exact top-`k`: the planted noise
candidates are evicted, the true nearest neighbor is restored to rank 0, the
refined index sets match the exact sets, and distances come back sorted
ascending. This verifies real re-ranking behavior, not merely that the call
succeeds.
- `cargo test -p cuvs --doc refine` — the doc example compiles.
- `cargo fmt -p cuvs -- --check` — clean.
- `cargo clippy -p cuvs` — no findings on the new code. (There is a pre-existing
`not_unsafe_ptr_arg_deref` lint on `resources.rs::set_cuda_stream` from a newer
clippy; it is untouched by this PR.)
- Built and tested against conda `libcuvs` 26.06 with the DLPack CMake package on
`CMAKE_PREFIX_PATH`, on a single CUDA device.

### Sibling-PR conflict note

This work was developed alongside a separate IVF-SQ bindings PR. Both touch
`rust/cuvs/src/lib.rs` (each adds one `pub mod` line). The additions are
independent and order-agnostic; whichever lands second will need a trivial
one-line merge in `lib.rs`. No other files overlap.
1 change: 1 addition & 0 deletions rust/cuvs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod dlpack;
mod error;
pub mod ivf_flat;
pub mod ivf_pq;
pub mod refine;
mod resources;
pub mod vamana;

Expand Down
228 changes: 228 additions & 0 deletions rust/cuvs/src/refine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
//! Refinement of approximate nearest neighbor results

use crate::distance_type::DistanceType;
use crate::dlpack::ManagedTensor;
use crate::error::{Result, check_cuvs};
use crate::resources::Resources;

/// Refine nearest neighbor search results.
///
/// Refinement is an operation that follows an approximate nearest neighbors
/// search. The approximate search has already selected `n_candidates` neighbor
/// candidates for each query. This narrows the candidate list down to the `k`
/// nearest neighbors by computing the exact distance between each query and its
/// candidates against the original dataset, then selecting the `k` closest.
///
/// All tensors must reside in the same memory space: either all on the device
/// or all on the host. The dataset and queries may be `f32`, `f16`, `i8`, or
/// `u8` (with matching dtype codes). The candidate and output index tensors
/// must be `i64`, and the output distance tensor must be `f32`.
///
/// # Arguments
///
/// * `res` - Resources to use
/// * `dataset` - A row-major matrix of the original dataset - shape `(n_rows, dims)`
/// * `queries` - A row-major matrix of the queries - shape `(n_queries, dims)`
/// * `candidates` - A row-major `i64` matrix of candidate indices into `dataset`
/// - shape `(n_queries, n_candidates)`, where `n_candidates >= k`
/// * `metric` - DistanceType used to rank candidates
/// * `indices` - Output `i64` matrix that receives the refined indices - shape
/// `(n_queries, k)`. `k` is inferred from this tensor's shape.
/// * `distances` - Output `f32` matrix that receives the refined distances -
/// shape `(n_queries, k)`
///
/// # Example
///
/// ```no_run
/// use cuvs::{ManagedTensor, Resources, Result};
/// use cuvs::distance_type::DistanceType;
/// use cuvs::refine::refine;
/// use ndarray::array;
///
/// fn do_refine() -> Result<()> {
/// let res = Resources::new()?;
///
/// // A tiny dataset with four 2-D points.
/// let dataset = array![[0.0f32, 0.0], [1.0, 0.0], [0.0, 1.0], [5.0, 5.0]];
/// let queries = array![[0.1f32, 0.1]];
///
/// // Approximate candidates - includes the far-away point 3 by mistake.
/// let candidates = array![[3i64, 1, 0]];
///
/// let dataset_d = ManagedTensor::from(&dataset).to_device(&res)?;
/// let queries_d = ManagedTensor::from(&queries).to_device(&res)?;
/// let candidates_d = ManagedTensor::from(&candidates).to_device(&res)?;
///
/// let mut indices_host = ndarray::Array::<i64, _>::zeros((1, 2));
/// let mut distances_host = ndarray::Array::<f32, _>::zeros((1, 2));
/// let indices_d = ManagedTensor::from(&indices_host).to_device(&res)?;
/// let distances_d = ManagedTensor::from(&distances_host).to_device(&res)?;
///
/// refine(
/// &res,
/// &dataset_d,
/// &queries_d,
/// &candidates_d,
/// DistanceType::L2Expanded,
/// &indices_d,
/// &distances_d,
/// )?;
///
/// indices_d.to_host(&res, &mut indices_host)?;
/// distances_d.to_host(&res, &mut distances_host)?;
/// res.sync_stream()?;
///
/// // Point 0 is the true nearest neighbor; the wrong candidate 3 is dropped.
/// assert_eq!(indices_host[[0, 0]], 0);
/// Ok(())
/// }
/// ```
pub fn refine(
res: &Resources,
dataset: &ManagedTensor,
queries: &ManagedTensor,
candidates: &ManagedTensor,
metric: DistanceType,
indices: &ManagedTensor,
distances: &ManagedTensor,
) -> Result<()> {
unsafe {
check_cuvs(ffi::cuvsRefine(
res.0,
dataset.as_ptr(),
queries.as_ptr(),
candidates.as_ptr(),
metric,
indices.as_ptr(),
distances.as_ptr(),
))
}
}

#[cfg(test)]
mod tests {
use super::*;

/// Refinement must repair a candidate list that contains deliberately
/// wrong entries: after refine, the top-k must equal the exact
/// brute-force top-k.
#[test]
fn test_refine_fixes_wrong_candidates() {
let res = Resources::new().unwrap();

// A small, well-separated 2-D dataset. The exact L2 ranking of every
// query is unambiguous, so we can hard-assert the refined output.
//
// index : point
// 0 : (0, 0)
// 1 : (1, 0)
// 2 : (0, 1)
// 3 : (2, 2)
// 4 : (5, 5)
// 5 : (9, 9)
let dataset = ndarray::array![
[0.0f32, 0.0],
[1.0, 0.0],
[0.0, 1.0],
[2.0, 2.0],
[5.0, 5.0],
[9.0, 9.0],
];

// Two queries near distinct clusters.
// q0 sits next to point 0; true top-3 = [0, 1, 2]
// q1 sits next to point 4; true top-3 = [4, 3, 5] (4 closest, then 3, then 5)
let queries = ndarray::array![[0.1f32, 0.1], [4.9, 4.9]];

// Candidate lists are intentionally *wrong order* and include far-away
// points. Each list is a superset of the true top-3 but jumbled, plus a
// planted bad candidate (index 5 for q0, index 0 for q1). Refine must
// re-rank these exactly and select the correct nearest three.
let candidates = ndarray::array![
[5i64, 2, 0, 1], // q0: true nearest 0 is buried, 5 is far noise
[0i64, 5, 3, 4], // q1: true nearest 4 is last, 0 is far noise
];

let n_queries = 2;
let k = 3;

let dataset_d = ManagedTensor::from(&dataset).to_device(&res).unwrap();
let queries_d = ManagedTensor::from(&queries).to_device(&res).unwrap();
let candidates_d = ManagedTensor::from(&candidates).to_device(&res).unwrap();

let mut indices_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let indices_d = ManagedTensor::from(&indices_host).to_device(&res).unwrap();
let distances_d = ManagedTensor::from(&distances_host).to_device(&res).unwrap();

refine(
&res,
&dataset_d,
&queries_d,
&candidates_d,
DistanceType::L2Expanded,
&indices_d,
&distances_d,
)
.unwrap();

indices_d.to_host(&res, &mut indices_host).unwrap();
distances_d.to_host(&res, &mut distances_host).unwrap();
res.sync_stream().unwrap();

// Exact brute-force top-3, independent of the candidate ordering.
// q0: distances to (0.1,0.1): 0 -> ~0.14, 1 -> ~0.91, 2 -> ~0.91, ...
// point 0 is strictly nearest; 1 and 2 are tied next.
// q1: distances to (4.9,4.9): 4 -> ~0.14, 3 -> ~4.1, 5 -> ~5.8.
assert_eq!(
indices_host[[0, 0]],
0,
"q0 nearest must be repaired to index 0, got {:?}",
indices_host.row(0)
);
assert_eq!(
indices_host[[1, 0]],
4,
"q1 nearest must be repaired to index 4, got {:?}",
indices_host.row(1)
);

// The planted noise candidates (5 for q0, 0 for q1) must be evicted
// from the refined top-k.
let q0: Vec<i64> = indices_host.row(0).to_vec();
let q1: Vec<i64> = indices_host.row(1).to_vec();
assert!(!q0.contains(&5), "q0 must drop far candidate 5, got {:?}", q0);
assert!(!q1.contains(&0), "q1 must drop far candidate 0, got {:?}", q1);

// The refined top-3 sets must match the exact brute-force top-3 sets.
let mut q0_sorted = q0.clone();
q0_sorted.sort_unstable();
assert_eq!(q0_sorted, vec![0, 1, 2], "q0 refined set wrong: {:?}", q0);

let mut q1_sorted = q1.clone();
q1_sorted.sort_unstable();
assert_eq!(q1_sorted, vec![3, 4, 5], "q1 refined set wrong: {:?}", q1);

// Refined distances must be sorted ascending (nearest first) across
// the full top-k, and the first entry must be the small in-cluster
// distance, not noise.
for q in 0..2 {
for i in 0..2 {
assert!(
distances_host[[q, i]] <= distances_host[[q, i + 1]],
"q{q} distances not ascending at {i}: {:?}",
(distances_host[[q, i]], distances_host[[q, i + 1]])
);
}
}
assert!(
distances_host[[0, 0]] < 1.0,
"q0 nearest distance should be small, got {}",
distances_host[[0, 0]]
);
}
}
Loading