diff --git a/rust/cuvs/src/brute_force.rs b/rust/cuvs/src/brute_force.rs index 413e8b0fb1..adead3efba 100644 --- a/rust/cuvs/src/brute_force.rs +++ b/rust/cuvs/src/brute_force.rs @@ -4,11 +4,13 @@ */ //! Brute Force KNN +use std::ffi::CString; use std::io::{Write, stderr}; +use std::path::Path; use crate::distance_type::DistanceType; use crate::dlpack::ManagedTensor; -use crate::error::{Result, check_cuvs}; +use crate::error::{Error, Result, check_cuvs}; use crate::resources::Resources; /// Brute Force KNN Index @@ -20,6 +22,17 @@ pub struct Index { _dataset: Option, } +/// Convert a filesystem path into a `CString` suitable for the cuVS C API, +/// returning `Error::InvalidArgument` instead of panicking for paths that are +/// not valid UTF-8 or that contain an interior NUL byte. +fn path_to_cstring(path: &Path) -> Result { + let path_str = path + .to_str() + .ok_or_else(|| Error::InvalidArgument(format!("path is not valid UTF-8: {path:?}")))?; + CString::new(path_str) + .map_err(|e| Error::InvalidArgument(format!("path contains an interior NUL byte: {e}"))) +} + impl Index { /// Builds a new Brute Force KNN Index from the dataset for efficient search. /// @@ -87,6 +100,40 @@ impl Index { )) } } + + /// Save the Brute Force index to file. + /// + /// The serialization format can be subject to change, therefore loading an + /// index saved with a previous version of cuVS is not guaranteed to work. + /// + /// # Arguments + /// + /// * `res` - Resources to use + /// * `filename` - The file path for saving the index + pub fn serialize>(&self, res: &Resources, filename: P) -> Result<()> { + let c_filename = path_to_cstring(filename.as_ref())?; + unsafe { check_cuvs(ffi::cuvsBruteForceSerialize(res.0, c_filename.as_ptr(), self.inner)) } + } + + /// Load a Brute Force index from file. + /// + /// The serialization format can be subject to change, therefore loading an + /// index saved with a previous version of cuVS is not guaranteed to work. + /// + /// # Arguments + /// + /// * `res` - Resources to use + /// * `filename` - The path of the file that stores the index + pub fn deserialize>(res: &Resources, filename: P) -> Result { + let c_filename = path_to_cstring(filename.as_ref())?; + // Create the Index handle first so that any error path below still runs + // its `Drop` and releases the C-side index allocation. + let index = Index::new()?; + unsafe { + check_cuvs(ffi::cuvsBruteForceDeserialize(res.0, c_filename.as_ptr(), index.inner))?; + } + Ok(index) + } } impl Drop for Index { @@ -168,4 +215,103 @@ mod tests { fn test_l2() { test_bfknn(DistanceType::L2Expanded); } + + const N_DATAPOINTS: usize = 16; + const N_FEATURES: usize = 8; + + /// Search the first `n_queries` rows of `dataset` against `index` and assert + /// each query finds itself as the top-1 neighbor. + fn search_and_verify_self_neighbors( + res: &Resources, + index: &Index, + dataset: &ndarray::Array2, + n_queries: usize, + k: usize, + ) { + let queries = dataset.slice(s![0..n_queries, ..]); + let queries = ManagedTensor::from(&queries).to_device(res).unwrap(); + + let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); + let neighbors = ManagedTensor::from(&neighbors_host).to_device(res).unwrap(); + + let mut distances_host = ndarray::Array::::zeros((n_queries, k)); + let distances = ManagedTensor::from(&distances_host).to_device(res).unwrap(); + + index.search(res, &queries, &neighbors, &distances).expect("search failed"); + + distances.to_host(res, &mut distances_host).unwrap(); + neighbors.to_host(res, &mut neighbors_host).unwrap(); + res.sync_stream().unwrap(); + + for i in 0..n_queries { + assert_eq!( + neighbors_host[[i, 0]], + i as i64, + "query {i} should be its own nearest neighbor" + ); + } + } + + #[test] + fn test_brute_force_serialize_deserialize() { + let res = Resources::new().unwrap(); + + // Keep `dataset` (the host array) in this scope for the whole test: the + // device dataset view stored inside the index borrows its shape, so the + // host array must not be moved while the index is alive. + let dataset = + ndarray::Array::::random((N_DATAPOINTS, N_FEATURES), Uniform::new(0., 1.0)); + let device_dataset = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let index = Index::build(&res, DistanceType::L2Expanded, None, device_dataset) + .expect("failed to build brute force index"); + res.sync_stream().unwrap(); + + let unique = format!( + "test_brute_force_index_{}_{}.bin", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_nanos() + ); + let filepath = std::env::temp_dir().join(unique); + index.serialize(&res, &filepath).expect("failed to serialize brute force index"); + + assert!(filepath.exists(), "serialized index file should exist"); + assert!( + std::fs::metadata(&filepath).unwrap().len() > 0, + "serialized index file should not be empty" + ); + + let loaded_index = + Index::deserialize(&res, &filepath).expect("failed to deserialize brute force index"); + + // The deserialized index should still find each query as its own + // nearest neighbor. + search_and_verify_self_neighbors(&res, &loaded_index, &dataset, 4, 4); + + let _ = std::fs::remove_file(&filepath); + } + + /// Passing a filename containing an interior NUL byte must surface as an + /// `InvalidArgument` error rather than panicking inside the serializer. + #[test] + fn test_brute_force_serialize_rejects_interior_nul() { + let res = Resources::new().unwrap(); + + let dataset = + ndarray::Array::::random((N_DATAPOINTS, N_FEATURES), Uniform::new(0., 1.0)); + let device_dataset = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let index = Index::build(&res, DistanceType::L2Expanded, None, device_dataset) + .expect("failed to build brute force index"); + res.sync_stream().unwrap(); + + // `PathBuf::from` on Unix preserves arbitrary bytes, so we can embed a + // NUL byte in the path and confirm the helper rejects it. + let bad_path = std::path::PathBuf::from("/tmp/has\0nul.bin"); + let err = index + .serialize(&res, &bad_path) + .expect_err("serialize should reject paths with interior NUL"); + assert!(matches!(err, Error::InvalidArgument(_)), "expected InvalidArgument, got {err:?}"); + } }