Skip to content

Commit 9b9b023

Browse files
add range_search() to IndexRefine
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
1 parent 5637bb8 commit 9b9b023

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

faiss/IndexRefine.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,45 @@ void IndexRefine::search(
166166
}
167167
}
168168

169+
void IndexRefine::range_search(
170+
idx_t n,
171+
const float* x,
172+
float radius,
173+
RangeSearchResult* result,
174+
const SearchParameters* params_in) const {
175+
const IndexRefineSearchParameters* params = nullptr;
176+
if (params_in) {
177+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
178+
FAISS_THROW_IF_NOT_MSG(
179+
params, "IndexRefine params have incorrect type");
180+
}
181+
182+
SearchParameters* base_index_params =
183+
(params != nullptr) ? params->base_index_params : nullptr;
184+
185+
base_index->range_search(n, x, radius, result, base_index_params);
186+
187+
#pragma omp parallel if (n > 1)
188+
{
189+
std::unique_ptr<DistanceComputer> dc(
190+
refine_index->get_distance_computer());
191+
192+
#pragma omp for
193+
for (idx_t i = 0; i < n; i++) {
194+
dc->set_query(x + i * d);
195+
196+
// reevaluate distances
197+
const size_t idx_start = result->lims[i];
198+
const size_t idx_end = result->lims[i + 1];
199+
200+
for (size_t j = idx_start; j < idx_end; j++) {
201+
const auto label = result->labels[j];
202+
result->distances[j] = (*dc)(label);
203+
}
204+
}
205+
}
206+
}
207+
169208
void IndexRefine::reconstruct(idx_t key, float* recons) const {
170209
refine_index->reconstruct(key, recons);
171210
}

faiss/IndexRefine.h

+7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ struct IndexRefine : Index {
5454
idx_t* labels,
5555
const SearchParameters* params = nullptr) const override;
5656

57+
void range_search(
58+
idx_t n,
59+
const float* x,
60+
float radius,
61+
RangeSearchResult* result,
62+
const SearchParameters* params = nullptr) const override;
63+
5764
// reconstruct is routed to the refine_index
5865
void reconstruct(idx_t key, float* recons) const override;
5966

0 commit comments

Comments
 (0)