Skip to content

Commit e2f3bac

Browse files
Change index_cpu_to_gpu to throw for indices not implemented on GPU (#3336)
Summary: Issue: #3269 List of implemented GPU indices: https://github.com/facebookresearch/faiss/wiki/Faiss-on-the-GPU#implemented-indexes Reviewed By: junjieqi Differential Revision: D55577576
1 parent da9f292 commit e2f3bac

8 files changed

+97
-32
lines changed

faiss/gpu/GpuCloner.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ Index* ToGpuCloner::clone_Index(const Index* index) {
205205
config.usePrecomputedTables = usePrecomputed;
206206
config.use_raft = use_raft;
207207
config.interleavedLayout = use_raft;
208+
config.enableCpuFallback = enableCpuFallback;
208209

209210
GpuIndexIVFPQ* res = new GpuIndexIVFPQ(provider, ipq, config);
210211

@@ -214,8 +215,11 @@ Index* ToGpuCloner::clone_Index(const Index* index) {
214215

215216
return res;
216217
} else {
217-
// default: use CPU cloner
218-
return Cloner::clone_Index(index);
218+
// use CPU cloner if CPU fallback is enabled
219+
if (enableCpuFallback) {
220+
return Cloner::clone_Index(index);
221+
}
222+
FAISS_THROW_MSG("This index type is not implemented on GPU.");
219223
}
220224
}
221225

@@ -224,8 +228,6 @@ faiss::Index* index_cpu_to_gpu(
224228
int device,
225229
const faiss::Index* index,
226230
const GpuClonerOptions* options) {
227-
auto index_pq = dynamic_cast<const faiss::IndexPQ*>(index);
228-
FAISS_THROW_IF_MSG(index_pq, "This index type is not implemented on GPU.");
229231
GpuClonerOptions defaults;
230232
ToGpuCloner cl(provider, device, options ? *options : defaults);
231233
return cl.clone_Index(index);

faiss/gpu/GpuClonerOptions.h

+5
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ struct GpuClonerOptions {
4343
#else
4444
bool use_raft = false;
4545
#endif
46+
47+
/// enable CPU fallback; when set to true, the cloner will clone to CPU
48+
/// the index components that are not impemented on GPU; when set to false,
49+
/// the cloner will throw an exception if it cannot convert the index to GPU
50+
bool enableCpuFallback = false;
4651
};
4752

4853
struct GpuMultipleClonerOptions : public GpuClonerOptions {

faiss/gpu/GpuIndex.h

+5
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ struct GpuIndexConfig {
4343
#else
4444
bool use_raft = false;
4545
#endif
46+
47+
/// enable CPU fallback; when set to true, the cloner will clone to CPU
48+
/// the index components that are not impemented on GPU; when set to false,
49+
/// the cloner will throw an exception if it cannot convert the index to GPU
50+
bool enableCpuFallback = false;
4651
};
4752

4853
/// A centralized function that determines whether RAFT should

faiss/gpu/GpuIndexIVF.cu

+1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ void GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
173173
GpuResourcesProviderFromInstance pfi(getResources());
174174

175175
GpuClonerOptions options;
176+
options.enableCpuFallback = config_.enableCpuFallback;
176177
auto cloner = ToGpuCloner(&pfi, getDevice(), options);
177178

178179
quantizer = cloner.clone_Index(index->quantizer);

faiss/gpu/test/test_gpu_index.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,10 @@ class TestGpuAutoTune(unittest.TestCase):
589589

590590
def test_params(self):
591591
index = faiss.index_factory(32, "IVF65536_HNSW,PQ16")
592-
index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index)
592+
res = faiss.StandardGpuResources()
593+
options = faiss.GpuClonerOptions()
594+
options.enableCpuFallback = True
595+
index = faiss.index_cpu_to_gpu(res, 0, index, options)
593596
ps = faiss.GpuParameterSpace()
594597
ps.initialize(index)
595598
for i in range(ps.parameter_ranges.size()):

faiss/gpu/test/test_index_cpu_to_gpu.py

+73-19
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,80 @@
44

55

66
class TestMoveToGpu(unittest.TestCase):
7-
def test_index_cpu_to_gpu(self):
7+
8+
@classmethod
9+
def setUpClass(cls):
10+
cls.res = faiss.StandardGpuResources()
11+
12+
def create_index(self, factory_string):
813
dimension = 128
914
n = 2500
1015
db_vectors = np.random.random((n, dimension)).astype('float32')
11-
code_size = 16
12-
res = faiss.StandardGpuResources()
13-
index_pq = faiss.IndexPQ(dimension, code_size, 6)
14-
index_pq.train(db_vectors)
15-
index_pq.add(db_vectors)
16-
self.assertRaisesRegex(Exception, ".*not implemented.*",
17-
faiss.index_cpu_to_gpu, res, 0, index_pq)
18-
19-
def test_index_cpu_to_gpu_does_not_throw_with_index_flat(self):
20-
dimension = 128
21-
n = 100
22-
db_vectors = np.random.random((n, dimension)).astype('float32')
23-
res = faiss.StandardGpuResources()
24-
index_flat = faiss.IndexFlatL2(dimension)
25-
index_flat.add(db_vectors)
16+
index = faiss.index_factory(dimension, factory_string)
17+
index.train(db_vectors)
18+
if factory_string.startswith("IDMap"):
19+
index.add_with_ids(db_vectors, np.arange(n))
20+
else:
21+
index.add(db_vectors)
22+
return index
23+
24+
def create_and_clone(self, factory_string,
25+
enableCpuFallback=None,
26+
use_raft=None):
27+
idx = self.create_index(factory_string)
28+
config = faiss.GpuClonerOptions()
29+
if enableCpuFallback is not None:
30+
config.enableCpuFallback = enableCpuFallback
31+
if use_raft is not None:
32+
config.use_raft = use_raft
33+
faiss.index_cpu_to_gpu(self.res, 0, idx, config)
34+
35+
def verify_throws_on_unsupported_index(self, factory_string):
36+
try:
37+
self.create_and_clone(factory_string)
38+
except Exception as e:
39+
if "not implemented" not in str(e):
40+
self.fail("Expected an exception but no exception was "
41+
"thrown for factory_string: %s." % factory_string)
42+
43+
def verify_succeeds_on_supported_index(self, factory_string, use_raft=None):
2644
try:
27-
faiss.index_cpu_to_gpu(res, 0, index_flat)
28-
except Exception:
29-
self.fail("index_cpu_to_gpu() threw an unexpected exception.")
45+
self.create_and_clone(factory_string, use_raft=use_raft)
46+
except Exception as e:
47+
self.fail("Unexpected exception thrown factory_string: "
48+
"%s; error message: %s." % (factory_string, str(e)))
49+
50+
def verify_succeeds_on_unsupported_index_with_fallback_enabled(
51+
self, factory_string, use_raft=None):
52+
try:
53+
self.create_and_clone(factory_string, enableCpuFallback=True,
54+
use_raft=use_raft)
55+
except Exception as e:
56+
self.fail("Unexpected exception thrown factory_string: "
57+
"%s; error message: %s." % (factory_string, str(e)))
58+
59+
def test_index_cpu_to_gpu_unsupported_indices(self):
60+
self.verify_throws_on_unsupported_index("PQ16")
61+
self.verify_throws_on_unsupported_index("LSHrt")
62+
self.verify_throws_on_unsupported_index("HNSW")
63+
self.verify_throws_on_unsupported_index("HNSW,PQ16")
64+
self.verify_throws_on_unsupported_index("IDMap,PQ16")
65+
self.verify_throws_on_unsupported_index("IVF256,ITQ64,SH1.2")
66+
67+
def test_index_cpu_to_gpu_supported_indices(self):
68+
self.verify_succeeds_on_supported_index("Flat")
69+
self.verify_succeeds_on_supported_index("IVF1,Flat")
70+
self.verify_succeeds_on_supported_index("IVF32,PQ8")
71+
72+
# set use_raft to false, this index type is not supported on RAFT
73+
self.verify_succeeds_on_supported_index("IVF32,SQ8", use_raft=False)
74+
75+
def test_index_cpu_to_gpu_unsupported_indices_with_fallback_enabled(self):
76+
self.verify_succeeds_on_unsupported_index_with_fallback_enabled("IDMap,Flat")
77+
self.verify_succeeds_on_unsupported_index_with_fallback_enabled("PCA12,IVF32,Flat")
78+
self.verify_succeeds_on_unsupported_index_with_fallback_enabled("PCA32,IVF32,PQ8")
79+
self.verify_succeeds_on_unsupported_index_with_fallback_enabled("PCA32,IVF32,PQ8np")
80+
81+
# set use_raft to false, this index type is not supported on RAFT
82+
self.verify_succeeds_on_unsupported_index_with_fallback_enabled(
83+
"PCA32,IVF32,SQ8", use_raft=False)

faiss/gpu/test/test_multi_gpu.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -244,5 +244,7 @@ def test_cpu_to_gpu_IVFFlat(self):
244244
def test_set_gpu_param(self):
245245
index = faiss.index_factory(12, "PCAR8,IVF10,PQ4")
246246
res = faiss.StandardGpuResources()
247-
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
247+
options = faiss.GpuClonerOptions()
248+
options.enableCpuFallback = True
249+
gpu_index = faiss.index_cpu_to_gpu(res, 0, index, options)
248250
faiss.GpuParameterSpace().set_index_parameter(gpu_index, "nprobe", 3)

faiss/impl/FaissAssert.h

-7
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,6 @@
9494
} \
9595
} while (false)
9696

97-
#define FAISS_THROW_IF_MSG(X, MSG) \
98-
do { \
99-
if (X) { \
100-
FAISS_THROW_FMT("Error: '%s' failed: " MSG, #X); \
101-
} \
102-
} while (false)
103-
10497
#define FAISS_THROW_IF_NOT_MSG(X, MSG) \
10598
do { \
10699
if (!(X)) { \

0 commit comments

Comments
 (0)