Skip to content

Commit d14e700

Browse files
mengdilinfacebook-github-bot
authored andcommitted
interrupt for NNDescent (facebookresearch#3432)
Summary: Addresses the issue in facebookresearch#3173 for `IndexNNDescent`, I see that there is already interrupt implemented for it's [search](https://fburl.com/code/iwn3tqic) API, so I looked into it's `add` API. For a given dataset nb = 10 mil, iter = 10, K = 32, d = 32 on a CPU only machine reveals that bulk of the cost comes from [nndescent](https://fburl.com/code/5rdb1p5o). For every iteration of `nndescent` takes around ~12 seconds, ~70-80% of the time is spent on `join` method (~10 seconds per iteration) and ~20-30% spent on `update` (~2 second per iteration). Adding the interrupt on the `join` should suffice on quickly terminating the program when users hit ctrl+C (happy to move the interrupt elsewhere if we think otherwise) Reviewed By: junjieqi, mdouze Differential Revision: D57300514
1 parent 4d06d70 commit d14e700

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

.circleci/config.yml

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jobs:
4848
- run:
4949
name: Verify clang-format
5050
command: |
51+
clang-format-18 -- version
5152
git ls-files | grep -E '\.(cpp|h|cu|cuh)$' | xargs clang-format-18 -i
5253
if git diff --quiet; then
5354
echo "Formatting OK!"

faiss/impl/NNDescent.cpp

+13-8
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,20 @@ NNDescent::NNDescent(const int d, const int K) : K(K), d(d) {
154154
NNDescent::~NNDescent() {}
155155

156156
void NNDescent::join(DistanceComputer& qdis) {
157+
idx_t check_period = InterruptCallback::get_period_hint(d * search_L);
158+
for (idx_t i0 = 0; i0 < (idx_t)ntotal; i0 += check_period) {
159+
idx_t i1 = std::min(i0 + check_period, (idx_t)ntotal);
157160
#pragma omp parallel for default(shared) schedule(dynamic, 100)
158-
for (int n = 0; n < ntotal; n++) {
159-
graph[n].join([&](int i, int j) {
160-
if (i != j) {
161-
float dist = qdis.symmetric_dis(i, j);
162-
graph[i].insert(j, dist);
163-
graph[j].insert(i, dist);
164-
}
165-
});
161+
for (idx_t n = i0; n < i1; n++) {
162+
graph[n].join([&](int i, int j) {
163+
if (i != j) {
164+
float dist = qdis.symmetric_dis(i, j);
165+
graph[i].insert(j, dist);
166+
graph[j].insert(i, dist);
167+
}
168+
});
169+
}
170+
InterruptCallback::check();
166171
}
167172
}
168173

0 commit comments

Comments
 (0)