|
9 | 9 | import platform
|
10 | 10 | import os
|
11 | 11 | import random
|
| 12 | +import shutil |
12 | 13 | import tempfile
|
13 | 14 |
|
14 | 15 | from faiss.contrib import datasets
|
|
17 | 18 | from faiss.contrib import ivf_tools
|
18 | 19 | from faiss.contrib import clustering
|
19 | 20 | from faiss.contrib import big_batch_search
|
| 21 | +from faiss.contrib.ondisk import merge_ondisk |
20 | 22 |
|
21 | 23 | from common_faiss_tests import get_dataset_2
|
22 |
| -try: |
23 |
| - from faiss.contrib.exhaustive_search import \ |
24 |
| - knn_ground_truth, knn, range_ground_truth, \ |
25 |
| - range_search_max_results, exponential_query_iterator |
26 |
| -except: |
27 |
| - pass # Submodule import broken in python 2. |
28 |
| - |
| 24 | +from faiss.contrib.exhaustive_search import \ |
| 25 | + knn_ground_truth, knn, range_ground_truth, \ |
| 26 | + range_search_max_results, exponential_query_iterator |
| 27 | +from contextlib import contextmanager |
29 | 28 |
|
30 | 29 | @unittest.skipIf(platform.python_version_tuple()[0] < '3',
|
31 | 30 | 'Submodule import broken in python 2.')
|
@@ -674,3 +673,63 @@ def test_code_set(self):
|
674 | 673 | np.testing.assert_equal(
|
675 | 674 | np.sort(np.unique(codes, axis=0), axis=None),
|
676 | 675 | np.sort(codes[inserted], axis=None))
|
| 676 | + |
| 677 | + |
| 678 | +@unittest.skipIf(platform.system() == 'Windows', |
| 679 | + 'OnDiskInvertedLists is unsupported on Windows.') |
| 680 | +class TestMerge(unittest.TestCase): |
| 681 | + @contextmanager |
| 682 | + def temp_directory(self): |
| 683 | + temp_dir = tempfile.mkdtemp() |
| 684 | + try: |
| 685 | + yield temp_dir |
| 686 | + finally: |
| 687 | + shutil.rmtree(temp_dir) |
| 688 | + |
| 689 | + def do_test_ondisk_merge(self, shift_ids=False): |
| 690 | + with self.temp_directory() as tmpdir: |
| 691 | + # only train and add index to disk without adding elements. |
| 692 | + # this will create empty inverted lists. |
| 693 | + ds = datasets.SyntheticDataset(32, 2000, 200, 20) |
| 694 | + index = faiss.index_factory(ds.d, "IVF32,Flat") |
| 695 | + index.train(ds.get_train()) |
| 696 | + faiss.write_index(index, tmpdir + "/trained.index") |
| 697 | + |
| 698 | + # create 4 shards and add elements to them |
| 699 | + ns = 4 # number of shards |
| 700 | + |
| 701 | + for bno in range(ns): |
| 702 | + index = faiss.read_index(tmpdir + "/trained.index") |
| 703 | + i0, i1 = int(bno * ds.nb / ns), int((bno + 1) * ds.nb / ns) |
| 704 | + if shift_ids: |
| 705 | + index.add_with_ids(ds.xb[i0:i1], np.arange(0, ds.nb / ns)) |
| 706 | + else: |
| 707 | + index.add_with_ids(ds.xb[i0:i1], np.arange(i0, i1)) |
| 708 | + faiss.write_index(index, tmpdir + "/block_%d.index" % bno) |
| 709 | + |
| 710 | + # construct the output index and merge them on disk |
| 711 | + index = faiss.read_index(tmpdir + "/trained.index") |
| 712 | + block_fnames = [tmpdir + "/block_%d.index" % bno for bno in range(4)] |
| 713 | + |
| 714 | + merge_ondisk( |
| 715 | + index, block_fnames, tmpdir + "/merged_index.ivfdata", shift_ids |
| 716 | + ) |
| 717 | + faiss.write_index(index, tmpdir + "/populated.index") |
| 718 | + |
| 719 | + # perform a search from index on disk |
| 720 | + index = faiss.read_index(tmpdir + "/populated.index") |
| 721 | + index.nprobe = 5 |
| 722 | + D, I = index.search(ds.xq, 5) |
| 723 | + |
| 724 | + # ground-truth |
| 725 | + gtI = ds.get_groundtruth(5) |
| 726 | + |
| 727 | + recall_at_1 = (I[:, :1] == gtI[:, :1]).sum() / float(ds.xq.shape[0]) |
| 728 | + self.assertGreaterEqual(recall_at_1, 0.5) |
| 729 | + |
| 730 | + def test_ondisk_merge(self): |
| 731 | + self.do_test_ondisk_merge() |
| 732 | + |
| 733 | + def test_ondisk_merge_with_shift_ids(self): |
| 734 | + # verified that recall is same for test_ondisk_merge and |
| 735 | + self.do_test_ondisk_merge(True) |
0 commit comments