Skip to content

Commit dc55e11

Browse files
mengdilinfacebook-github-bot
authored andcommitted
add hnsw flat benchmark (facebookresearch#3857)
Summary: Pull Request resolved: facebookresearch#3857 add benchmarking for hnsw flat. ServiceLab requires us to register our python benchmarks with their custom python function in order to export the metrics correctly. I decided to split servicelab custom code inside `faiss/perf_tests/servicelab` folder as to not expose it to open source and the actual benchmarking logic for `hnsw` lives in `faiss/perf_tests/bench_hnsw.py` which will be exposed to open source Reviewed By: kuarora Differential Revision: D62316706 fbshipit-source-id: 6f88ed70ae78fa309a347371645fb012e25b55da
1 parent d104275 commit dc55e11

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed

faiss/perf_tests/bench_hnsw.py

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import argparse
2+
import resource
3+
import time
4+
from contextlib import contextmanager
5+
from dataclasses import dataclass
6+
from typing import Dict, Generator, List, Optional
7+
8+
import faiss # @manual=//faiss/python:pyfaiss
9+
import numpy as np
10+
from faiss.contrib.datasets import ( # @manual=//faiss/contrib:faiss_contrib
11+
Dataset,
12+
SyntheticDataset,
13+
)
14+
15+
US_IN_S = 1_000_000
16+
17+
18+
@dataclass
19+
class PerfCounters:
20+
wall_time_s: float = 0.0
21+
user_time_s: float = 0.0
22+
system_time_s: float = 0.0
23+
24+
25+
@contextmanager
26+
def timed_execution() -> Generator[PerfCounters, None, None]:
27+
pcounters = PerfCounters()
28+
wall_time_start = time.perf_counter()
29+
rusage_start = resource.getrusage(resource.RUSAGE_SELF)
30+
yield pcounters
31+
wall_time_end = time.perf_counter()
32+
rusage_end = resource.getrusage(resource.RUSAGE_SELF)
33+
pcounters.wall_time_s = wall_time_end - wall_time_start
34+
pcounters.user_time_s = rusage_end.ru_utime - rusage_start.ru_utime
35+
pcounters.system_time_s = rusage_end.ru_stime - rusage_start.ru_stime
36+
37+
38+
def is_perf_counter(key: str) -> bool:
39+
return key.endswith("_time_us")
40+
41+
42+
def accumulate_perf_counter(
43+
phase: str,
44+
t: PerfCounters,
45+
counters: Dict[str, int]
46+
):
47+
counters[f"{phase}_wall_time_us"] = int(t.wall_time_s * US_IN_S)
48+
counters[f"{phase}_user_time_us"] = int(t.user_time_s * US_IN_S)
49+
counters[f"{phase}_system_time_us"] = int(t.system_time_s * US_IN_S)
50+
51+
52+
def run_on_dataset(
53+
ds: Dataset,
54+
M: int,
55+
num_threads:
56+
int,
57+
efSearch: int = 16,
58+
efConstruction: int = 40
59+
) -> Dict[str, int]:
60+
xq = ds.get_queries()
61+
xb = ds.get_database()
62+
63+
nb, d = xb.shape
64+
nq, d = xq.shape
65+
66+
k = 10
67+
# pyre-ignore[16]: Module `faiss` has no attribute `omp_set_num_threads`.
68+
faiss.omp_set_num_threads(num_threads)
69+
index = faiss.IndexHNSWFlat(d, M)
70+
index.hnsw.efConstruction = 40 # default
71+
with timed_execution() as t:
72+
index.add(xb)
73+
counters = {}
74+
accumulate_perf_counter("add", t, counters)
75+
counters["nb"] = nb
76+
77+
index.hnsw.efSearch = efSearch
78+
with timed_execution() as t:
79+
D, I = index.search(xq, k)
80+
accumulate_perf_counter("search", t, counters)
81+
counters["nq"] = nq
82+
counters["efSearch"] = efSearch
83+
counters["efConstruction"] = efConstruction
84+
counters["M"] = M
85+
counters["d"] = d
86+
87+
return counters
88+
89+
90+
def run(
91+
d: int,
92+
nb: int,
93+
nq: int,
94+
M: int,
95+
num_threads: int,
96+
efSearch: int = 16,
97+
efConstruction: int = 40,
98+
) -> Dict[str, int]:
99+
ds = SyntheticDataset(d=d, nb=nb, nt=0, nq=nq, metric="L2", seed=1338)
100+
return run_on_dataset(
101+
ds,
102+
M=M,
103+
num_threads=num_threads,
104+
efSearch=efSearch,
105+
efConstruction=efConstruction,
106+
)
107+
108+
109+
def _merge_counters(
110+
element: Dict[str, int], accu: Optional[Dict[str, int]] = None
111+
) -> Dict[str, int]:
112+
if accu is None:
113+
return dict(element)
114+
else:
115+
assert accu.keys() <= element.keys(), (
116+
"Accu keys must be a subset of element keys: "
117+
f"{accu.keys()} not a subset of {element.keys()}"
118+
)
119+
for key in accu.keys():
120+
if is_perf_counter(key):
121+
accu[key] += element[key]
122+
return accu
123+
124+
125+
def run_with_iterations(
126+
iterations: int,
127+
d: int,
128+
nb: int,
129+
nq: int,
130+
M: int,
131+
num_threads: int,
132+
efSearch: int = 16,
133+
efConstruction: int = 40,
134+
) -> Dict[str, int]:
135+
result = None
136+
for _ in range(iterations):
137+
counters = run(
138+
d=d,
139+
nb=nb,
140+
nq=nq,
141+
M=M,
142+
num_threads=num_threads,
143+
efSearch=efSearch,
144+
efConstruction=efConstruction,
145+
)
146+
result = _merge_counters(counters, result)
147+
assert result is not None
148+
return result
149+
150+
151+
def _accumulate_counters(
152+
element: Dict[str, int], accu: Optional[Dict[str, List[int]]] = None
153+
) -> Dict[str, List[int]]:
154+
if accu is None:
155+
accu = {key: [value] for key, value in element.items()}
156+
return accu
157+
else:
158+
assert accu.keys() <= element.keys(), (
159+
"Accu keys must be a subset of element keys: "
160+
f"{accu.keys()} not a subset of {element.keys()}"
161+
)
162+
for key in accu.keys():
163+
accu[key].append(element[key])
164+
return accu
165+
166+
167+
def main():
168+
parser = argparse.ArgumentParser(description="Benchmark HNSW")
169+
parser.add_argument("-M", "--M", type=int, required=True)
170+
parser.add_argument("-t", "--num-threads", type=int, required=True)
171+
parser.add_argument("-w", "--warm-up-iterations", type=int, default=0)
172+
parser.add_argument("-i", "--num-iterations", type=int, default=20)
173+
parser.add_argument("-r", "--num-repetitions", type=int, default=20)
174+
parser.add_argument("-s", "--ef-search", type=int, default=16)
175+
parser.add_argument("-c", "--ef-construction", type=int, default=40)
176+
parser.add_argument("-n", "--nb", type=int, default=5000)
177+
parser.add_argument("-q", "--nq", type=int, default=500)
178+
parser.add_argument("-d", "--d", type=int, default=128)
179+
args = parser.parse_args()
180+
181+
if args.warm_up_iterations > 0:
182+
print(f"Warming up for {args.warm_up_iterations} iterations...")
183+
# warm-up
184+
run_with_iterations(
185+
iterations=args.warm_up_iterations,
186+
d=args.d,
187+
nb=args.nb,
188+
nq=args.nq,
189+
M=args.M,
190+
num_threads=args.num_threads,
191+
efSearch=args.ef_search,
192+
efConstruction=args.ef_construction,
193+
)
194+
print(
195+
f"Running benchmark with dataset(nb={args.nb}, nq={args.nq}, "
196+
f"d={args.d}), M={args.M}, num_threads={args.num_threads}, "
197+
f"efSearch={args.ef_search}, efConstruction={args.ef_construction}"
198+
)
199+
result = None
200+
for _ in range(args.num_repetitions):
201+
counters = run_with_iterations(
202+
iterations=args.num_iterations,
203+
d=args.d,
204+
nb=args.nb,
205+
nq=args.nq,
206+
M=args.M,
207+
num_threads=args.num_threads,
208+
efSearch=args.ef_search,
209+
efConstruction=args.ef_construction,
210+
)
211+
result = _accumulate_counters(counters, result)
212+
assert result is not None
213+
for counter, values in result.items():
214+
if is_perf_counter(counter):
215+
print(
216+
"%s t=%.3f us (± %.4f)" % (
217+
counter,
218+
np.mean(values),
219+
np.std(values)
220+
)
221+
)

0 commit comments

Comments
 (0)