Skip to content

Commit 868813f

Browse files
authored
Merge branch 'main' into fix-swig
2 parents 829e258 + 14b8af6 commit 868813f

File tree

4 files changed

+79
-19
lines changed

4 files changed

+79
-19
lines changed

.circleci/config.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ jobs:
168168
command: |
169169
cd conda
170170
conda build faiss-gpu-raft --variants '{ "cudatoolkit": "<<parameters.cuda>>", "c_compiler_version": "<<parameters.compiler_version>>", "cxx_compiler_version": "<<parameters.compiler_version>>" }' \
171-
-c pytorch -c nvidia/label/cuda-<<parameters.cuda>> -c nvidia -c rapidsai -c conda-forge
171+
-c pytorch -c nvidia/label/cuda-<<parameters.cuda>> -c nvidia -c rapidsai -c rapidsai-nightly -c conda-forge
172172
- when:
173173
condition:
174174
and:
@@ -182,7 +182,7 @@ jobs:
182182
command: |
183183
cd conda
184184
conda build faiss-gpu-raft --variants '{ "cudatoolkit": "<<parameters.cuda>>", "c_compiler_version": "<<parameters.compiler_version>>", "cxx_compiler_version": "<<parameters.compiler_version>>" }' \
185-
--user pytorch --label <<parameters.label>> -c pytorch -c nvidia/label/cuda-<<parameters.cuda>> -c nvidia -c rapidsai -c conda-forge
185+
--user pytorch --label <<parameters.label>> -c pytorch -c nvidia/label/cuda-<<parameters.cuda>> -c nvidia -c rapidsai -c rapidsai-nightly -c conda-forge
186186
187187
build_cmake:
188188
parameters:

faiss/IndexIVF.h

+8
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,14 @@ struct IndexIVF : Index, IndexIVFInterface {
433433

434434
/* The standalone codec interface (except sa_decode that is specific) */
435435
size_t sa_code_size() const override;
436+
437+
/** encode a set of vectors
438+
* sa_encode will call encode_vector with include_listno=true
439+
* @param n nb of vectors to encode
440+
* @param x the vectors to encode
441+
* @param bytes output array for the codes
442+
* @return nb of bytes written to codes
443+
*/
436444
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
437445

438446
IndexIVF();

faiss/IndexIVFPQFastScan.cpp

+21-2
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,28 @@ void IndexIVFPQFastScan::compute_LUT(
286286
}
287287
}
288288

289-
void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
289+
void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
290290
const {
291-
pq.decode(bytes, x, n);
291+
size_t coarse_size = coarse_code_size();
292+
293+
#pragma omp parallel if (n > 1)
294+
{
295+
std::vector<float> residual(d);
296+
297+
#pragma omp for
298+
for (idx_t i = 0; i < n; i++) {
299+
const uint8_t* code = codes + i * (code_size + coarse_size);
300+
int64_t list_no = decode_listno(code);
301+
float* xi = x + i * d;
302+
pq.decode(code + coarse_size, xi);
303+
if (by_residual) {
304+
quantizer->reconstruct(list_no, residual.data());
305+
for (size_t j = 0; j < d; j++) {
306+
xi[j] += residual[j];
307+
}
308+
}
309+
}
310+
}
292311
}
293312

294313
} // namespace faiss

tests/test_fast_scan_ivf.py

+48-15
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ def sp(x):
8484
b = btab[0]
8585
dis_new = self.compute_dis_quant(codes, LUTq, biasq, a, b)
8686

87-
# print(a, b, dis_ref.sum())
8887
avg_realtive_error = np.abs(dis_new - dis_ref).sum() / dis_ref.sum()
89-
# print('a=', a, 'avg_relative_error=', avg_realtive_error)
9088
self.assertLess(avg_realtive_error, 0.0005)
9189

9290
def test_no_residual_ip(self):
@@ -228,8 +226,6 @@ def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2):
228226

229227
m3 = three_metrics(Da, Ia, Db, Ib)
230228

