Skip to content

Commit 399388e

Browse files
mengdilinfacebook-github-bot
authored andcommitted
interrupt for NNDescent (facebookresearch#3432)
Summary: Addresses the issue in facebookresearch#3173 for `IndexNNDescent`, I see that there is already interrupt implemented for it's [search](https://fburl.com/code/iwn3tqic) API, so I looked into it's `add` API. For a given dataset nb = 10 mil, iter = 10, K = 32, d = 32 on a CPU only machine reveals that bulk of the cost comes from [nndescent](https://fburl.com/code/5rdb1p5o). For every iteration of `nndescent` takes around ~12 seconds, ~70-80% of the time is spent on `join` method (~10 seconds per iteration) and ~20-30% spent on `update` (~2 second per iteration). Adding the interrupt on the `join` should suffice on quickly terminating the program when users hit ctrl+C (happy to move the interrupt elsewhere if we think otherwise) Reviewed By: junjieqi, mdouze Differential Revision: D57300514
1 parent 4d06d70 commit 399388e

File tree

3 files changed

+59
-15
lines changed

3 files changed

+59
-15
lines changed

.circleci/config.yml

+3-7
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,12 @@ jobs:
3939
name: Install clang-format
4040
command: |
4141
apt-get update -y
42-
apt-get install -y wget
43-
apt install -y lsb-release wget software-properties-common gnupg
44-
wget https://apt.llvm.org/llvm.sh
45-
chmod u+x llvm.sh
46-
./llvm.sh 18
47-
apt-get install -y git-core clang-format-18
42+
apt-get install -y curl tar gzip
43+
bash .circleci/setup-clang-format.sh
4844
- run:
4945
name: Verify clang-format
5046
command: |
51-
git ls-files | grep -E '\.(cpp|h|cu|cuh)$' | xargs clang-format-18 -i
47+
git ls-files | grep -E '\.(cpp|h|cu|cuh)$' | xargs clang-format -i
5248
if git diff --quiet; then
5349
echo "Formatting OK!"
5450
else

.circleci/setup-clang-format.sh

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/bin/bash
2+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3+
4+
set -euo pipefail
5+
# https://reproducible-builds.org/docs/archives/
6+
deterministic_tar_gz() {
7+
# Use year 2030 to thwart tmpreaper.
8+
tar \
9+
--sort=name \
10+
--mtime=2030-01-01T00:00:00Z \
11+
--owner=0 --group=0 --numeric-owner \
12+
-cf- \
13+
"${@:2}" \
14+
| gzip -9n \
15+
> "$1"
16+
}
17+
18+
# To curl from devservers
19+
if host -W 1 fwdproxy >/dev/null; then
20+
curl() { HTTP_PROXY=fwdproxy:8080 HTTPS_PROXY=fwdproxy:8080 command curl "$@"; }
21+
fi
22+
23+
### CHANGE THESE VARIABLES WHEN UPDATING ###
24+
25+
# https://pypi.org/project/clang-format/18.1.3
26+
27+
28+
LINUX_X86_64_URL=https://files.pythonhosted.org/packages/d5/9c/4f3806d20397790b3cd80aef89d295bf399581804f5c5758b6207e54e902/clang_format-18.1.3-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
29+
30+
###
31+
NAME=${LINUX_X86_64_URL##*/}
32+
33+
set -x
34+
curl -L -o "$NAME.zip" "$LINUX_X86_64_URL"
35+
mkdir "clang-format"
36+
unzip -q -d "clang-format" "$NAME.zip"
37+
echo "clang-format/clang_format/data/bin"
38+
deterministic_tar_gz "$NAME.tar.gz" -C "clang-format/clang_format/data/bin" "clang-format"
39+
tar tvf "$NAME.tar.gz"
40+
41+
echo "$PWD"
42+
export PATH="${PWD}/clang-format/clang_format/data/bin:$PATH"
43+
echo $PATH

faiss/impl/NNDescent.cpp

+13-8
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,20 @@ NNDescent::NNDescent(const int d, const int K) : K(K), d(d) {
154154
NNDescent::~NNDescent() {}
155155

156156
void NNDescent::join(DistanceComputer& qdis) {
157+
idx_t check_period = InterruptCallback::get_period_hint(d * search_L);
158+
for (idx_t i0 = 0; i0 < (idx_t)ntotal; i0 += check_period) {
159+
idx_t i1 = std::min(i0 + check_period, (idx_t)ntotal);
157160
#pragma omp parallel for default(shared) schedule(dynamic, 100)
158-
for (int n = 0; n < ntotal; n++) {
159-
graph[n].join([&](int i, int j) {
160-
if (i != j) {
161-
float dist = qdis.symmetric_dis(i, j);
162-
graph[i].insert(j, dist);
163-
graph[j].insert(i, dist);
164-
}
165-
});
161+
for (idx_t n = i0; n < i1; n++) {
162+
graph[n].join([&](int i, int j) {
163+
if (i != j) {
164+
float dist = qdis.symmetric_dis(i, j);
165+
graph[i].insert(j, dist);
166+
graph[j].insert(i, dist);
167+
}
168+
});
169+
}
170+
InterruptCallback::check();
166171
}
167172
}
168173

0 commit comments

Comments
 (0)