Skip to content

Commit c5edaa0

Browse files
authored
Merge pull request #36 from miamia0/optimize
[fix] fix panic for can not search all the vector for not clear the searcher.seen when lower_search
2 parents 1caa704 + 4ce5e6d commit c5edaa0

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

src/hnsw/hnsw_const.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ where
6464
params,
6565
}
6666
}
67+
68+
pub fn new_with_capacity(metric: Met, params: Params, capacity: usize) -> Self {
69+
Self {
70+
metric,
71+
zero: Vec::with_capacity(capacity),
72+
features: Vec::with_capacity(capacity),
73+
layers: vec![],
74+
prng: R::from_seed(R::Seed::default()),
75+
params,
76+
}
77+
}
6778
}
6879

6980
impl<Met, T, R, const M: usize, const M0: usize> Knn for Hnsw<Met, T, R, M, M0>
@@ -374,12 +385,14 @@ where
374385
// See Algorithm 5 line 5 of the paper. The paper makes no further comment on why `1` was chosen.
375386
let &Neighbor { index, distance } = searcher.nearest.first().unwrap();
376387
searcher.nearest.clear();
388+
searcher.seen.clear();
377389
// Update the node to the next layer.
378390
let new_index = layer[index].next_node as usize;
379391
let candidate = Neighbor {
380392
index: new_index,
381393
distance,
382394
};
395+
searcher.seen.insert(layer[index].zero_node);
383396
// Insert the index of the nearest neighbor into the nearest pool for the next layer.
384397
searcher.nearest.push(candidate);
385398
// Insert the index into the candidate pool as well.

tests/simple.rs

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Useful tests for debugging since they are hand-written and easy to see the debugging output.
22
33
use hnsw::{Hnsw, Searcher};
4+
use itertools::Itertools;
45
use rand_pcg::Pcg64;
56
use space::{Metric, Neighbor};
67

@@ -18,13 +19,43 @@ impl Metric<&[f64]> for Euclidean {
1819
}
1920
}
2021

22+
struct TestBruteForceHelper {
23+
vectors: Vec<(usize, Vec<f64>)>,
24+
}
25+
26+
impl TestBruteForceHelper {
27+
fn new() -> Self {
28+
Self {
29+
vectors: Vec::new(),
30+
}
31+
}
32+
33+
fn push(&mut self, v: (usize, Vec<f64>)) {
34+
self.vectors.push(v);
35+
}
36+
37+
fn search(&self, query: &[f64], top_k: usize) -> Vec<usize> {
38+
let metric = Euclidean;
39+
let mut candidates: Vec<(usize, u64)> = self
40+
.vectors
41+
.iter()
42+
.map(|v| (v.0.clone(), metric.distance(&query, &v.1.as_slice())))
43+
.collect_vec();
44+
45+
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
46+
47+
candidates.into_iter().take(top_k).map(|v| v.0).collect()
48+
}
49+
}
50+
2151
fn test_hnsw() -> (
2252
Hnsw<Euclidean, &'static [f64], Pcg64, 12, 24>,
2353
Searcher<u64>,
54+
TestBruteForceHelper,
2455
) {
2556
let mut searcher = Searcher::default();
2657
let mut hnsw = Hnsw::new(Euclidean);
27-
58+
let mut helper = TestBruteForceHelper::new();
2859
let features = [
2960
&[0.0, 0.0, 0.0, 1.0],
3061
&[0.0, 0.0, 1.0, 0.0],
@@ -36,11 +67,12 @@ fn test_hnsw() -> (
3667
&[1.0, 0.0, 0.0, 1.0],
3768
];
3869

39-
for &feature in &features {
40-
hnsw.insert(feature, &mut searcher);
70+
for (index, feature) in features.iter().enumerate() {
71+
helper.push((index, feature.to_vec()));
72+
hnsw.insert(*feature, &mut searcher);
4173
}
4274

43-
(hnsw, searcher)
75+
(hnsw, searcher, helper)
4476
}
4577

4678
#[test]
@@ -50,7 +82,7 @@ fn insertion() {
5082

5183
#[test]
5284
fn nearest_neighbor() {
53-
let (hnsw, mut searcher) = test_hnsw();
85+
let (hnsw, mut searcher, helper) = test_hnsw();
5486
let searcher = &mut searcher;
5587
let mut neighbors = [Neighbor {
5688
index: !0,
@@ -101,4 +133,17 @@ fn nearest_neighbor() {
101133
}
102134
]
103135
);
136+
// test for not panicking
137+
for topk in 0..8 {
138+
let mut neighbors = vec![
139+
Neighbor {
140+
index: !0,
141+
distance: !0,
142+
};
143+
topk
144+
];
145+
hnsw.nearest(&&[0.0, 0.0, 0.0, 1.0][..], 24, searcher, &mut neighbors);
146+
let result = neighbors.iter().map(|item| item.index).collect_vec();
147+
assert_eq!(result, helper.search(&[0.0, 0.0, 0.0, 1.0], topk));
148+
}
104149
}

0 commit comments

Comments
 (0)