-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
cole_foster@brown.edu
committed
Aug 6, 2024
1 parent
3b1f0fa
commit 05b418f
Showing
11 changed files
with
639 additions
and
786 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
data/ | ||
result/ | ||
results/ | ||
dev/ | ||
slurm* | ||
run.sh | ||
run.sh | ||
res.csv | ||
result_* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import h5py | ||
import numpy as np | ||
import os | ||
import csv | ||
from pathlib import Path | ||
from urllib.request import urlretrieve | ||
|
||
data_directory = "/users/cfoste18/scratch/datasets/LAION" | ||
|
||
def download(src, dst): | ||
if not os.path.exists(dst): | ||
os.makedirs(Path(dst).parent, exist_ok=True) | ||
print('downloading %s -> %s...' % (src, dst)) | ||
urlretrieve(src, dst) | ||
|
||
def get_groundtruth(size="100K"): | ||
url = f"http://ingeotec.mx/~sadit/sisap2024-data/gold-standard-dbsize={size}--public-queries-2024-laion2B-en-clip768v2-n=10k.h5" | ||
#url = f"https://sisap-23-challenge.s3.amazonaws.com/SISAP23-Challenge/laion2B-en-public-gold-standard-v2-{size}.h5" | ||
|
||
out_fn = os.path.join(data_directory, f"groundtruth-{size}.h5") | ||
download(url, out_fn) | ||
gt_f = h5py.File(out_fn, "r") | ||
true_I = np.array(gt_f['knns']) | ||
gt_f.close() | ||
return true_I | ||
|
||
def get_all_results(dirname): | ||
for root, _, files in os.walk(dirname): | ||
for fn in files: | ||
if os.path.splitext(fn)[-1] != ".h5": | ||
continue | ||
try: | ||
f = h5py.File(os.path.join(root, fn), "r") | ||
yield f | ||
f.close() | ||
except: | ||
print("Unable to read", fn) | ||
|
||
def get_recall(I, gt, k): | ||
assert k <= I.shape[1] | ||
assert len(I) == len(gt) | ||
|
||
n = len(I) | ||
recall = 0 | ||
for i in range(n): | ||
recall += len(set(I[i, :k]) & set(gt[i, :k])) | ||
return recall / (n * k) | ||
|
||
def return_h5_str(f, param): | ||
if param not in f: | ||
return 0 | ||
x = f[param][()] | ||
if type(x) == np.bytes_: | ||
return x.decode() | ||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
true_I_cache = {} | ||
|
||
columns = ["data", "size", "algo", "buildtime", "querytime", "params", "recall"] | ||
|
||
with open('res.csv', 'w', newline='') as csvfile: | ||
writer = csv.DictWriter(csvfile, fieldnames=columns) | ||
writer.writeheader() | ||
for res in get_all_results("result"): | ||
try: | ||
size = res.attrs["size"] | ||
d = dict(res.attrs) | ||
except: | ||
size = res["size"][()].decode() | ||
d = {k: return_h5_str(res, k) for k in columns} | ||
if size not in true_I_cache: | ||
true_I_cache[size] = get_groundtruth(size) | ||
recall = get_recall(np.array(res["knns"]), true_I_cache[size], 10) | ||
d['recall'] = recall | ||
print(d["data"], d["algo"], d["params"], "=>", recall) | ||
writer.writerow(d) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
I am unsure how to use the recent versions of the "eval" submodule, resorting to the 2023 version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# This is based on https://github.com/matsui528/annbench/blob/main/plot.py | ||
import argparse | ||
import csv | ||
import matplotlib | ||
matplotlib.use('agg') | ||
import matplotlib.pyplot as plt | ||
import sys | ||
from itertools import cycle | ||
|
||
|
||
marker = cycle(('p', '^', 'h', 'x', 'o', 's', '*', '+', 'D', '1', 'X')) | ||
linestyle = cycle((':', '-', '--')) | ||
|
||
def draw(lines, xlabel, ylabel, title, filename, with_ctrl, width, height): | ||
""" | ||
Visualize search results and save them as an image | ||
Args: | ||
lines (list): search results. list of dict. | ||
xlabel (str): label of x-axis, usually "recall" | ||
ylabel (str): label of y-axis, usually "query per sec" | ||
title (str): title of the result_img | ||
filename (str): output file name of image | ||
with_ctrl (bool): show control parameters or not | ||
width (int): width of the figure | ||
height (int): height of the figure | ||
""" | ||
plt.figure(figsize=(width, height)) | ||
|
||
for line in lines: | ||
for key in ["xs", "ys", "label", "ctrls"]: | ||
assert key in line | ||
|
||
for line in lines: | ||
plt.plot(line["xs"], line["ys"], label=line["label"], marker=next(marker), linestyle=next(linestyle)) | ||
if with_ctrl: | ||
for x, y, ctrl in zip(line["xs"], line["ys"], line["ctrls"]): | ||
plt.annotate(text=str(ctrl), xy=(x, y), xytext=(x, y+50)) | ||
|
||
plt.xlabel(xlabel) | ||
plt.ylabel(ylabel) | ||
plt.grid(which="both") | ||
plt.yscale("log") | ||
# plt.xlim([0.75, 1.0]) | ||
plt.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left") | ||
plt.title(title) | ||
plt.savefig(filename, bbox_inches='tight') | ||
plt.cla() | ||
|
||
def get_pareto_frontier(line): | ||
data = sorted(zip(line["ys"], line["xs"], line["ctrls"]),reverse=True) | ||
line["xs"] = [] | ||
line["ys"] = [] | ||
line["ctrls"] = [] | ||
|
||
cur = 0 | ||
for y, x, label in data: | ||
if x > cur: | ||
cur = x | ||
line["xs"].append(x) | ||
line["ys"].append(y) | ||
line["ctrls"].append(label) | ||
|
||
return line | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--size", | ||
default="100K" | ||
) | ||
parser.add_argument("csvfile") | ||
args = parser.parse_args() | ||
|
||
with open(args.csvfile, newline="") as csvfile: | ||
reader = csv.DictReader(csvfile) | ||
data = list(reader) | ||
|
||
lines = {} | ||
for res in data: | ||
if res["size"] != args.size: | ||
continue | ||
dataset = res["data"] | ||
algo = res["algo"] | ||
# label = dataset + algo | ||
label = algo | ||
if label not in lines: | ||
lines[label] = { | ||
"xs": [], | ||
"ys": [], | ||
"ctrls": [], | ||
"label": label, | ||
} | ||
lines[label]["xs"].append(float(res["recall"])) | ||
lines[label]["ys"].append(10000/float(res["querytime"])) # FIX query size hardcoded | ||
try: | ||
run_identifier = res["params"].split("query=")[1] | ||
except: | ||
run_identifier = res["params"] | ||
lines[label]["ctrls"].append(run_identifier) | ||
|
||
draw([get_pareto_frontier(line) for line in lines.values()], | ||
"Recall @ 30", "QPS (1/s)", "Result", f"result_{args.size}.png", True, 10, 8) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.