Skip to content

Commit c080f3c

Browse files
Michael Norrisfacebook-github-bot
Michael Norris
authored andcommitted
Add more unit testing for IndexHNSW [1/n] (facebookresearch#4054)
Summary: Part 1 of more HNSW unit tests Reviewed By: junjieqi Differential Revision: D66690398
1 parent 1ac6f37 commit c080f3c

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

tests/test_graph_based.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ def test_hnsw_unbounded_queue(self):
7373

7474
self.io_and_retest(index, Dhnsw, Ihnsw)
7575

76+
def test_hnsw_no_init_level0(self):
77+
d = self.xq.shape[1]
78+
79+
index = faiss.IndexHNSWFlat(d, 16)
80+
index.init_level0 = False
81+
index.add(self.xb)
82+
Dhnsw, Ihnsw = index.search(self.xq, 1)
83+
84+
self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 25)
85+
86+
self.io_and_retest(index, Dhnsw, Ihnsw)
87+
7688
def io_and_retest(self, index, Dhnsw, Ihnsw):
7789
index2 = faiss.deserialize_index(faiss.serialize_index(index))
7890
Dhnsw2, Ihnsw2 = index2.search(self.xq, 1)
@@ -175,16 +187,31 @@ def test_abs_inner_product(self):
175187
xb = self.xb - self.xb.mean(axis=0) # need to be centered to give interesting directions
176188
xq = self.xq - self.xq.mean(axis=0)
177189
Dref, Iref = faiss.knn(xq, xb, 10, faiss.METRIC_ABS_INNER_PRODUCT)
178-
190+
179191
index = faiss.IndexHNSWFlat(d, 32, faiss.METRIC_ABS_INNER_PRODUCT)
180192
index.add(xb)
181193
Dnew, Inew = index.search(xq, 10)
182194

183195
inter = faiss.eval_intersection(Iref, Inew)
184196
# 4769 vs. 500*10
185197
self.assertGreater(inter, Iref.size * 0.9)
186-
187-
198+
199+
def test_hnsw_reset(self):
200+
d = self.xb.shape[1]
201+
index_flat = faiss.IndexFlat(d)
202+
index_flat.add(self.xb)
203+
self.assertEqual(index_flat.ntotal, self.xb.shape[0])
204+
index_hnsw = faiss.IndexHNSW(index_flat)
205+
index_hnsw.add(self.xb)
206+
# * 2 because we add to storage twice. This is just for testing
207+
# that storage gets cleared correctly.
208+
self.assertEqual(index_hnsw.ntotal, self.xb.shape[0] * 2)
209+
210+
index_hnsw.reset()
211+
212+
self.assertEqual(index_flat.ntotal, 0)
213+
self.assertEqual(index_hnsw.ntotal, 0)
214+
188215
class Issue3684(unittest.TestCase):
189216

190217
def test_issue3684(self):

0 commit comments

Comments
 (0)