|
4 | 4 |
|
5 | 5 |
|
6 | 6 | class TestMoveToGpu(unittest.TestCase):
|
7 |
| - def test_index_cpu_to_gpu(self): |
| 7 | + |
| 8 | + @classmethod |
| 9 | + def setUpClass(cls): |
| 10 | + cls.res = faiss.StandardGpuResources() |
| 11 | + |
| 12 | + def create_index(self, factory_string): |
8 | 13 | dimension = 128
|
9 | 14 | n = 2500
|
10 | 15 | db_vectors = np.random.random((n, dimension)).astype('float32')
|
11 |
| - code_size = 16 |
12 |
| - res = faiss.StandardGpuResources() |
13 |
| - index_pq = faiss.IndexPQ(dimension, code_size, 6) |
14 |
| - index_pq.train(db_vectors) |
15 |
| - index_pq.add(db_vectors) |
16 |
| - self.assertRaisesRegex(Exception, ".*not implemented.*", |
17 |
| - faiss.index_cpu_to_gpu, res, 0, index_pq) |
18 |
| - |
19 |
| - def test_index_cpu_to_gpu_does_not_throw_with_index_flat(self): |
20 |
| - dimension = 128 |
21 |
| - n = 100 |
22 |
| - db_vectors = np.random.random((n, dimension)).astype('float32') |
23 |
| - res = faiss.StandardGpuResources() |
24 |
| - index_flat = faiss.IndexFlatL2(dimension) |
25 |
| - index_flat.add(db_vectors) |
| 16 | + index = faiss.index_factory(dimension, factory_string) |
| 17 | + index.train(db_vectors) |
| 18 | + if factory_string.startswith("IDMap"): |
| 19 | + index.add_with_ids(db_vectors, np.arange(n)) |
| 20 | + else: |
| 21 | + index.add(db_vectors) |
| 22 | + return index |
| 23 | + |
| 24 | + def create_and_clone(self, factory_string, |
| 25 | + allowCpuCoarseQuantizer=None, |
| 26 | + use_raft=None): |
| 27 | + idx = self.create_index(factory_string) |
| 28 | + config = faiss.GpuClonerOptions() |
| 29 | + if allowCpuCoarseQuantizer is not None: |
| 30 | + config.allowCpuCoarseQuantizer = allowCpuCoarseQuantizer |
| 31 | + if use_raft is not None: |
| 32 | + config.use_raft = use_raft |
| 33 | + faiss.index_cpu_to_gpu(self.res, 0, idx, config) |
| 34 | + |
| 35 | + def verify_throws_not_implemented_exception(self, factory_string): |
| 36 | + try: |
| 37 | + self.create_and_clone(factory_string) |
| 38 | + except Exception as e: |
| 39 | + if "not implemented" not in str(e): |
| 40 | + self.fail("Expected an exception but no exception was " |
| 41 | + "thrown for factory_string: %s." % factory_string) |
| 42 | + |
| 43 | + def verify_clones_successfully(self, factory_string, |
| 44 | + allowCpuCoarseQuantizer=None, |
| 45 | + use_raft=None): |
| 46 | + try: |
| 47 | + self.create_and_clone( |
| 48 | + factory_string, |
| 49 | + allowCpuCoarseQuantizer=allowCpuCoarseQuantizer, |
| 50 | + use_raft=use_raft) |
| 51 | + except Exception as e: |
| 52 | + self.fail("Unexpected exception thrown factory_string: " |
| 53 | + "%s; error message: %s." % (factory_string, str(e))) |
| 54 | + |
| 55 | + def test_not_implemented_indices(self): |
| 56 | + self.verify_throws_not_implemented_exception("PQ16") |
| 57 | + self.verify_throws_not_implemented_exception("LSHrt") |
| 58 | + self.verify_throws_not_implemented_exception("HNSW") |
| 59 | + self.verify_throws_not_implemented_exception("HNSW,PQ16") |
| 60 | + self.verify_throws_not_implemented_exception("IDMap,PQ16") |
| 61 | + self.verify_throws_not_implemented_exception("IVF256,ITQ64,SH1.2") |
| 62 | + |
| 63 | + def test_implemented_indices(self): |
| 64 | + self.verify_clones_successfully("Flat") |
| 65 | + self.verify_clones_successfully("IVF1,Flat") |
| 66 | + self.verify_clones_successfully("IVF32,PQ8") |
| 67 | + self.verify_clones_successfully("IDMap,Flat") |
| 68 | + self.verify_clones_successfully("PCA12,IVF32,Flat") |
| 69 | + self.verify_clones_successfully("PCA32,IVF32,PQ8") |
| 70 | + self.verify_clones_successfully("PCA32,IVF32,PQ8np") |
| 71 | + |
| 72 | + # set use_raft to false, these index types are not supported on RAFT |
| 73 | + self.verify_clones_successfully("IVF32,SQ8", use_raft=False) |
| 74 | + self.verify_clones_successfully( |
| 75 | + "PCA32,IVF32,SQ8", use_raft=False) |
| 76 | + |
| 77 | + def test_with_flag(self): |
| 78 | + self.verify_clones_successfully("IVF32_HNSW,Flat", |
| 79 | + allowCpuCoarseQuantizer=True) |
| 80 | + self.verify_clones_successfully("IVF256(PQ2x4fs),Flat", |
| 81 | + allowCpuCoarseQuantizer=True) |
| 82 | + |
| 83 | + def test_with_flag_set_to_false(self): |
26 | 84 | try:
|
27 |
| - faiss.index_cpu_to_gpu(res, 0, index_flat) |
28 |
| - except Exception: |
29 |
| - self.fail("index_cpu_to_gpu() threw an unexpected exception.") |
| 85 | + self.verify_clones_successfully("IVF32_HNSW,Flat", |
| 86 | + allowCpuCoarseQuantizer=False) |
| 87 | + except Exception as e: |
| 88 | + if "set the flag to true to allow the CPU fallback" not in str(e): |
| 89 | + self.fail("Unexepected error message thrown: %s." % str(e)) |
0 commit comments