Skip to content

Commit d0d3af7

Browse files
add range_search() to IndexRefine
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
1 parent adb1884 commit d0d3af7

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

faiss/IndexRefine.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,45 @@ void IndexRefine::search(
168168
}
169169
}
170170

171+
void IndexRefine::range_search(
172+
idx_t n,
173+
const float* x,
174+
float radius,
175+
RangeSearchResult* result,
176+
const SearchParameters* params_in) const {
177+
const IndexRefineSearchParameters* params = nullptr;
178+
if (params_in) {
179+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
180+
FAISS_THROW_IF_NOT_MSG(
181+
params, "IndexRefine params have incorrect type");
182+
}
183+
184+
SearchParameters* base_index_params =
185+
(params != nullptr) ? params->base_index_params : nullptr;
186+
187+
base_index->range_search(n, x, radius, result, base_index_params);
188+
189+
#pragma omp parallel if (n > 1)
190+
{
191+
std::unique_ptr<DistanceComputer> dc(
192+
refine_index->get_distance_computer());
193+
194+
#pragma omp for
195+
for (idx_t i = 0; i < n; i++) {
196+
dc->set_query(x + i * d);
197+
198+
// reevaluate distances
199+
const size_t idx_start = result->lims[i];
200+
const size_t idx_end = result->lims[i + 1];
201+
202+
for (size_t j = idx_start; j < idx_end; j++) {
203+
const auto label = result->labels[j];
204+
result->distances[j] = (*dc)(label);
205+
}
206+
}
207+
}
208+
}
209+
171210
void IndexRefine::reconstruct(idx_t key, float* recons) const {
172211
refine_index->reconstruct(key, recons);
173212
}

faiss/IndexRefine.h

+7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ struct IndexRefine : Index {
5454
idx_t* labels,
5555
const SearchParameters* params = nullptr) const override;
5656

57+
void range_search(
58+
idx_t n,
59+
const float* x,
60+
float radius,
61+
RangeSearchResult* result,
62+
const SearchParameters* params = nullptr) const override;
63+
5764
// reconstruct is routed to the refine_index
5865
void reconstruct(idx_t key, float* recons) const override;
5966

0 commit comments

Comments
 (0)