Skip to content

Commit 9590ad2

Browse files
mlomeli1facebook-github-bot
authored andcommitted
PQ with pytorch (#4116)
Summary: Pull Request resolved: #4116 This diff implements Product Quantization using Pytorch only. Reviewed By: mdouze Differential Revision: D67766798 fbshipit-source-id: fe2d44a674fc2056f7e2082e9765052c98fdc8f8
1 parent 0cbc2a8 commit 9590ad2

File tree

3 files changed

+82
-13
lines changed

3 files changed

+82
-13
lines changed

contrib/torch/clustering.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# the kmeans can produce both torch and numpy centroids
1414
from faiss.contrib.clustering import kmeans
1515

16+
1617
class DatasetAssign:
1718
"""Wrapper for a tensor that offers a function to assign the vectors
1819
to centroids. All other implementations offer the same interface"""

contrib/torch/quantization.py

+52-9
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,47 @@
77
This contrib module contains Pytorch code for quantization.
88
"""
99

10-
import numpy as np
1110
import torch
1211
import faiss
13-
14-
from faiss.contrib import torch_utils
12+
import math
13+
from faiss.contrib.torch import clustering
14+
# the kmeans can produce both torch and numpy centroids
1515

1616

1717
class Quantizer:
1818

1919
def __init__(self, d, code_size):
20+
"""
21+
d: dimension of vectors
22+
code_size: nb of bytes of the code (per vector)
23+
"""
2024
self.d = d
2125
self.code_size = code_size
2226

2327
def train(self, x):
28+
"""
29+
takes a n-by-d array and peforms training
30+
"""
2431
pass
2532

2633
def encode(self, x):
34+
"""
35+
takes a n-by-d float array, encodes to an n-by-code_size uint8 array
36+
"""
2737
pass
2838

29-
def decode(self, x):
39+
def decode(self, codes):
40+
"""
41+
takes a n-by-code_size uint8 array, returns a n-by-d array
42+
"""
3043
pass
3144

3245

3346
class VectorQuantizer(Quantizer):
3447

3548
def __init__(self, d, k):
36-
code_size = int(torch.ceil(torch.log2(k) / 8))
49+
50+
code_size = int(math.ceil(torch.log2(k) / 8))
3751
Quantizer.__init__(d, code_size)
3852
self.k = k
3953

@@ -42,12 +56,41 @@ def train(self, x):
4256

4357

4458
class ProductQuantizer(Quantizer):
45-
4659
def __init__(self, d, M, nbits):
47-
code_size = int(torch.ceil(M * nbits / 8))
48-
Quantizer.__init__(d, code_size)
60+
""" M: number of subvectors, d%M == 0
61+
nbits: number of bits that each vector is encoded into
62+
"""
63+
assert d % M == 0
64+
assert nbits == 8 # todo: implement other nbits values
65+
code_size = int(math.ceil(M * nbits / 8))
66+
Quantizer.__init__(self, d, code_size)
4967
self.M = M
5068
self.nbits = nbits
69+
self.code_size = code_size
5170

5271
def train(self, x):
53-
pass
72+
nc = 2 ** self.nbits
73+
sd = self.d // self.M
74+
dev = x.device
75+
dtype = x.dtype
76+
self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype)
77+
for m in range(self.M):
78+
xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M]
79+
data = clustering.DatasetAssign(xsub.contiguous())
80+
self.codebook[m] = clustering.kmeans(2 ** self.nbits, data)
81+
82+
def encode(self, x):
83+
codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
84+
for m in range(self.M):
85+
xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M]
86+
_, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1)
87+
codes[:, m] = I.ravel()
88+
return codes
89+
90+
def decode(self, codes):
91+
idxs = [codes[:, m].long() for m in range(self.M)]
92+
vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)]
93+
stacked_vectors = torch.stack(vectors, dim=1)
94+
cbd = self.codebook.shape[-1]
95+
x_rec = stacked_vectors.reshape(-1, cbd * self.M)
96+
return x_rec

tests/torch_test_contrib.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch # usort: skip
7-
import unittest # usort: skip
8-
import numpy as np # usort: skip
7+
import unittest # usort: skip
8+
import numpy as np # usort: skip
99

10-
import faiss # usort: skip
10+
import faiss # usort: skip
1111
import faiss.contrib.torch_utils # usort: skip
1212
from faiss.contrib import datasets
13-
from faiss.contrib.torch import clustering
13+
from faiss.contrib.torch import clustering, quantization
14+
1415

1516

1617

@@ -400,3 +401,27 @@ def test_python_kmeans(self):
400401
# 33498.332 33380.477
401402
# print(err, err2) 1/0
402403
self.assertLess(err2, err * 1.1)
404+
405+
406+
class TestQuantization(unittest.TestCase):
407+
def test_python_product_quantization(self):
408+
""" Test the python implementation of product quantization """
409+
d = 64
410+
n = 10000
411+
cs = 4
412+
nbits = 8
413+
M = 4
414+
x = np.random.random(size=(n, d)).astype('float32')
415+
pq = faiss.ProductQuantizer(d, cs, nbits)
416+
pq.train(x)
417+
codes = pq.compute_codes(x)
418+
x2 = pq.decode(codes)
419+
diff = ((x - x2)**2).sum()
420+
# vs pure pytorch impl
421+
xt = torch.from_numpy(x)
422+
my_pq = quantization.ProductQuantizer(d, M, nbits)
423+
my_pq.train(xt)
424+
my_codes = my_pq.encode(xt)
425+
xt2 = my_pq.decode(my_codes)
426+
my_diff = ((xt - xt2)**2).sum()
427+
self.assertLess(abs(diff - my_diff), 100)

0 commit comments

Comments
 (0)