Skip to content

Commit 0bc9090

Browse files
kuarorafacebook-github-bot
authored andcommitted
Support of skip_ids in merge_from_multiple function of OnDiskInvertedLists (#3327)
Summary: **Context** 1. [Issue 2621](#2621) discuss inconsistency between OnDiskInvertedList and InvertedList. OnDiskInvertedList is supposed to handle disk based multiple Index Shards. Thus, we should name it differently when merging invls from index shard. 2. [Issue 2876](#2876) provides usecase of shifting ids when merging invls from different shards. **In this diff**, 1. To address #1 above, I renamed the merge_from function to merge_from_multiple without touching merge_from base class. why so? To continue to allow merge invl from one index to ondiskinvl from other index. 2. To address #2 above, I have added support of shift_ids in merge_from_multiple to shift ids from different shards. This can be used when each shard has same set of ids but different data. This is not recommended if id is already unique across shards. Reviewed By: mdouze Differential Revision: D55482518
1 parent c9c86f0 commit 0bc9090

File tree

5 files changed

+115
-16
lines changed

5 files changed

+115
-16
lines changed

contrib/ondisk.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
def merge_ondisk(
14-
trained_index: faiss.Index, shard_fnames: List[str], ivfdata_fname: str
14+
trained_index: faiss.Index, shard_fnames: List[str], ivfdata_fname: str, shift_ids=False
1515
) -> None:
1616
"""Add the contents of the indexes stored in shard_fnames into the index
1717
trained_index. The on-disk data is stored in ivfdata_fname"""
@@ -51,7 +51,7 @@ def merge_ondisk(
5151
ivf_vector.push_back(ivf)
5252

5353
LOG.info("merge %d inverted lists " % ivf_vector.size())
54-
ntotal = invlists.merge_from(ivf_vector.data(), ivf_vector.size())
54+
ntotal = invlists.merge_from_multiple(ivf_vector.data(), ivf_vector.size(), shift_ids)
5555

5656
# now replace the inverted lists in the output index
5757
index.ntotal = index_ivf.ntotal = ntotal

faiss/invlists/OnDiskInvertedLists.cpp

+19-4
Original file line numberDiff line numberDiff line change
@@ -565,22 +565,27 @@ void OnDiskInvertedLists::free_slot(size_t offset, size_t capacity) {
565565
/*****************************************
566566
* Compact form
567567
*****************************************/
568-
569-
size_t OnDiskInvertedLists::merge_from(
568+
size_t OnDiskInvertedLists::merge_from_multiple(
570569
const InvertedLists** ils,
571570
int n_il,
571+
bool shift_ids,
572572
bool verbose) {
573573
FAISS_THROW_IF_NOT_MSG(
574574
totsize == 0, "works only on an empty InvertedLists");
575575

576576
std::vector<size_t> sizes(nlist);
577+
std::vector<size_t> shift_id_offsets(n_il);
577578
for (int i = 0; i < n_il; i++) {
578579
const InvertedLists* il = ils[i];
579580
FAISS_THROW_IF_NOT(il->nlist == nlist && il->code_size == code_size);
580581

581582
for (size_t j = 0; j < nlist; j++) {
582583
sizes[j] += il->list_size(j);
583584
}
585+
586+
size_t il_totsize = il->compute_ntotal();
587+
shift_id_offsets[i] =
588+
(shift_ids && i > 0) ? shift_id_offsets[i - 1] + il_totsize : 0;
584589
}
585590

586591
size_t cums = 0;
@@ -605,11 +610,21 @@ size_t OnDiskInvertedLists::merge_from(
605610
const InvertedLists* il = ils[i];
606611
size_t n_entry = il->list_size(j);
607612
l.size += n_entry;
613+
ScopedIds scope_ids(il, j);
614+
const idx_t* scope_ids_data = scope_ids.get();
615+
std::vector<idx_t> new_ids;
616+
if (shift_ids) {
617+
new_ids.resize(n_entry);
618+
for (size_t k = 0; k < n_entry; k++) {
619+
new_ids[k] = scope_ids[k] + shift_id_offsets[i];
620+
}
621+
scope_ids_data = new_ids.data();
622+
}
608623
update_entries(
609624
j,
610625
l.size - n_entry,
611626
n_entry,
612-
ScopedIds(il, j).get(),
627+
scope_ids_data,
613628
ScopedCodes(il, j).get());
614629
}
615630
assert(l.size == l.capacity);
@@ -638,7 +653,7 @@ size_t OnDiskInvertedLists::merge_from(
638653
size_t OnDiskInvertedLists::merge_from_1(
639654
const InvertedLists* ils,
640655
bool verbose) {
641-
return merge_from(&ils, 1, verbose);
656+
return merge_from_multiple(&ils, 1, verbose);
642657
}
643658

644659
void OnDiskInvertedLists::crop_invlists(size_t l0, size_t l1) {

faiss/invlists/OnDiskInvertedLists.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ struct OnDiskInvertedLists : InvertedLists {
101101

102102
// copy all inverted lists into *this, in compact form (without
103103
// allocating slots)
104-
size_t merge_from(
104+
size_t merge_from_multiple(
105105
const InvertedLists** ils,
106106
int n_il,
107+
bool shift_ids = false,
107108
bool verbose = false);
108109

109110
/// same as merge_from for a single invlist

tests/test_contrib.py

+66-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import platform
1010
import os
1111
import random
12+
import shutil
1213
import tempfile
1314

1415
from faiss.contrib import datasets
@@ -17,15 +18,13 @@
1718
from faiss.contrib import ivf_tools
1819
from faiss.contrib import clustering
1920
from faiss.contrib import big_batch_search
21+
from faiss.contrib.ondisk import merge_ondisk
2022

2123
from common_faiss_tests import get_dataset_2
22-
try:
23-
from faiss.contrib.exhaustive_search import \
24-
knn_ground_truth, knn, range_ground_truth, \
25-
range_search_max_results, exponential_query_iterator
26-
except:
27-
pass # Submodule import broken in python 2.
28-
24+
from faiss.contrib.exhaustive_search import \
25+
knn_ground_truth, knn, range_ground_truth, \
26+
range_search_max_results, exponential_query_iterator
27+
from contextlib import contextmanager
2928

3029
@unittest.skipIf(platform.python_version_tuple()[0] < '3',
3130
'Submodule import broken in python 2.')
@@ -674,3 +673,63 @@ def test_code_set(self):
674673
np.testing.assert_equal(
675674
np.sort(np.unique(codes, axis=0), axis=None),
676675
np.sort(codes[inserted], axis=None))
676+
677+
678+
@unittest.skipIf(platform.system() == 'Windows',
679+
'OnDiskInvertedLists is unsupported on Windows.')
680+
class TestMerge(unittest.TestCase):
681+
@contextmanager
682+
def temp_directory(self):
683+
temp_dir = tempfile.mkdtemp()
684+
try:
685+
yield temp_dir
686+
finally:
687+
shutil.rmtree(temp_dir)
688+
689+
def do_test_ondisk_merge(self, shift_ids=False):
690+
with self.temp_directory() as tmpdir:
691+
# only train and add index to disk without adding elements.
692+
# this will create empty inverted lists.
693+
ds = datasets.SyntheticDataset(32, 2000, 200, 20)
694+
index = faiss.index_factory(ds.d, "IVF32,Flat")
695+
index.train(ds.get_train())
696+
faiss.write_index(index, tmpdir + "/trained.index")
697+
698+
# create 4 shards and add elements to them
699+
ns = 4 # number of shards
700+
701+
for bno in range(ns):
702+
index = faiss.read_index(tmpdir + "/trained.index")
703+
i0, i1 = int(bno * ds.nb / ns), int((bno + 1) * ds.nb / ns)
704+
if shift_ids:
705+
index.add_with_ids(ds.xb[i0:i1], np.arange(0, ds.nb / ns))
706+
else:
707+
index.add_with_ids(ds.xb[i0:i1], np.arange(i0, i1))
708+
faiss.write_index(index, tmpdir + "/block_%d.index" % bno)
709+
710+
# construct the output index and merge them on disk
711+
index = faiss.read_index(tmpdir + "/trained.index")
712+
block_fnames = [tmpdir + "/block_%d.index" % bno for bno in range(4)]
713+
714+
merge_ondisk(
715+
index, block_fnames, tmpdir + "/merged_index.ivfdata", shift_ids
716+
)
717+
faiss.write_index(index, tmpdir + "/populated.index")
718+
719+
# perform a search from index on disk
720+
index = faiss.read_index(tmpdir + "/populated.index")
721+
index.nprobe = 5
722+
D, I = index.search(ds.xq, 5)
723+
724+
# ground-truth
725+
gtI = ds.get_groundtruth(5)
726+
727+
recall_at_1 = (I[:, :1] == gtI[:, :1]).sum() / float(ds.xq.shape[0])
728+
self.assertGreaterEqual(recall_at_1, 0.5)
729+
730+
def test_ondisk_merge(self):
731+
self.do_test_ondisk_merge()
732+
733+
def test_ondisk_merge_with_shift_ids(self):
734+
# verified that recall is same for test_ondisk_merge and
735+
self.do_test_ondisk_merge(True)

tests/test_merge.cpp

+26-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ size_t nq = 100;
3232
int nindex = 4;
3333
int k = 10;
3434
int nlist = 40;
35+
int shard_size = nb / nindex;
3536

3637
struct CommonData {
3738
std::vector<float> database;
@@ -100,7 +101,7 @@ int compare_merged(
100101
auto il = new faiss::OnDiskInvertedLists(
101102
index0->nlist, index0->code_size, filename.c_str());
102103

103-
il->merge_from(lists.data(), lists.size());
104+
il->merge_from_multiple(lists.data(), lists.size(), shift_ids);
104105

105106
index0->replace_invlists(il, true);
106107
index0->ntotal = ntotal;
@@ -110,11 +111,14 @@ int compare_merged(
110111
nq, cd.queries.data(), k, newD.data(), newI.data());
111112

112113
size_t ndiff = 0;
114+
bool adjust_ids = shift_ids && !standard_merge;
113115
for (size_t i = 0; i < k * nq; i++) {
114-
if (refI[i] != newI[i]) {
116+
idx_t new_id = adjust_ids ? refI[i] % shard_size : refI[i];
117+
if (refI[i] != new_id) {
115118
ndiff++;
116119
}
117120
}
121+
118122
return ndiff;
119123
}
120124

@@ -220,3 +224,23 @@ TEST(MERGE, merge_flat_ondisk_2) {
220224
int ndiff = compare_merged(&index_shards, false, false);
221225
EXPECT_GE(0, ndiff);
222226
}
227+
228+
// now use ondisk specific merge and use shift ids
229+
TEST(MERGE, merge_flat_ondisk_3) {
230+
faiss::IndexShards index_shards(d, false, false);
231+
index_shards.own_indices = true;
232+
233+
std::vector<idx_t> ids;
234+
for (int i = 0; i < nb; ++i) {
235+
int id = i % shard_size;
236+
ids.push_back(id);
237+
}
238+
for (int i = 0; i < nindex; i++) {
239+
index_shards.add_shard(
240+
new faiss::IndexIVFFlat(&cd.quantizer, d, nlist));
241+
}
242+
EXPECT_TRUE(index_shards.is_trained);
243+
index_shards.add_with_ids(nb, cd.database.data(), ids.data());
244+
int ndiff = compare_merged(&index_shards, true, false);
245+
EXPECT_GE(0, ndiff);
246+
}

0 commit comments

Comments
 (0)