9
9
#include < omp.h>
10
10
11
11
#include < memory>
12
+ #include < numeric>
12
13
13
14
#include < faiss/IndexAdditiveQuantizer.h>
14
15
#include < faiss/IndexIVFAdditiveQuantizer.h>
@@ -529,20 +530,30 @@ void handle_ivf(
529
530
faiss::IndexIVF* index,
530
531
int64_t shard_count,
531
532
const std::string& filename_template,
532
- ShardingFunction* sharding_function) {
533
+ ShardingFunction* sharding_function,
534
+ bool generate_ids) {
533
535
std::vector<faiss::IndexIVF*> sharded_indexes (shard_count);
534
536
auto clone = static_cast <faiss::IndexIVF*>(faiss::clone_index (index ));
535
537
clone->quantizer ->reset ();
536
538
for (int64_t i = 0 ; i < shard_count; i++) {
537
539
sharded_indexes[i] =
538
540
static_cast <faiss::IndexIVF*>(faiss::clone_index (clone));
541
+ if (generate_ids) {
542
+ // Assume the quantizer does not natively support add_with_ids.
543
+ sharded_indexes[i]->quantizer =
544
+ new IndexIDMap2 (sharded_indexes[i]->quantizer );
545
+ }
539
546
}
540
547
541
548
// assign centroids to each sharded Index based on sharding_function, and
542
549
// add them to the quantizer of each sharded index
543
550
std::vector<std::vector<float >> sharded_centroids (shard_count);
551
+ std::vector<std::vector<idx_t >> xids (shard_count);
544
552
for (int64_t i = 0 ; i < index ->quantizer ->ntotal ; i++) {
545
553
int64_t shard_id = (*sharding_function)(i, shard_count);
554
+ // Since the quantizer does not natively support add_with_ids, we simply
555
+ // generate them.
556
+ xids[shard_id].push_back (i);
546
557
float * reconstructed = new float [index ->quantizer ->d ];
547
558
index ->quantizer ->reconstruct (i, reconstructed);
548
559
sharded_centroids[shard_id].insert (
@@ -552,9 +563,16 @@ void handle_ivf(
552
563
delete[] reconstructed;
553
564
}
554
565
for (int64_t i = 0 ; i < shard_count; i++) {
555
- sharded_indexes[i]->quantizer ->add (
556
- sharded_centroids[i].size () / index ->quantizer ->d ,
557
- sharded_centroids[i].data ());
566
+ if (generate_ids) {
567
+ sharded_indexes[i]->quantizer ->add_with_ids (
568
+ sharded_centroids[i].size () / index ->quantizer ->d ,
569
+ sharded_centroids[i].data (),
570
+ xids[i].data ());
571
+ } else {
572
+ sharded_indexes[i]->quantizer ->add (
573
+ sharded_centroids[i].size () / index ->quantizer ->d ,
574
+ sharded_centroids[i].data ());
575
+ }
558
576
}
559
577
560
578
for (int64_t i = 0 ; i < shard_count; i++) {
@@ -572,7 +590,8 @@ void handle_binary_ivf(
572
590
faiss::IndexBinaryIVF* index,
573
591
int64_t shard_count,
574
592
const std::string& filename_template,
575
- ShardingFunction* sharding_function) {
593
+ ShardingFunction* sharding_function,
594
+ bool generate_ids) {
576
595
std::vector<faiss::IndexBinaryIVF*> sharded_indexes (shard_count);
577
596
578
597
auto clone = static_cast <faiss::IndexBinaryIVF*>(
@@ -582,14 +601,23 @@ void handle_binary_ivf(
582
601
for (int64_t i = 0 ; i < shard_count; i++) {
583
602
sharded_indexes[i] = static_cast <faiss::IndexBinaryIVF*>(
584
603
faiss::clone_binary_index (clone));
604
+ if (generate_ids) {
605
+ // Assume the quantizer does not natively support add_with_ids.
606
+ sharded_indexes[i]->quantizer =
607
+ new IndexBinaryIDMap2 (sharded_indexes[i]->quantizer );
608
+ }
585
609
}
586
610
587
611
// assign centroids to each sharded Index based on sharding_function, and
588
612
// add them to the quantizer of each sharded index
589
613
int64_t reconstruction_size = index ->quantizer ->d / 8 ;
590
614
std::vector<std::vector<uint8_t >> sharded_centroids (shard_count);
615
+ std::vector<std::vector<idx_t >> xids (shard_count);
591
616
for (int64_t i = 0 ; i < index ->quantizer ->ntotal ; i++) {
592
617
int64_t shard_id = (*sharding_function)(i, shard_count);
618
+ // Since the quantizer does not natively support add_with_ids, we simply
619
+ // generate them.
620
+ xids[shard_id].push_back (i);
593
621
uint8_t * reconstructed = new uint8_t [reconstruction_size];
594
622
index ->quantizer ->reconstruct (i, reconstructed);
595
623
sharded_centroids[shard_id].insert (
@@ -599,9 +627,16 @@ void handle_binary_ivf(
599
627
delete[] reconstructed;
600
628
}
601
629
for (int64_t i = 0 ; i < shard_count; i++) {
602
- sharded_indexes[i]->quantizer ->add (
603
- sharded_centroids[i].size () / reconstruction_size,
604
- sharded_centroids[i].data ());
630
+ if (generate_ids) {
631
+ sharded_indexes[i]->quantizer ->add_with_ids (
632
+ sharded_centroids[i].size () / reconstruction_size,
633
+ sharded_centroids[i].data (),
634
+ xids[i].data ());
635
+ } else {
636
+ sharded_indexes[i]->quantizer ->add (
637
+ sharded_centroids[i].size () / reconstruction_size,
638
+ sharded_centroids[i].data ());
639
+ }
605
640
}
606
641
607
642
for (int64_t i = 0 ; i < shard_count; i++) {
@@ -620,7 +655,8 @@ void sharding_helper(
620
655
IndexType* index,
621
656
int64_t shard_count,
622
657
const std::string& filename_template,
623
- ShardingFunction* sharding_function) {
658
+ ShardingFunction* sharding_function,
659
+ bool generate_ids) {
624
660
FAISS_THROW_IF_MSG (index ->quantizer ->ntotal == 0 , " No centroids to shard." );
625
661
FAISS_THROW_IF_MSG (
626
662
filename_template.find (" %d" ) == std::string::npos,
@@ -636,30 +672,44 @@ void sharding_helper(
636
672
dynamic_cast <faiss::IndexIVF*>(index ),
637
673
shard_count,
638
674
filename_template,
639
- sharding_function);
675
+ sharding_function,
676
+ generate_ids);
640
677
} else if (typeid (IndexType) == typeid (faiss::IndexBinaryIVF)) {
641
678
handle_binary_ivf (
642
679
dynamic_cast <faiss::IndexBinaryIVF*>(index ),
643
680
shard_count,
644
681
filename_template,
645
- sharding_function);
682
+ sharding_function,
683
+ generate_ids);
646
684
}
647
685
}
648
686
649
687
void shard_ivf_index_centroids (
650
688
faiss::IndexIVF* index,
651
689
int64_t shard_count,
652
690
const std::string& filename_template,
653
- ShardingFunction* sharding_function) {
654
- sharding_helper (index , shard_count, filename_template, sharding_function);
691
+ ShardingFunction* sharding_function,
692
+ bool generate_ids) {
693
+ sharding_helper (
694
+ index ,
695
+ shard_count,
696
+ filename_template,
697
+ sharding_function,
698
+ generate_ids);
655
699
}
656
700
657
701
void shard_binary_ivf_index_centroids (
658
702
faiss::IndexBinaryIVF* index,
659
703
int64_t shard_count,
660
704
const std::string& filename_template,
661
- ShardingFunction* sharding_function) {
662
- sharding_helper (index , shard_count, filename_template, sharding_function);
705
+ ShardingFunction* sharding_function,
706
+ bool generate_ids) {
707
+ sharding_helper (
708
+ index ,
709
+ shard_count,
710
+ filename_template,
711
+ sharding_function,
712
+ generate_ids);
663
713
}
664
714
665
715
} // namespace ivflib
0 commit comments