231-
232-
# print(by_residual, metric, recall_at_1, recall_at_10, intersection_at_10)
233229
ref_results = {
234230
(True, 1): [0.985, 1.0, 9.872],
235231
(True, 0): [ 0.987, 1.0, 9.914],
@@ -261,36 +257,80 @@ class TestEquivPQ(unittest.TestCase):
261257

262258
def test_equiv_pq(self):
263259
ds = datasets.SyntheticDataset(32, 2000, 200, 4)
260+
xq = ds.get_queries()
264261

265262
index = faiss.index_factory(32, "IVF1,PQ16x4np")
266263
index.by_residual = False
267264
# force coarse quantizer
268265
index.quantizer.add(np.zeros((1, 32), dtype='float32'))
269266
index.train(ds.get_train())
270267
index.add(ds.get_database())
271-
Dref, Iref = index.search(ds.get_queries(), 4)
268+
Dref, Iref = index.search(xq, 4)
272269

273270
index_pq = faiss.index_factory(32, "PQ16x4np")
274271
index_pq.pq = index.pq
275272
index_pq.is_trained = True
276273
index_pq.codes = faiss. downcast_InvertedLists(
277274
index.invlists).codes.at(0)
278275
index_pq.ntotal = index.ntotal
279-
Dnew, Inew = index_pq.search(ds.get_queries(), 4)
276+
Dnew, Inew = index_pq.search(xq, 4)
280277

281278
np.testing.assert_array_equal(Iref, Inew)
282279
np.testing.assert_array_equal(Dref, Dnew)
283280

284281
index_pq2 = faiss.IndexPQFastScan(index_pq)
285282
index_pq2.implem = 12
286-
Dref, Iref = index_pq2.search(ds.get_queries(), 4)
283+
Dref, Iref = index_pq2.search(xq, 4)
287284

288285
index2 = faiss.IndexIVFPQFastScan(index)
289286
index2.implem = 12
290-
Dnew, Inew = index2.search(ds.get_queries(), 4)
287+
Dnew, Inew = index2.search(xq, 4)
291288
np.testing.assert_array_equal(Iref, Inew)
292289
np.testing.assert_array_equal(Dref, Dnew)
293290

291+
# test encode and decode
292+
293+
np.testing.assert_array_equal(
294+
index_pq.sa_encode(xq),
295+
index2.sa_encode(xq)
296+
)
297+
298+
np.testing.assert_array_equal(
299+
index_pq.sa_decode(index_pq.sa_encode(xq)),
300+
index2.sa_decode(index2.sa_encode(xq))
301+
)
302+
303+
np.testing.assert_array_equal(
304+
((index_pq.sa_decode(index_pq.sa_encode(xq)) - xq) ** 2).sum(1),
305+
((index2.sa_decode(index2.sa_encode(xq)) - xq) ** 2).sum(1)
306+
)
307+
308+
def test_equiv_pq_encode_decode(self):
309+
ds = datasets.SyntheticDataset(32, 1000, 200, 10)
310+
xq = ds.get_queries()
311+
312+
index_ivfpq = faiss.index_factory(ds.d, "IVF10,PQ8x4np")
313+
index_ivfpq.train(ds.get_train())
314+
315+
index_ivfpqfs = faiss.IndexIVFPQFastScan(index_ivfpq)
316+
317+
np.testing.assert_array_equal(
318+
index_ivfpq.sa_encode(xq),
319+
index_ivfpqfs.sa_encode(xq)
320+
)
321+
322+
np.testing.assert_array_equal(
323+
index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)),
324+
index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq))
325+
)
326+
327+
np.testing.assert_array_equal(
328+
((index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)) - xq) ** 2)
329+
.sum(1),
330+
((index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq)) - xq) ** 2)
331+
.sum(1)
332+
)
333+
294334

295335
class TestIVFImplem12(unittest.TestCase):
296336

@@ -463,7 +503,6 @@ def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
463503
Dnew, Inew = index2.search(ds.get_queries(), 10)
464504

465505
m3 = three_metrics(Dref, Iref, Dnew, Inew)
466-
# print((by_residual, metric, d), ":", m3)
467506
ref_m3_tab = {
468507
(True, 1, 32): (0.995, 1.0, 9.91),
469508
(True, 0, 32): (0.99, 1.0, 9.91),
@@ -554,7 +593,6 @@ def subtest_accuracy(self, aq, st, by_residual, implem, metric_type='L2'):
554593
recall_ref = (Iref == gt).sum() / nq
555594
recall1 = (I1 == gt).sum() / nq
556595

557-
print(aq, st, by_residual, implem, metric_type, recall_ref, recall1)
558596
assert abs(recall_ref - recall1) < 0.051
559597

560598
def xx_test_accuracy(self):
@@ -599,7 +637,6 @@ def subtest_rescale_accuracy(self, aq, st, by_residual, implem):
599637
recall_ref = (Iref == gt).sum() / nq
600638
recall1 = (I1 == gt).sum() / nq
601639

602-
print(aq, st, by_residual, implem, recall_ref, recall1)
603640
assert abs(recall_ref - recall1) < 0.05
604641

605642
def xx_test_rescale_accuracy(self):
@@ -624,7 +661,6 @@ def subtest_from_ivfaq(self, implem):
624661
nq = Iref.shape[0]
625662
recall_ref = (Iref == gt).sum() / nq
626663
recall1 = (I1 == gt).sum() / nq
627-
print(recall_ref, recall1)
628664
assert abs(recall_ref - recall1) < 0.02
629665

630666
def test_from_ivfaq(self):
@@ -763,7 +799,6 @@ def subtest_accuracy(self, paq):
763799
recall_ref = (Iref == gt).sum() / nq
764800
recall1 = (I1 == gt).sum() / nq
765801

766-
print(paq, recall_ref, recall1)
767802
assert abs(recall_ref - recall1) < 0.05
768803

769804
def test_accuracy_PLSQ(self):
@@ -847,7 +882,6 @@ def do_test(self, metric=faiss.METRIC_L2):
847882
# find a reasonable radius
848883
D, I = index.search(ds.get_queries(), 10)
849884
radius = np.median(D[:, -1])
850-
# print("radius=", radius)
851885
lims1, D1, I1 = index.range_search(ds.get_queries(), radius)
852886

853887
index2 = faiss.IndexIVFPQFastScan(index)
@@ -860,7 +894,6 @@ def do_test(self, metric=faiss.METRIC_L2):
860894
for i in range(ds.nq):
861895
ref = set(I1[lims1[i]: lims1[i + 1]])
862896
new = set(I2[lims2[i]: lims2[i + 1]])
863-
print(ref, new)
864897
nmiss += len(ref - new)
865898
nextra += len(new - ref)
866899

0 commit comments

Comments
 (0)