Skip to content

Commit

Permalink
working iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
cole_foster@brown.edu committed Aug 6, 2024
1 parent 3b1f0fa commit 05b418f
Show file tree
Hide file tree
Showing 11 changed files with 639 additions and 786 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ jobs:
conda activate hsp
conda install matplotlib
pip install h5py
cd hnswlib/
cd src/
pip install .
- name: Run benchmark
shell: bash -el {0}
run: |
conda activate hsp
python3 search/search.py
python3 eval/eval.py
python3 eval/plot.py res.csv --size 300K
python3 eval-2023/eval.py
python3 eval-2023/plot.py res.csv --size 300K
- uses: actions/upload-artifact@v3
with:
name: Results on 300k
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
data/
result/
results/
dev/
slurm*
run.sh
run.sh
res.csv
result_*
78 changes: 78 additions & 0 deletions eval-2023/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import h5py
import numpy as np
import os
import csv
from pathlib import Path
from urllib.request import urlretrieve

data_directory = "/users/cfoste18/scratch/datasets/LAION"

def download(src, dst):
if not os.path.exists(dst):
os.makedirs(Path(dst).parent, exist_ok=True)
print('downloading %s -> %s...' % (src, dst))
urlretrieve(src, dst)

def get_groundtruth(size="100K"):
url = f"http://ingeotec.mx/~sadit/sisap2024-data/gold-standard-dbsize={size}--public-queries-2024-laion2B-en-clip768v2-n=10k.h5"
#url = f"https://sisap-23-challenge.s3.amazonaws.com/SISAP23-Challenge/laion2B-en-public-gold-standard-v2-{size}.h5"

out_fn = os.path.join(data_directory, f"groundtruth-{size}.h5")
download(url, out_fn)
gt_f = h5py.File(out_fn, "r")
true_I = np.array(gt_f['knns'])
gt_f.close()
return true_I

def get_all_results(dirname):
for root, _, files in os.walk(dirname):
for fn in files:
if os.path.splitext(fn)[-1] != ".h5":
continue
try:
f = h5py.File(os.path.join(root, fn), "r")
yield f
f.close()
except:
print("Unable to read", fn)

def get_recall(I, gt, k):
assert k <= I.shape[1]
assert len(I) == len(gt)

n = len(I)
recall = 0
for i in range(n):
recall += len(set(I[i, :k]) & set(gt[i, :k]))
return recall / (n * k)

def return_h5_str(f, param):
if param not in f:
return 0
x = f[param][()]
if type(x) == np.bytes_:
return x.decode()
return x


if __name__ == "__main__":
true_I_cache = {}

columns = ["data", "size", "algo", "buildtime", "querytime", "params", "recall"]

with open('res.csv', 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=columns)
writer.writeheader()
for res in get_all_results("result"):
try:
size = res.attrs["size"]
d = dict(res.attrs)
except:
size = res["size"][()].decode()
d = {k: return_h5_str(res, k) for k in columns}
if size not in true_I_cache:
true_I_cache[size] = get_groundtruth(size)
recall = get_recall(np.array(res["knns"]), true_I_cache[size], 10)
d['recall'] = recall
print(d["data"], d["algo"], d["params"], "=>", recall)
writer.writerow(d)
1 change: 1 addition & 0 deletions eval-2023/note.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
I am unsure how to use the recent versions of the "eval" submodule, resorting to the 2023 version
102 changes: 102 additions & 0 deletions eval-2023/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# This is based on https://github.com/matsui528/annbench/blob/main/plot.py
import argparse
import csv
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import sys
from itertools import cycle


marker = cycle(('p', '^', 'h', 'x', 'o', 's', '*', '+', 'D', '1', 'X'))
linestyle = cycle((':', '-', '--'))

def draw(lines, xlabel, ylabel, title, filename, with_ctrl, width, height):
"""
Visualize search results and save them as an image
Args:
lines (list): search results. list of dict.
xlabel (str): label of x-axis, usually "recall"
ylabel (str): label of y-axis, usually "query per sec"
title (str): title of the result_img
filename (str): output file name of image
with_ctrl (bool): show control parameters or not
width (int): width of the figure
height (int): height of the figure
"""
plt.figure(figsize=(width, height))

for line in lines:
for key in ["xs", "ys", "label", "ctrls"]:
assert key in line

for line in lines:
plt.plot(line["xs"], line["ys"], label=line["label"], marker=next(marker), linestyle=next(linestyle))
if with_ctrl:
for x, y, ctrl in zip(line["xs"], line["ys"], line["ctrls"]):
plt.annotate(text=str(ctrl), xy=(x, y), xytext=(x, y+50))

plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.grid(which="both")
plt.yscale("log")
# plt.xlim([0.75, 1.0])
plt.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
plt.title(title)
plt.savefig(filename, bbox_inches='tight')
plt.cla()

def get_pareto_frontier(line):
data = sorted(zip(line["ys"], line["xs"], line["ctrls"]),reverse=True)
line["xs"] = []
line["ys"] = []
line["ctrls"] = []

cur = 0
for y, x, label in data:
if x > cur:
cur = x
line["xs"].append(x)
line["ys"].append(y)
line["ctrls"].append(label)

return line

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--size",
default="100K"
)
parser.add_argument("csvfile")
args = parser.parse_args()

with open(args.csvfile, newline="") as csvfile:
reader = csv.DictReader(csvfile)
data = list(reader)

lines = {}
for res in data:
if res["size"] != args.size:
continue
dataset = res["data"]
algo = res["algo"]
# label = dataset + algo
label = algo
if label not in lines:
lines[label] = {
"xs": [],
"ys": [],
"ctrls": [],
"label": label,
}
lines[label]["xs"].append(float(res["recall"]))
lines[label]["ys"].append(10000/float(res["querytime"])) # FIX query size hardcoded
try:
run_identifier = res["params"].split("query=")[1]
except:
run_identifier = res["params"]
lines[label]["ctrls"].append(run_identifier)

draw([get_pareto_frontier(line) for line in lines.values()],
"Recall @ 30", "QPS (1/s)", "Result", f"result_{args.size}.png", True, 10, 8)
130 changes: 0 additions & 130 deletions search/search-hnsw.py

This file was deleted.

Loading

0 comments on commit 05b418f

Please sign in to comment.