Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 147 additions & 1 deletion rust/cuvs/src/brute_force.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
*/
//! Brute Force KNN

use std::ffi::CString;
use std::io::{Write, stderr};
use std::path::Path;

use crate::distance_type::DistanceType;
use crate::dlpack::ManagedTensor;
use crate::error::{Result, check_cuvs};
use crate::error::{Error, Result, check_cuvs};
use crate::resources::Resources;

/// Brute Force KNN Index
Expand All @@ -20,6 +22,17 @@ pub struct Index {
_dataset: Option<ManagedTensor>,
}

/// Convert a filesystem path into a `CString` suitable for the cuVS C API,
/// returning `Error::InvalidArgument` instead of panicking for paths that are
/// not valid UTF-8 or that contain an interior NUL byte.
fn path_to_cstring(path: &Path) -> Result<CString> {
let path_str = path
.to_str()
.ok_or_else(|| Error::InvalidArgument(format!("path is not valid UTF-8: {path:?}")))?;
CString::new(path_str)
.map_err(|e| Error::InvalidArgument(format!("path contains an interior NUL byte: {e}")))
}

impl Index {
/// Builds a new Brute Force KNN Index from the dataset for efficient search.
///
Expand Down Expand Up @@ -87,6 +100,40 @@ impl Index {
))
}
}

/// Save the Brute Force index to file.
///
/// The serialization format can be subject to change, therefore loading an
/// index saved with a previous version of cuVS is not guaranteed to work.
///
/// # Arguments
///
/// * `res` - Resources to use
/// * `filename` - The file path for saving the index
pub fn serialize<P: AsRef<Path>>(&self, res: &Resources, filename: P) -> Result<()> {
let c_filename = path_to_cstring(filename.as_ref())?;
unsafe { check_cuvs(ffi::cuvsBruteForceSerialize(res.0, c_filename.as_ptr(), self.inner)) }
}

/// Load a Brute Force index from file.
///
/// The serialization format can be subject to change, therefore loading an
/// index saved with a previous version of cuVS is not guaranteed to work.
///
/// # Arguments
///
/// * `res` - Resources to use
/// * `filename` - The path of the file that stores the index
pub fn deserialize<P: AsRef<Path>>(res: &Resources, filename: P) -> Result<Index> {
let c_filename = path_to_cstring(filename.as_ref())?;
// Create the Index handle first so that any error path below still runs
// its `Drop` and releases the C-side index allocation.
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cuvsBruteForceDeserialize(res.0, c_filename.as_ptr(), index.inner))?;
}
Ok(index)
}
}

impl Drop for Index {
Expand Down Expand Up @@ -168,4 +215,103 @@ mod tests {
fn test_l2() {
test_bfknn(DistanceType::L2Expanded);
}

const N_DATAPOINTS: usize = 16;
const N_FEATURES: usize = 8;

/// Search the first `n_queries` rows of `dataset` against `index` and assert
/// each query finds itself as the top-1 neighbor.
fn search_and_verify_self_neighbors(
res: &Resources,
index: &Index,
dataset: &ndarray::Array2<f32>,
n_queries: usize,
k: usize,
) {
let queries = dataset.slice(s![0..n_queries, ..]);
let queries = ManagedTensor::from(&queries).to_device(res).unwrap();

let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from(&neighbors_host).to_device(res).unwrap();

let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from(&distances_host).to_device(res).unwrap();

index.search(res, &queries, &neighbors, &distances).expect("search failed");

distances.to_host(res, &mut distances_host).unwrap();
neighbors.to_host(res, &mut neighbors_host).unwrap();
res.sync_stream().unwrap();

for i in 0..n_queries {
assert_eq!(
neighbors_host[[i, 0]],
i as i64,
"query {i} should be its own nearest neighbor"
);
}
}

#[test]
fn test_brute_force_serialize_deserialize() {
let res = Resources::new().unwrap();

// Keep `dataset` (the host array) in this scope for the whole test: the
// device dataset view stored inside the index borrows its shape, so the
// host array must not be moved while the index is alive.
let dataset =
ndarray::Array::<f32, _>::random((N_DATAPOINTS, N_FEATURES), Uniform::new(0., 1.0));
let device_dataset = ManagedTensor::from(&dataset).to_device(&res).unwrap();
let index = Index::build(&res, DistanceType::L2Expanded, None, device_dataset)
.expect("failed to build brute force index");
res.sync_stream().unwrap();

let unique = format!(
"test_brute_force_index_{}_{}.bin",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_nanos()
);
let filepath = std::env::temp_dir().join(unique);
index.serialize(&res, &filepath).expect("failed to serialize brute force index");

assert!(filepath.exists(), "serialized index file should exist");
assert!(
std::fs::metadata(&filepath).unwrap().len() > 0,
"serialized index file should not be empty"
);

let loaded_index =
Index::deserialize(&res, &filepath).expect("failed to deserialize brute force index");

// The deserialized index should still find each query as its own
// nearest neighbor.
search_and_verify_self_neighbors(&res, &loaded_index, &dataset, 4, 4);

let _ = std::fs::remove_file(&filepath);
}

/// Passing a filename containing an interior NUL byte must surface as an
/// `InvalidArgument` error rather than panicking inside the serializer.
#[test]
fn test_brute_force_serialize_rejects_interior_nul() {
let res = Resources::new().unwrap();

let dataset =
ndarray::Array::<f32, _>::random((N_DATAPOINTS, N_FEATURES), Uniform::new(0., 1.0));
let device_dataset = ManagedTensor::from(&dataset).to_device(&res).unwrap();
let index = Index::build(&res, DistanceType::L2Expanded, None, device_dataset)
.expect("failed to build brute force index");
res.sync_stream().unwrap();

// `PathBuf::from` on Unix preserves arbitrary bytes, so we can embed a
// NUL byte in the path and confirm the helper rejects it.
let bad_path = std::path::PathBuf::from("/tmp/has\0nul.bin");
let err = index
.serialize(&res, &bad_path)
.expect_err("serialize should reject paths with interior NUL");
assert!(matches!(err, Error::InvalidArgument(_)), "expected InvalidArgument, got {err:?}");
}
}
Loading