|
4 | 4 |
|
5 | 5 |
|
6 | 6 | class TestMoveToGpu(unittest.TestCase):
|
7 |
| - def test_index_cpu_to_gpu(self): |
| 7 | + |
| 8 | + @classmethod |
| 9 | + def setUpClass(cls): |
| 10 | + cls.res = faiss.StandardGpuResources() |
| 11 | + |
| 12 | + def create_index(self, factory_string): |
8 | 13 | dimension = 128
|
9 | 14 | n = 2500
|
10 | 15 | db_vectors = np.random.random((n, dimension)).astype('float32')
|
11 |
| - code_size = 16 |
12 |
| - res = faiss.StandardGpuResources() |
13 |
| - index_pq = faiss.IndexPQ(dimension, code_size, 6) |
14 |
| - index_pq.train(db_vectors) |
15 |
| - index_pq.add(db_vectors) |
16 |
| - self.assertRaisesRegex(Exception, ".*not implemented.*", |
17 |
| - faiss.index_cpu_to_gpu, res, 0, index_pq) |
18 |
| - |
19 |
| - def test_index_cpu_to_gpu_does_not_throw_with_index_flat(self): |
20 |
| - dimension = 128 |
21 |
| - n = 100 |
22 |
| - db_vectors = np.random.random((n, dimension)).astype('float32') |
23 |
| - res = faiss.StandardGpuResources() |
24 |
| - index_flat = faiss.IndexFlatL2(dimension) |
25 |
| - index_flat.add(db_vectors) |
| 16 | + index = faiss.index_factory(dimension, factory_string) |
| 17 | + index.train(db_vectors) |
| 18 | + if factory_string.startswith("IDMap"): |
| 19 | + index.add_with_ids(db_vectors, np.arange(n)) |
| 20 | + else: |
| 21 | + index.add(db_vectors) |
| 22 | + return index |
| 23 | + |
| 24 | + def create_and_clone(self, factory_string, |
| 25 | + enableCpuFallback=None, |
| 26 | + use_raft=None): |
| 27 | + idx = self.create_index(factory_string) |
| 28 | + config = faiss.GpuClonerOptions() |
| 29 | + if enableCpuFallback is not None: |
| 30 | + config.enableCpuFallback = enableCpuFallback |
| 31 | + if use_raft is not None: |
| 32 | + config.use_raft = use_raft |
| 33 | + faiss.index_cpu_to_gpu(self.res, 0, idx, config) |
| 34 | + |
| 35 | + def verify_throws_on_unsupported_index(self, factory_string): |
| 36 | + try: |
| 37 | + self.create_and_clone(factory_string) |
| 38 | + except Exception as e: |
| 39 | + if "not implemented" not in str(e): |
| 40 | + self.fail("Expected an exception but no exception was " |
| 41 | + "thrown for factory_string: %s." % factory_string) |
| 42 | + |
| 43 | + def verify_succeeds_on_supported_index(self, factory_string, use_raft=None): |
26 | 44 | try:
|
27 |
| - faiss.index_cpu_to_gpu(res, 0, index_flat) |
28 |
| - except Exception: |
29 |
| - self.fail("index_cpu_to_gpu() threw an unexpected exception.") |
| 45 | + self.create_and_clone(factory_string, use_raft=use_raft) |
| 46 | + except Exception as e: |
| 47 | + self.fail("Unexpected exception thrown factory_string: " |
| 48 | + "%s; error message: %s." % (factory_string, str(e))) |
| 49 | + |
| 50 | + def verify_succeeds_on_unsupported_index_with_fallback_enabled( |
| 51 | + self, factory_string, use_raft=None): |
| 52 | + try: |
| 53 | + self.create_and_clone(factory_string, enableCpuFallback=True, |
| 54 | + use_raft=use_raft) |
| 55 | + except Exception as e: |
| 56 | + self.fail("Unexpected exception thrown factory_string: " |
| 57 | + "%s; error message: %s." % (factory_string, str(e))) |
| 58 | + |
| 59 | + def test_index_cpu_to_gpu_unsupported_indices(self): |
| 60 | + self.verify_throws_on_unsupported_index("PQ16") |
| 61 | + self.verify_throws_on_unsupported_index("LSHrt") |
| 62 | + self.verify_throws_on_unsupported_index("HNSW") |
| 63 | + self.verify_throws_on_unsupported_index("HNSW,PQ16") |
| 64 | + self.verify_throws_on_unsupported_index("IDMap,PQ16") |
| 65 | + self.verify_throws_on_unsupported_index("IVF256,ITQ64,SH1.2") |
| 66 | + |
| 67 | + def test_index_cpu_to_gpu_supported_indices(self): |
| 68 | + self.verify_succeeds_on_supported_index("Flat") |
| 69 | + self.verify_succeeds_on_supported_index("IVF1,Flat") |
| 70 | + self.verify_succeeds_on_supported_index("IVF32,PQ8") |
| 71 | + |
| 72 | + # set use_raft to false, this index type is not supported on RAFT |
| 73 | + self.verify_succeeds_on_supported_index("IVF32,SQ8", use_raft=False) |
| 74 | + |
| 75 | + def test_index_cpu_to_gpu_unsupported_indices_with_fallback_enabled(self): |
| 76 | + self.verify_succeeds_on_unsupported_index_with_fallback_enabled("IDMap,Flat") |
| 77 | + self.verify_succeeds_on_unsupported_index_with_fallback_enabled("PCA12,IVF32,Flat") |
| 78 | + self.verify_succeeds_on_unsupported_index_with_fallback_enabled("PCA32,IVF32,PQ8") |
| 79 | + self.verify_succeeds_on_unsupported_index_with_fallback_enabled("PCA32,IVF32,PQ8np") |
| 80 | + |
| 81 | + # set use_raft to false, this index type is not supported on RAFT |
| 82 | + self.verify_succeeds_on_unsupported_index_with_fallback_enabled( |
| 83 | + "PCA32,IVF32,SQ8", use_raft=False) |
0 commit comments