11
11
from dataclasses import dataclass
12
12
from typing import ClassVar , Dict , List , Optional
13
13
14
- import faiss # @manual=//faiss/python:pyfaiss_gpu
14
+ import faiss # @manual=//faiss/python:pyfaiss
15
15
import numpy as np
16
16
from faiss .benchs .bench_fw .descriptors import IndexBaseDescriptor
17
17
18
- from faiss .contrib .evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
18
+ from faiss .contrib .evaluation import ( # @manual=//faiss/contrib:faiss_contrib
19
19
knn_intersection_measure ,
20
20
OperatingPointsWithRanges ,
21
21
)
22
- from faiss .contrib .factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
22
+ from faiss .contrib .factory_tools import ( # @manual=//faiss/contrib:faiss_contrib
23
23
reverse_index_factory ,
24
24
)
25
- from faiss .contrib .ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu
25
+ from faiss .contrib .ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib
26
26
add_preassigned ,
27
27
replace_ivf_quantizer ,
28
28
)
@@ -635,11 +635,12 @@ def get_index_name(self) -> Optional[str]:
635
635
636
636
def fetch_index (self ):
637
637
# read index from file if it is already available
638
+ index_filename = None
638
639
if self .index_path :
639
640
index_filename = os .path .basename (self .index_path )
640
- else :
641
+ elif self . index_name :
641
642
index_filename = self .index_name + "index"
642
- if self .io .file_exist (index_filename ):
643
+ if index_filename and self .io .file_exist (index_filename ):
643
644
if self .index_path :
644
645
index = self .io .read_index (
645
646
index_filename ,
@@ -681,7 +682,7 @@ def fetch_index(self):
681
682
)
682
683
assert index .ntotal == xb .shape [0 ] or index_ivf .ntotal == xb .shape [0 ]
683
684
logger .info ("Added vectors to index" )
684
- if self .serialize_full_index :
685
+ if self .serialize_full_index and index_filename :
685
686
codec_size = self .io .write_index (index , index_filename )
686
687
assert codec_size is not None
687
688
@@ -908,6 +909,7 @@ def get_codec(self):
908
909
class IndexFromFactory (Index ):
909
910
factory : Optional [str ] = None
910
911
training_vectors : Optional [DatasetDescriptor ] = None
912
+ assemble_opaque : bool = True
911
913
912
914
def __post_init__ (self ):
913
915
super ().__post_init__ ()
@@ -916,6 +918,19 @@ def __post_init__(self):
916
918
if self .factory != "Flat" and self .training_vectors is None :
917
919
raise ValueError (f"training_vectors is not set for { self .factory } " )
918
920
921
+ def get_codec_name (self ):
922
+ codec_name = super ().get_codec_name ()
923
+ if codec_name is None :
924
+ codec_name = f"{ self .factory .replace (',' , '_' )} ."
925
+ codec_name += f"d_{ self .d } .{ self .metric .upper ()} ."
926
+ if self .factory != "Flat" :
927
+ assert self .training_vectors is not None
928
+ codec_name += self .training_vectors .get_filename ("xt" )
929
+ if self .construction_params is not None :
930
+ codec_name += IndexBaseDescriptor .param_dict_list_to_name (self .construction_params )
931
+ self .codec_name = codec_name
932
+ return self .codec_name
933
+
919
934
def fetch_meta (self , dry_run = False ):
920
935
meta_filename = self .get_codec_name () + "json"
921
936
if self .io .file_exist (meta_filename ):
@@ -1021,14 +1036,13 @@ def get_quantizer(self, dry_run, pretransform=None):
1021
1036
def assemble (self , dry_run ):
1022
1037
logger .info (f"assemble { self .factory } " )
1023
1038
model = self .get_model ()
1024
- opaque = True
1025
1039
t_aggregate = 0
1026
1040
# try:
1027
1041
# reverse_index_factory(model)
1028
1042
# opaque = False
1029
1043
# except NotImplementedError:
1030
1044
# opaque = True
1031
- if opaque :
1045
+ if self . assemble_opaque :
1032
1046
codec = model
1033
1047
else :
1034
1048
if isinstance (model , faiss .IndexPreTransform ):
0 commit comments