|
8 | 8 | import unittest
|
9 | 9 | import faiss
|
10 | 10 |
|
11 |
| -from faiss.contrib import datasets |
| 11 | +from faiss.contrib import datasets, evaluation |
12 | 12 |
|
13 | 13 |
|
14 | 14 | class TestDistanceComputer(unittest.TestCase):
|
@@ -119,3 +119,53 @@ def test_rflat(self):
|
119 | 119 | def test_refine_sq8(self):
|
120 | 120 | # this case uses the IndexRefine class
|
121 | 121 | self.do_test("IVF8,PQ2x4np,Refine(SQ8)")
|
| 122 | + |
| 123 | + |
| 124 | +class TestIndexRefineRangeSearch(unittest.TestCase): |
| 125 | + |
| 126 | + def do_test(self, factory_string): |
| 127 | + d = 32 |
| 128 | + radius = 8 |
| 129 | + |
| 130 | + ds = datasets.SyntheticDataset(d, 1024, 512, 256) |
| 131 | + |
| 132 | + index = faiss.index_factory(d, factory_string) |
| 133 | + index.train(ds.get_train()) |
| 134 | + index.add(ds.get_database()) |
| 135 | + xq = ds.get_queries() |
| 136 | + xb = ds.get_database() |
| 137 | + |
| 138 | + # perform a range_search |
| 139 | + lims_1, D1, I1 = index.range_search(xq, radius) |
| 140 | + |
| 141 | + # create a baseline (FlatL2) |
| 142 | + index_flat = faiss.IndexFlatL2(d) |
| 143 | + index_flat.train(ds.get_train()) |
| 144 | + index_flat.add(ds.get_database()) |
| 145 | + |
| 146 | + lims_ref, Dref, Iref = index_flat.range_search(xq, radius) |
| 147 | + |
| 148 | + # add a refine index on top of the index |
| 149 | + index_r = faiss.IndexRefine(index, index_flat) |
| 150 | + lims_2, D2, I2 = index_r.range_search(xq, radius) |
| 151 | + |
| 152 | + # validate: refined range_search() keeps indices untouched |
| 153 | + precision_1, recall_1 = evaluation.range_PR(lims_ref, Iref, lims_1, I1) |
| 154 | + |
| 155 | + precision_2, recall_2 = evaluation.range_PR(lims_ref, Iref, lims_2, I2) |
| 156 | + |
| 157 | + self.assertAlmostEqual(recall_1, recall_2) |
| 158 | + |
| 159 | + # validate: refined range_search() updates distances, and new distances are correct L2 distances |
| 160 | + for iq in range(0, ds.nq): |
| 161 | + start_lim = lims_2[iq] |
| 162 | + end_lim = lims_2[iq + 1] |
| 163 | + for i_lim in range(start_lim, end_lim): |
| 164 | + idx = I2[i_lim] |
| 165 | + l2_dis = np.sum(np.square(xq[iq : iq + 1,] - xb[idx : idx + 1,])) |
| 166 | + |
| 167 | + self.assertAlmostEqual(l2_dis, D2[i_lim], places=4) |
| 168 | + |
| 169 | + |
| 170 | + def test_refine_1(self): |
| 171 | + self.do_test("SQ4") |
0 commit comments