Skip to content

Commit 835b3ea

Browse files
Michael Norrisfacebook-github-bot
Michael Norris
authored andcommitted
Fix IVF quantizer centroid sharding so IDs are generated (#4197)
Summary: Pull Request resolved: #4197 Ivan and I discussed 2 problems: 1. We may want to try to offload/shard PQ or SQ table data if there is a big enough win (pending) 2. IDs seem to be random after sharding. This diff solves 2. Root cause is that we add to quantizer without IDs. Instead, we wrap in IndexIDMap2 (which provides reconstruction, whereas IndexIDMap does not). Laser's quantizers are Flat and HNSW, so we can wrap like this. Reviewed By: ivansopin Differential Revision: D69832788 fbshipit-source-id: 331b6d1cf52666f5dac61e2b52302d46b0a83708
1 parent 65222b3 commit 835b3ea

File tree

3 files changed

+136
-31
lines changed

3 files changed

+136
-31
lines changed

faiss/IVFlib.cpp

+65-15
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <omp.h>
1010

1111
#include <memory>
12+
#include <numeric>
1213

1314
#include <faiss/IndexAdditiveQuantizer.h>
1415
#include <faiss/IndexIVFAdditiveQuantizer.h>
@@ -529,20 +530,30 @@ void handle_ivf(
529530
faiss::IndexIVF* index,
530531
int64_t shard_count,
531532
const std::string& filename_template,
532-
ShardingFunction* sharding_function) {
533+
ShardingFunction* sharding_function,
534+
bool generate_ids) {
533535
std::vector<faiss::IndexIVF*> sharded_indexes(shard_count);
534536
auto clone = static_cast<faiss::IndexIVF*>(faiss::clone_index(index));
535537
clone->quantizer->reset();
536538
for (int64_t i = 0; i < shard_count; i++) {
537539
sharded_indexes[i] =
538540
static_cast<faiss::IndexIVF*>(faiss::clone_index(clone));
541+
if (generate_ids) {
542+
// Assume the quantizer does not natively support add_with_ids.
543+
sharded_indexes[i]->quantizer =
544+
new IndexIDMap2(sharded_indexes[i]->quantizer);
545+
}
539546
}
540547

541548
// assign centroids to each sharded Index based on sharding_function, and
542549
// add them to the quantizer of each sharded index
543550
std::vector<std::vector<float>> sharded_centroids(shard_count);
551+
std::vector<std::vector<idx_t>> xids(shard_count);
544552
for (int64_t i = 0; i < index->quantizer->ntotal; i++) {
545553
int64_t shard_id = (*sharding_function)(i, shard_count);
554+
// Since the quantizer does not natively support add_with_ids, we simply
555+
// generate them.
556+
xids[shard_id].push_back(i);
546557
float* reconstructed = new float[index->quantizer->d];
547558
index->quantizer->reconstruct(i, reconstructed);
548559
sharded_centroids[shard_id].insert(
@@ -552,9 +563,16 @@ void handle_ivf(
552563
delete[] reconstructed;
553564
}
554565
for (int64_t i = 0; i < shard_count; i++) {
555-
sharded_indexes[i]->quantizer->add(
556-
sharded_centroids[i].size() / index->quantizer->d,
557-
sharded_centroids[i].data());
566+
if (generate_ids) {
567+
sharded_indexes[i]->quantizer->add_with_ids(
568+
sharded_centroids[i].size() / index->quantizer->d,
569+
sharded_centroids[i].data(),
570+
xids[i].data());
571+
} else {
572+
sharded_indexes[i]->quantizer->add(
573+
sharded_centroids[i].size() / index->quantizer->d,
574+
sharded_centroids[i].data());
575+
}
558576
}
559577

560578
for (int64_t i = 0; i < shard_count; i++) {
@@ -572,7 +590,8 @@ void handle_binary_ivf(
572590
faiss::IndexBinaryIVF* index,
573591
int64_t shard_count,
574592
const std::string& filename_template,
575-
ShardingFunction* sharding_function) {
593+
ShardingFunction* sharding_function,
594+
bool generate_ids) {
576595
std::vector<faiss::IndexBinaryIVF*> sharded_indexes(shard_count);
577596

578597
auto clone = static_cast<faiss::IndexBinaryIVF*>(
@@ -582,14 +601,23 @@ void handle_binary_ivf(
582601
for (int64_t i = 0; i < shard_count; i++) {
583602
sharded_indexes[i] = static_cast<faiss::IndexBinaryIVF*>(
584603
faiss::clone_binary_index(clone));
604+
if (generate_ids) {
605+
// Assume the quantizer does not natively support add_with_ids.
606+
sharded_indexes[i]->quantizer =
607+
new IndexBinaryIDMap2(sharded_indexes[i]->quantizer);
608+
}
585609
}
586610

587611
// assign centroids to each sharded Index based on sharding_function, and
588612
// add them to the quantizer of each sharded index
589613
int64_t reconstruction_size = index->quantizer->d / 8;
590614
std::vector<std::vector<uint8_t>> sharded_centroids(shard_count);
615+
std::vector<std::vector<idx_t>> xids(shard_count);
591616
for (int64_t i = 0; i < index->quantizer->ntotal; i++) {
592617
int64_t shard_id = (*sharding_function)(i, shard_count);
618+
// Since the quantizer does not natively support add_with_ids, we simply
619+
// generate them.
620+
xids[shard_id].push_back(i);
593621
uint8_t* reconstructed = new uint8_t[reconstruction_size];
594622
index->quantizer->reconstruct(i, reconstructed);
595623
sharded_centroids[shard_id].insert(
@@ -599,9 +627,16 @@ void handle_binary_ivf(
599627
delete[] reconstructed;
600628
}
601629
for (int64_t i = 0; i < shard_count; i++) {
602-
sharded_indexes[i]->quantizer->add(
603-
sharded_centroids[i].size() / reconstruction_size,
604-
sharded_centroids[i].data());
630+
if (generate_ids) {
631+
sharded_indexes[i]->quantizer->add_with_ids(
632+
sharded_centroids[i].size() / reconstruction_size,
633+
sharded_centroids[i].data(),
634+
xids[i].data());
635+
} else {
636+
sharded_indexes[i]->quantizer->add(
637+
sharded_centroids[i].size() / reconstruction_size,
638+
sharded_centroids[i].data());
639+
}
605640
}
606641

607642
for (int64_t i = 0; i < shard_count; i++) {
@@ -620,7 +655,8 @@ void sharding_helper(
620655
IndexType* index,
621656
int64_t shard_count,
622657
const std::string& filename_template,
623-
ShardingFunction* sharding_function) {
658+
ShardingFunction* sharding_function,
659+
bool generate_ids) {
624660
FAISS_THROW_IF_MSG(index->quantizer->ntotal == 0, "No centroids to shard.");
625661
FAISS_THROW_IF_MSG(
626662
filename_template.find("%d") == std::string::npos,
@@ -636,30 +672,44 @@ void sharding_helper(
636672
dynamic_cast<faiss::IndexIVF*>(index),
637673
shard_count,
638674
filename_template,
639-
sharding_function);
675+
sharding_function,
676+
generate_ids);
640677
} else if (typeid(IndexType) == typeid(faiss::IndexBinaryIVF)) {
641678
handle_binary_ivf(
642679
dynamic_cast<faiss::IndexBinaryIVF*>(index),
643680
shard_count,
644681
filename_template,
645-
sharding_function);
682+
sharding_function,
683+
generate_ids);
646684
}
647685
}
648686

649687
void shard_ivf_index_centroids(
650688
faiss::IndexIVF* index,
651689
int64_t shard_count,
652690
const std::string& filename_template,
653-
ShardingFunction* sharding_function) {
654-
sharding_helper(index, shard_count, filename_template, sharding_function);
691+
ShardingFunction* sharding_function,
692+
bool generate_ids) {
693+
sharding_helper(
694+
index,
695+
shard_count,
696+
filename_template,
697+
sharding_function,
698+
generate_ids);
655699
}
656700

657701
void shard_binary_ivf_index_centroids(
658702
faiss::IndexBinaryIVF* index,
659703
int64_t shard_count,
660704
const std::string& filename_template,
661-
ShardingFunction* sharding_function) {
662-
sharding_helper(index, shard_count, filename_template, sharding_function);
705+
ShardingFunction* sharding_function,
706+
bool generate_ids) {
707+
sharding_helper(
708+
index,
709+
shard_count,
710+
filename_template,
711+
sharding_function,
712+
generate_ids);
663713
}
664714

665715
} // namespace ivflib

faiss/IVFlib.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -191,19 +191,23 @@ struct DefaultShardingFunction : ShardingFunction {
191191
* @param filename_template Template for shard filenames.
192192
* @param sharding_function The function to shard by. The default is ith vector
193193
* mod shard_count.
194+
* @param generate_ids Generates ids using IndexIDMap2. If true, ids will
195+
* match the default ids in the unsharded index.
194196
* @return The number of shards written.
195197
*/
196198
void shard_ivf_index_centroids(
197199
IndexIVF* index,
198200
int64_t shard_count = 20,
199201
const std::string& filename_template = "shard.%d.index",
200-
ShardingFunction* sharding_function = nullptr);
202+
ShardingFunction* sharding_function = nullptr,
203+
bool generate_ids = false);
201204

202205
void shard_binary_ivf_index_centroids(
203206
faiss::IndexBinaryIVF* index,
204207
int64_t shard_count = 20,
205208
const std::string& filename_template = "shard.%d.index",
206-
ShardingFunction* sharding_function = nullptr);
209+
ShardingFunction* sharding_function = nullptr,
210+
bool generate_ids = false);
207211

208212
} // namespace ivflib
209213
} // namespace faiss

tests/test_ivflib.py

+65-14
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def default_sharding_function(self, i, shard_count):
199199
return i % shard_count
200200

201201
def verify_sharded_ivf_indexes(
202-
self, template, xb, shard_count, sharding_function):
202+
self, template, xb, shard_count, sharding_function, generate_ids=True):
203203
sharded_indexes_counters = [0] * shard_count
204204
sharded_indexes = []
205205
for i in range(shard_count):
@@ -208,15 +208,21 @@ def verify_sharded_ivf_indexes(
208208
else:
209209
index = faiss.read_index(template % i)
210210
sharded_indexes.append(index)
211+
211212
# Reconstruct and verify each centroid
212-
nb = len(xb)
213-
for i in range(nb):
214-
shard_id = sharding_function(i, shard_count)
215-
reconstructed = sharded_indexes[shard_id].quantizer.reconstruct(
216-
sharded_indexes_counters[shard_id])
217-
sharded_indexes_counters[shard_id] += 1
218-
print(f"reconstructed: {reconstructed} xb[i]: {xb[i]}")
219-
np.testing.assert_array_equal(reconstructed, xb[i])
213+
if generate_ids:
214+
for i in range(len(xb)):
215+
shard_id = sharding_function(i, shard_count)
216+
reconstructed = sharded_indexes[shard_id].quantizer.reconstruct(i)
217+
np.testing.assert_array_equal(reconstructed, xb[i])
218+
else:
219+
for i in range(len(xb)):
220+
shard_id = sharding_function(i, shard_count)
221+
reconstructed = sharded_indexes[shard_id].quantizer.reconstruct(
222+
sharded_indexes_counters[shard_id])
223+
sharded_indexes_counters[shard_id] += 1
224+
np.testing.assert_array_equal(reconstructed, xb[i])
225+
220226
# Clean up
221227
for i in range(shard_count):
222228
os.remove(template % i)
@@ -245,7 +251,9 @@ def test_save_index_shards_by_centroids_flat_quantizer_default_sharding(
245251
faiss.shard_ivf_index_centroids(
246252
index,
247253
shard_count,
248-
template
254+
template,
255+
None,
256+
True
249257
)
250258
self.verify_sharded_ivf_indexes(
251259
template, xb, shard_count, self.default_sharding_function)
@@ -264,7 +272,8 @@ def test_save_index_shards_by_centroids_flat_quantizer_custom_sharding(
264272
index,
265273
shard_count,
266274
template,
267-
self.custom_sharding_function
275+
self.custom_sharding_function,
276+
True
268277
)
269278
self.verify_sharded_ivf_indexes(
270279
template, xb, shard_count, self.custom_sharding_function)
@@ -282,7 +291,8 @@ def test_save_index_shards_by_centroids_hnsw_quantizer(self):
282291
index,
283292
shard_count,
284293
template,
285-
None
294+
None,
295+
True
286296
)
287297
self.verify_sharded_ivf_indexes(
288298
template, xb, shard_count, self.default_sharding_function)
@@ -299,7 +309,9 @@ def test_save_index_shards_by_centroids_binary_flat_quantizer(self):
299309
faiss.shard_binary_ivf_index_centroids(
300310
index,
301311
shard_count,
302-
template
312+
template,
313+
None,
314+
True
303315
)
304316
self.verify_sharded_ivf_indexes(
305317
template, xb, shard_count, self.default_sharding_function)
@@ -316,7 +328,46 @@ def test_save_index_shards_by_centroids_binary_hnsw_quantizer(self):
316328
faiss.shard_binary_ivf_index_centroids(
317329
index,
318330
shard_count,
319-
template
331+
template,
332+
None,
333+
True
320334
)
321335
self.verify_sharded_ivf_indexes(
322336
template, xb, shard_count, self.default_sharding_function)
337+
338+
def test_save_index_shards_without_id_generation(self):
339+
xb = np.random.randint(256, size=(self.nb, int(self.d / 8))).astype('uint8')
340+
quantizer = faiss.IndexBinaryHNSW(self.d, 32)
341+
index = faiss.IndexBinaryIVF(quantizer, self.d, self.nlist)
342+
shard_count = 5
343+
344+
index.quantizer.add(xb)
345+
346+
template = str(random.randint(0, 100000)) + "shard.%d.index"
347+
faiss.shard_binary_ivf_index_centroids(
348+
index,
349+
shard_count,
350+
template,
351+
None,
352+
False
353+
)
354+
self.verify_sharded_ivf_indexes(
355+
template, xb, shard_count, self.default_sharding_function, False)
356+
357+
xb = np.random.rand(self.nb, self.d).astype('float32')
358+
quantizer = faiss.IndexHNSWFlat(self.d, 32)
359+
index = faiss.IndexIVFFlat(quantizer, self.d, self.nlist)
360+
shard_count = 23
361+
362+
index.quantizer.add(xb)
363+
364+
template = str(random.randint(0, 100000)) + "shard.%d.index"
365+
faiss.shard_ivf_index_centroids(
366+
index,
367+
shard_count,
368+
template,
369+
None,
370+
False
371+
)
372+
self.verify_sharded_ivf_indexes(
373+
template, xb, shard_count, self.default_sharding_function, False)

0 commit comments

Comments
 (0)