Skip to content

Commit 162e6ce

Browse files
add range_search() to IndexRefine (#4022)
Summary: This is very convenient to have `range_seach()` in `IndexRefine`. Unlike the plain `search()` method, `range_search()` just reevaluates the computed distances from the baseline index. The labels are not re-sorted according to new distances, because this is not listed as a requirement in a method description https://github.com/facebookresearch/faiss/blob/adb188411a98c3af5b7295c7016e5f46fee9eb07/faiss/Index.h#L150-L161 https://github.com/facebookresearch/faiss/blob/adb188411a98c3af5b7295c7016e5f46fee9eb07/faiss/impl/AuxIndexStructures.h#L35 Pull Request resolved: #4022 Reviewed By: mnorris11 Differential Revision: D66116082 Pulled By: gtwang01 fbshipit-source-id: 915aca2570d5863c876c9497d4c885e270b9b220
1 parent 9590ad2 commit 162e6ce

File tree

3 files changed

+97
-1
lines changed

3 files changed

+97
-1
lines changed

faiss/IndexRefine.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,45 @@ void IndexRefine::search(
166166
}
167167
}
168168

169+
void IndexRefine::range_search(
170+
idx_t n,
171+
const float* x,
172+
float radius,
173+
RangeSearchResult* result,
174+
const SearchParameters* params_in) const {
175+
const IndexRefineSearchParameters* params = nullptr;
176+
if (params_in) {
177+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
178+
FAISS_THROW_IF_NOT_MSG(
179+
params, "IndexRefine params have incorrect type");
180+
}
181+
182+
SearchParameters* base_index_params =
183+
(params != nullptr) ? params->base_index_params : nullptr;
184+
185+
base_index->range_search(n, x, radius, result, base_index_params);
186+
187+
#pragma omp parallel if (n > 1)
188+
{
189+
std::unique_ptr<DistanceComputer> dc(
190+
refine_index->get_distance_computer());
191+
192+
#pragma omp for
193+
for (idx_t i = 0; i < n; i++) {
194+
dc->set_query(x + i * d);
195+
196+
// reevaluate distances
197+
const size_t idx_start = result->lims[i];
198+
const size_t idx_end = result->lims[i + 1];
199+
200+
for (size_t j = idx_start; j < idx_end; j++) {
201+
const auto label = result->labels[j];
202+
result->distances[j] = (*dc)(label);
203+
}
204+
}
205+
}
206+
}
207+
169208
void IndexRefine::reconstruct(idx_t key, float* recons) const {
170209
refine_index->reconstruct(key, recons);
171210
}

faiss/IndexRefine.h

+7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ struct IndexRefine : Index {
5454
idx_t* labels,
5555
const SearchParameters* params = nullptr) const override;
5656

57+
void range_search(
58+
idx_t n,
59+
const float* x,
60+
float radius,
61+
RangeSearchResult* result,
62+
const SearchParameters* params = nullptr) const override;
63+
5764
// reconstruct is routed to the refine_index
5865
void reconstruct(idx_t key, float* recons) const override;
5966

tests/test_refine.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import unittest
99
import faiss
1010

11-
from faiss.contrib import datasets
11+
from faiss.contrib import datasets, evaluation
1212

1313

1414
class TestDistanceComputer(unittest.TestCase):
@@ -119,3 +119,53 @@ def test_rflat(self):
119119
def test_refine_sq8(self):
120120
# this case uses the IndexRefine class
121121
self.do_test("IVF8,PQ2x4np,Refine(SQ8)")
122+
123+
124+
class TestIndexRefineRangeSearch(unittest.TestCase):
125+
126+
def do_test(self, factory_string):
127+
d = 32
128+
radius = 8
129+
130+
ds = datasets.SyntheticDataset(d, 1024, 512, 256)
131+
132+
index = faiss.index_factory(d, factory_string)
133+
index.train(ds.get_train())
134+
index.add(ds.get_database())
135+
xq = ds.get_queries()
136+
xb = ds.get_database()
137+
138+
# perform a range_search
139+
lims_1, D1, I1 = index.range_search(xq, radius)
140+
141+
# create a baseline (FlatL2)
142+
index_flat = faiss.IndexFlatL2(d)
143+
index_flat.train(ds.get_train())
144+
index_flat.add(ds.get_database())
145+
146+
lims_ref, Dref, Iref = index_flat.range_search(xq, radius)
147+
148+
# add a refine index on top of the index
149+
index_r = faiss.IndexRefine(index, index_flat)
150+
lims_2, D2, I2 = index_r.range_search(xq, radius)
151+
152+
# validate: refined range_search() keeps indices untouched
153+
precision_1, recall_1 = evaluation.range_PR(lims_ref, Iref, lims_1, I1)
154+
155+
precision_2, recall_2 = evaluation.range_PR(lims_ref, Iref, lims_2, I2)
156+
157+
self.assertAlmostEqual(recall_1, recall_2)
158+
159+
# validate: refined range_search() updates distances, and new distances are correct L2 distances
160+
for iq in range(0, ds.nq):
161+
start_lim = lims_2[iq]
162+
end_lim = lims_2[iq + 1]
163+
for i_lim in range(start_lim, end_lim):
164+
idx = I2[i_lim]
165+
l2_dis = np.sum(np.square(xq[iq : iq + 1,] - xb[idx : idx + 1,]))
166+
167+
self.assertAlmostEqual(l2_dis, D2[i_lim], places=4)
168+
169+
170+
def test_refine_1(self):
171+
self.do_test("SQ4")

0 commit comments

Comments
 (0)