@@ -84,9 +84,7 @@ def sp(x):
84
84
b = btab [0 ]
85
85
dis_new = self .compute_dis_quant (codes , LUTq , biasq , a , b )
86
86
87
- # print(a, b, dis_ref.sum())
88
87
avg_realtive_error = np .abs (dis_new - dis_ref ).sum () / dis_ref .sum ()
89
- # print('a=', a, 'avg_relative_error=', avg_realtive_error)
90
88
self .assertLess (avg_realtive_error , 0.0005 )
91
89
92
90
def test_no_residual_ip (self ):
@@ -228,8 +226,6 @@ def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2):
228
226
229
227
m3 = three_metrics (Da , Ia , Db , Ib )
230
228
231
-
232
- # print(by_residual, metric, recall_at_1, recall_at_10, intersection_at_10)
233
229
ref_results = {
234
230
(True , 1 ): [0.985 , 1.0 , 9.872 ],
235
231
(True , 0 ): [ 0.987 , 1.0 , 9.914 ],
@@ -261,36 +257,80 @@ class TestEquivPQ(unittest.TestCase):
261
257
262
258
def test_equiv_pq (self ):
263
259
ds = datasets .SyntheticDataset (32 , 2000 , 200 , 4 )
260
+ xq = ds .get_queries ()
264
261
265
262
index = faiss .index_factory (32 , "IVF1,PQ16x4np" )
266
263
index .by_residual = False
267
264
# force coarse quantizer
268
265
index .quantizer .add (np .zeros ((1 , 32 ), dtype = 'float32' ))
269
266
index .train (ds .get_train ())
270
267
index .add (ds .get_database ())
271
- Dref , Iref = index .search (ds . get_queries () , 4 )
268
+ Dref , Iref = index .search (xq , 4 )
272
269
273
270
index_pq = faiss .index_factory (32 , "PQ16x4np" )
274
271
index_pq .pq = index .pq
275
272
index_pq .is_trained = True
276
273
index_pq .codes = faiss . downcast_InvertedLists (
277
274
index .invlists ).codes .at (0 )
278
275
index_pq .ntotal = index .ntotal
279
- Dnew , Inew = index_pq .search (ds . get_queries () , 4 )
276
+ Dnew , Inew = index_pq .search (xq , 4 )
280
277
281
278
np .testing .assert_array_equal (Iref , Inew )
282
279
np .testing .assert_array_equal (Dref , Dnew )
283
280
284
281
index_pq2 = faiss .IndexPQFastScan (index_pq )
285
282
index_pq2 .implem = 12
286
- Dref , Iref = index_pq2 .search (ds . get_queries () , 4 )
283
+ Dref , Iref = index_pq2 .search (xq , 4 )
287
284
288
285
index2 = faiss .IndexIVFPQFastScan (index )
289
286
index2 .implem = 12
290
- Dnew , Inew = index2 .search (ds . get_queries () , 4 )
287
+ Dnew , Inew = index2 .search (xq , 4 )
291
288
np .testing .assert_array_equal (Iref , Inew )
292
289
np .testing .assert_array_equal (Dref , Dnew )
293
290
291
+ # test encode and decode
292
+
293
+ np .testing .assert_array_equal (
294
+ index_pq .sa_encode (xq ),
295
+ index2 .sa_encode (xq )
296
+ )
297
+
298
+ np .testing .assert_array_equal (
299
+ index_pq .sa_decode (index_pq .sa_encode (xq )),
300
+ index2 .sa_decode (index2 .sa_encode (xq ))
301
+ )
302
+
303
+ np .testing .assert_array_equal (
304
+ ((index_pq .sa_decode (index_pq .sa_encode (xq )) - xq ) ** 2 ).sum (1 ),
305
+ ((index2 .sa_decode (index2 .sa_encode (xq )) - xq ) ** 2 ).sum (1 )
306
+ )
307
+
308
+ def test_equiv_pq_encode_decode (self ):
309
+ ds = datasets .SyntheticDataset (32 , 1000 , 200 , 10 )
310
+ xq = ds .get_queries ()
311
+
312
+ index_ivfpq = faiss .index_factory (ds .d , "IVF10,PQ8x4np" )
313
+ index_ivfpq .train (ds .get_train ())
314
+
315
+ index_ivfpqfs = faiss .IndexIVFPQFastScan (index_ivfpq )
316
+
317
+ np .testing .assert_array_equal (
318
+ index_ivfpq .sa_encode (xq ),
319
+ index_ivfpqfs .sa_encode (xq )
320
+ )
321
+
322
+ np .testing .assert_array_equal (
323
+ index_ivfpq .sa_decode (index_ivfpq .sa_encode (xq )),
324
+ index_ivfpqfs .sa_decode (index_ivfpqfs .sa_encode (xq ))
325
+ )
326
+
327
+ np .testing .assert_array_equal (
328
+ ((index_ivfpq .sa_decode (index_ivfpq .sa_encode (xq )) - xq ) ** 2 )
329
+ .sum (1 ),
330
+ ((index_ivfpqfs .sa_decode (index_ivfpqfs .sa_encode (xq )) - xq ) ** 2 )
331
+ .sum (1 )
332
+ )
333
+
294
334
295
335
class TestIVFImplem12 (unittest .TestCase ):
296
336
@@ -463,7 +503,6 @@ def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
463
503
Dnew , Inew = index2 .search (ds .get_queries (), 10 )
464
504
465
505
m3 = three_metrics (Dref , Iref , Dnew , Inew )
466
- # print((by_residual, metric, d), ":", m3)
467
506
ref_m3_tab = {
468
507
(True , 1 , 32 ): (0.995 , 1.0 , 9.91 ),
469
508
(True , 0 , 32 ): (0.99 , 1.0 , 9.91 ),
@@ -554,7 +593,6 @@ def subtest_accuracy(self, aq, st, by_residual, implem, metric_type='L2'):
554
593
recall_ref = (Iref == gt ).sum () / nq
555
594
recall1 = (I1 == gt ).sum () / nq
556
595
557
- print (aq , st , by_residual , implem , metric_type , recall_ref , recall1 )
558
596
assert abs (recall_ref - recall1 ) < 0.051
559
597
560
598
def xx_test_accuracy (self ):
@@ -599,7 +637,6 @@ def subtest_rescale_accuracy(self, aq, st, by_residual, implem):
599
637
recall_ref = (Iref == gt ).sum () / nq
600
638
recall1 = (I1 == gt ).sum () / nq
601
639
602
- print (aq , st , by_residual , implem , recall_ref , recall1 )
603
640
assert abs (recall_ref - recall1 ) < 0.05
604
641
605
642
def xx_test_rescale_accuracy (self ):
@@ -624,7 +661,6 @@ def subtest_from_ivfaq(self, implem):
624
661
nq = Iref .shape [0 ]
625
662
recall_ref = (Iref == gt ).sum () / nq
626
663
recall1 = (I1 == gt ).sum () / nq
627
- print (recall_ref , recall1 )
628
664
assert abs (recall_ref - recall1 ) < 0.02
629
665
630
666
def test_from_ivfaq (self ):
@@ -763,7 +799,6 @@ def subtest_accuracy(self, paq):
763
799
recall_ref = (Iref == gt ).sum () / nq
764
800
recall1 = (I1 == gt ).sum () / nq
765
801
766
- print (paq , recall_ref , recall1 )
767
802
assert abs (recall_ref - recall1 ) < 0.05
768
803
769
804
def test_accuracy_PLSQ (self ):
@@ -847,7 +882,6 @@ def do_test(self, metric=faiss.METRIC_L2):
847
882
# find a reasonable radius
848
883
D , I = index .search (ds .get_queries (), 10 )
849
884
radius = np .median (D [:, - 1 ])
850
- # print("radius=", radius)
851
885
lims1 , D1 , I1 = index .range_search (ds .get_queries (), radius )
852
886
853
887
index2 = faiss .IndexIVFPQFastScan (index )
@@ -860,7 +894,6 @@ def do_test(self, metric=faiss.METRIC_L2):
860
894
for i in range (ds .nq ):
861
895
ref = set (I1 [lims1 [i ]: lims1 [i + 1 ]])
862
896
new = set (I2 [lims2 [i ]: lims2 [i + 1 ]])
863
- print (ref , new )
864
897
nmiss += len (ref - new )
865
898
nextra += len (new - ref )
866
899
0 commit comments