Skip to content

Commit 2956fbe

Browse files
bshethmetaaalekhpatel07
authored andcommitted
First attempt at LSH matching with nbits (facebookresearch#3679)
Summary: Pull Request resolved: facebookresearch#3679 T195237796 Claims we should be able to incldue nbits in the LSH factory string. Their example is: ``` index = faiss.index_factory(128, 'LSH16rt') Returns the following error. faiss/index_factory.cpp:880: could not parse index string LSHrt_16 ``` This is my first attempt at modifying the regex to accept an integer for nbits. Can an expert help me understand what the domain of accepted strings should be so I can modify the regex as necessary? Reviewed By: ramilbakhshyiev Differential Revision: D60054776 fbshipit-source-id: e47074eb9986b7c1c702832fc0bf758f60f45290
1 parent 5fe1aeb commit 2956fbe

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

faiss/index_factory.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,12 @@ Index* parse_other_indexes(
530530
}
531531

532532
// IndexLSH
533-
if (match("LSH(r?)(t?)")) {
534-
bool rotate_data = sm[1].length() > 0;
535-
bool train_thresholds = sm[2].length() > 0;
533+
if (match("LSH([0-9]*)(r?)(t?)")) {
534+
int nbits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : d;
535+
bool rotate_data = sm[2].length() > 0;
536+
bool train_thresholds = sm[3].length() > 0;
536537
FAISS_THROW_IF_NOT(metric == METRIC_L2);
537-
return new IndexLSH(d, d, rotate_data, train_thresholds);
538+
return new IndexLSH(d, nbits, rotate_data, train_thresholds);
538539
}
539540

540541
// IndexLattice

tests/test_factory.py

+6
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ def test_factory_NSG(self):
119119
assert index.nlist == 65536 and index_nsg.nsg.R == 64
120120
assert index.pq.M == 2 and index.pq.nbits == 8
121121

122+
def test_factory_lsh(self):
123+
index = faiss.index_factory(128, 'LSHrt')
124+
self.assertEqual(index.nbits, 128)
125+
index = faiss.index_factory(128, 'LSH16rt')
126+
self.assertEqual(index.nbits, 16)
127+
122128
def test_factory_fast_scan(self):
123129
index = faiss.index_factory(56, "PQ28x4fs")
124130
self.assertEqual(index.bbs, 32)

0 commit comments

Comments
 (0)