Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SearchParameters support for IndexBinaryFlat #4055

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fddbd3e
Add support for search params to IndexBinaryFlat
gustavz Dec 4, 2024
bf90fc1
update tests
gustavz Dec 4, 2024
c5c9cab
revert default param changes
gustavz Dec 4, 2024
d70e64b
add missing sel to hamming.h, add no heap test case, simplify valid_c…
gustavz Dec 12, 2024
cd8b5d3
fix import, no heap test, linting
gustavz Dec 16, 2024
b7e73e1
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Dec 16, 2024
e96237e
lint
gustavz Dec 19, 2024
b3db31b
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Dec 19, 2024
d92184e
remove default from definition
gustavz Dec 24, 2024
d40594f
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Dec 24, 2024
16027c6
add #include <faiss/impl/IDSelector.h> to hamming.cpp
gustavz Jan 14, 2025
54d8256
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Jan 14, 2025
983ae3d
add faiss namespace to IDSelector in hamming.cpp
gustavz Jan 21, 2025
44d7977
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Jan 21, 2025
fc1f675
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Jan 22, 2025
dfccc69
add faiss:: to IDSelector in hamming.h
gustavz Feb 13, 2025
a5ec86c
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Feb 13, 2025
56f13ea
add params to IndexBinary replacement_search and replacement_range_se…
gustavz Feb 14, 2025
95af571
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Feb 14, 2025
75df1b5
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
mnorris11 Feb 20, 2025
5b424cb
update tests
gustavz Feb 28, 2025
3b56fc2
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Feb 28, 2025
8be85c8
small test optimization
gustavz Feb 28, 2025
9d16d52
lint
gustavz Mar 11, 2025
1c52c9c
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gustavz Mar 11, 2025
9f851fc
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gtwang01 Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions faiss/IndexBinaryFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ void IndexBinaryFlat::search(
int32_t* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
// Extract IDSelector from params if present
const IDSelector* sel = params ? params->sel : nullptr;
FAISS_THROW_IF_NOT(k > 0);

const idx_t block_size = query_batch_size;
Expand All @@ -60,7 +60,8 @@ void IndexBinaryFlat::search(
ntotal,
code_size,
/* ordered = */ true,
approx_topk_mode);
approx_topk_mode,
sel);
} else {
hammings_knn_mc(
x + s * code_size,
Expand All @@ -70,7 +71,8 @@ void IndexBinaryFlat::search(
k,
code_size,
distances + s * k,
labels + s * k);
labels + s * k,
sel);
}
}
}
Expand Down Expand Up @@ -107,9 +109,8 @@ void IndexBinaryFlat::range_search(
int radius,
RangeSearchResult* result,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result);
const IDSelector* sel = params ? params->sel : nullptr;
hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result, sel);
}

} // namespace faiss
40 changes: 28 additions & 12 deletions faiss/utils/approx_topk_hamming/approx_topk_hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ struct HeapWithBucketsForHamming32<
// output distances
int* const __restrict bh_val,
// output indices, each being within [0, n) range
int64_t* const __restrict bh_ids) {
int64_t* const __restrict bh_ids,
// optional id selector for filtering
const IDSelector* sel = nullptr) {
// forward a call to bs_addn with 1 beam
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids, sel);
}

static void bs_addn(
Expand All @@ -66,7 +68,9 @@ struct HeapWithBucketsForHamming32<
int* const __restrict bh_val,
// output indices, each being within [0, n_per_beam * beam_size)
// range
int64_t* const __restrict bh_ids) {
int64_t* const __restrict bh_ids,
// optional id selector for filtering
const IDSelector* sel = nullptr) {
//
using C = CMax<int, int64_t>;

Expand Down Expand Up @@ -95,11 +99,20 @@ struct HeapWithBucketsForHamming32<
for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
uint32_t hamming_distances[8];
uint8_t valid_counter = 0;
for (size_t j8 = 0; j8 < 8; j8++) {
hamming_distances[j8] = hc.hamming(
binary_vectors +
(j8 + j * 8 + ip + n_per_beam * beam_index) *
code_size);
const uint32_t idx = j8 + j * 8 + ip + n_per_beam * beam_index;
if (!sel || sel->is_member(idx)) {
hamming_distances[j8] = hc.hamming(
binary_vectors + idx * code_size);
valid_counter++;
} else {
hamming_distances[j8] = std::numeric_limits<int32_t>::max();
}
}

if (valid_counter == 0) {
continue; // Skip if all vectors are filtered out
}

// loop. Compiler should get rid of unneeded ops
Expand Down Expand Up @@ -157,7 +170,8 @@ struct HeapWithBucketsForHamming32<
const auto value = min_distances_scalar[j8];
const auto index = min_indices_scalar[j8];

if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
if (value < std::numeric_limits<int32_t>::max() &&
C::cmp2(bh_val[0], value, bh_ids[0], index)) {
heap_replace_top<C>(
k, bh_val, bh_ids, value, index);
}
Expand All @@ -168,11 +182,13 @@ struct HeapWithBucketsForHamming32<
// process leftovers
for (uint32_t ip = nb; ip < n_per_beam; ip++) {
const auto index = ip + n_per_beam * beam_index;
const auto value =
hc.hamming(binary_vectors + (index)*code_size);
if (!sel || sel->is_member(index)) {
const auto value =
hc.hamming(binary_vectors + (index)*code_size);

if (C::cmp(bh_val[0], value)) {
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
if (C::cmp(bh_val[0], value)) {
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
}
}
}
}
Expand Down
41 changes: 27 additions & 14 deletions faiss/utils/hamming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ void hammings_knn_hc(
size_t n2,
bool order = true,
bool init_heap = true,
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK) {
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK
const IDSelector* sel = nullptr) {
size_t k = ha->k;
if (init_heap)
ha->heapify();
Expand Down Expand Up @@ -205,7 +206,7 @@ void hammings_knn_hc(
NB, \
BD, \
HammingComputer>:: \
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_); \
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_, sel); \
break;

switch (approx_topk_mode) {
Expand All @@ -215,6 +216,9 @@ void hammings_knn_hc(
HANDLE_APPROX(32, 2)
default: {
for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) {
if (sel && !sel->is_member(j)) {
continue;
}
dis = hc.hamming(bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_replace_top<hamdis_t>(
Expand All @@ -239,7 +243,8 @@ void hammings_knn_mc(
size_t nb,
size_t k,
int32_t* __restrict distances,
int64_t* __restrict labels) {
int64_t* __restrict labels,
const IDSelector* sel) {
const int nBuckets = bytes_per_code * 8 + 1;
std::vector<int> all_counters(na * nBuckets, 0);
std::unique_ptr<int64_t[]> all_ids_per_dis(new int64_t[na * nBuckets * k]);
Expand All @@ -260,7 +265,9 @@ void hammings_knn_mc(
#pragma omp parallel for
for (int64_t i = 0; i < na; ++i) {
for (size_t j = j0; j < j1; ++j) {
cs[i].update_counter(b + j * bytes_per_code, j);
if (!sel || sel->is_member(j)) {
cs[i].update_counter(b + j * bytes_per_code, j);
}
}
}
}
Expand Down Expand Up @@ -292,7 +299,8 @@ void hamming_range_search(
size_t nb,
int radius,
size_t code_size,
RangeSearchResult* res) {
RangeSearchResult* res,
const IDSelector* sel) {
#pragma omp parallel
{
RangeSearchPartialResult pres(res);
Expand All @@ -304,9 +312,11 @@ void hamming_range_search(
RangeQueryResult& qres = pres.new_result(i);

for (size_t j = 0; j < nb; j++) {
int dis = hc.hamming(yi);
if (dis < radius) {
qres.add(dis, j);
if (!sel || sel->is_member(j)) {
int dis = hc.hamming(yi);
if (dis < radius) {
qres.add(dis, j);
}
}
yi += code_size;
}
Expand Down Expand Up @@ -490,10 +500,11 @@ void hammings_knn_hc(
size_t nb,
size_t ncodes,
int order,
ApproxTopK_mode_t approx_topk_mode) {
ApproxTopK_mode_t approx_topk_mode
const IDSelector* sel) {
Run_hammings_knn_hc r;
dispatch_HammingComputer(
ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode);
ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode, sel);
}

void hammings_knn_mc(
Expand All @@ -504,10 +515,11 @@ void hammings_knn_mc(
size_t k,
size_t ncodes,
int32_t* __restrict distances,
int64_t* __restrict labels) {
int64_t* __restrict labels,
const IDSelector* sel) {
Run_hammings_knn_mc r;
dispatch_HammingComputer(
ncodes, r, ncodes, a, b, na, nb, k, distances, labels);
ncodes, r, ncodes, a, b, na, nb, k, distances, labels, sel);
}

void hamming_range_search(
Expand All @@ -517,10 +529,11 @@ void hamming_range_search(
size_t nb,
int radius,
size_t code_size,
RangeSearchResult* result) {
RangeSearchResult* result,
const IDSelector* sel = nullptr) {
Run_hamming_range_search r;
dispatch_HammingComputer(
code_size, r, a, b, na, nb, radius, code_size, result);
code_size, r, a, b, na, nb, radius, code_size, result, sel);
}

/* Count number of matches given a max threshold */
Expand Down
9 changes: 6 additions & 3 deletions faiss/utils/hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ void hammings_knn_hc(
size_t nb,
size_t ncodes,
int ordered,
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK);
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
const IDSelector* sel = nullptr);

/* Legacy alias to hammings_knn_hc. */
void hammings_knn(
Expand Down Expand Up @@ -166,7 +167,8 @@ void hammings_knn_mc(
size_t k,
size_t ncodes,
int32_t* distances,
int64_t* labels);
int64_t* labels,
const IDSelector* sel = nullptr);

/** same as hammings_knn except we are doing a range search with radius */
void hamming_range_search(
Expand All @@ -176,7 +178,8 @@ void hamming_range_search(
size_t nb,
int radius,
size_t ncodes,
RangeSearchResult* result);
RangeSearchResult* result,
const IDSelector* sel = nullptr);

/* Counting the number of matches or of cross-matches (without returning them)
For use with function that assume pre-allocated memory */
Expand Down
79 changes: 58 additions & 21 deletions tests/test_search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,31 @@ class TestSelector(unittest.TestCase):
combinations as possible.
"""

def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10):
def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10, params=None):
""" Verify that the id selector returns the subset of results that are
members according to the IDSelector.
Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "and", "or", "xor"
params: optional SearchParameters object to override default settings
"""
ds = datasets.SyntheticDataset(32, 1000, 100, 20)
index = faiss.index_factory(ds.d, index_key, mt)
index.train(ds.get_train())
d = 32 # make sure dimension is multiple of 8 for binary
ds = datasets.SyntheticDataset(d, 1000, 100, 20)

if index_key == "BinaryFlat":
# Create proper binary vectors following test_index_binary.py pattern
rs = np.random.RandomState(123)
xb = rs.randint(256, size=(ds.nb, d // 8), dtype='uint8')
xq = rs.randint(256, size=(ds.nq, d // 8), dtype='uint8')
xt = None # No training needed for binary flat
index = faiss.IndexBinaryFlat(d)
# Use smaller radius for Hamming distance
base_radius = 4
else:
xb = ds.get_database()
xq = ds.get_queries()
xt = ds.get_train()
index = faiss.index_factory(d, index_key, mt)
index.train(xt)
base_radius = float('inf') # Will be set based on results

# reference result
if "range" in id_selector_type:
Expand All @@ -54,20 +71,22 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
subset = np.setxor1d(lhs_subset, rhs_subset)
else:
rs = np.random.RandomState(123)
subset = rs.choice(ds.nb, 50, replace=False).astype("int64")
# add_with_ids not supported for all index types
# index.add_with_ids(ds.get_database()[subset], subset)
index.add(ds.get_database()[subset])
subset = rs.choice(ds.nb, 50, replace=False).astype('int64')

index.add(xb[subset])
if "IVF" in index_key and id_selector_type == "range_sorted":
self.assertTrue(index.check_ids_sorted())
Dref, Iref0 = index.search(ds.get_queries(), k)
Dref, Iref0 = index.search(xq, k)
Iref = subset[Iref0]
Iref[Iref0 < 0] = -1

radius = float(Dref[Iref > 0].max()) * 1.01
if base_radius == float('inf'):
radius = float(Dref[Iref > 0].max()) * 1.01
else:
radius = base_radius

try:
Rlims_ref, RDref, RIref = index.range_search(
ds.get_queries(), radius)
Rlims_ref, RDref, RIref = index.range_search(xq, radius)
except RuntimeError as e:
if "not implemented" in str(e):
have_range_search = False
Expand All @@ -81,7 +100,7 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR

# result with selector: fill full database and search with selector
index.reset()
index.add(ds.get_database())
index.add(xb)
if id_selector_type == "range":
sel = faiss.IDSelectorRange(30, 80)
elif id_selector_type == "range_sorted":
Expand Down Expand Up @@ -118,18 +137,22 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
else:
sel = faiss.IDSelectorBatch(subset)

params = (
faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else
faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else
faiss.SearchParameters(sel=sel)
)
Dnew, Inew = index.search(ds.get_queries(), k, params=params)
if params is None:
params = (
faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else
faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else
faiss.SearchParameters(sel=sel)
)
else:
# Use provided params but ensure selector is set
params.sel = sel

Dnew, Inew = index.search(xq, k, params=params)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_almost_equal(Dref, Dnew, decimal=5)

if have_range_search:
Rlims_new, RDnew, RInew = index.range_search(
ds.get_queries(), radius, params=params)
Rlims_new, RDnew, RInew = index.range_search(xq, radius, params=params)
np.testing.assert_array_equal(Rlims_ref, Rlims_new)
RDref, RIref = sort_range_res_2(Rlims_ref, RDref, RIref)
np.testing.assert_array_equal(RIref, RInew)
Expand Down Expand Up @@ -284,6 +307,20 @@ def test_bounds(self):
distances, indices = index_ip.search(xb[:2], k=3, params=search_params)
distances, indices = index_l2.search(xb[:2], k=3, params=search_params)

def test_BinaryFlat(self):
self.do_test_id_selector("BinaryFlat")

def test_BinaryFlat_id_range(self):
self.do_test_id_selector("BinaryFlat", id_selector_type="range")

def test_BinaryFlat_id_array(self):
self.do_test_id_selector("BinaryFlat", id_selector_type="array")

def test_BinaryFlat_no_heap(self):
params = faiss.SearchParameters()
params.use_heap = False
self.do_test_id_selector("BinaryFlat", params=params)


class TestSearchParams(unittest.TestCase):

Expand Down
Loading