7
7
This contrib module contains Pytorch code for quantization.
8
8
"""
9
9
10
- import numpy as np
11
10
import torch
12
11
import faiss
13
-
14
- from faiss .contrib import torch_utils
12
+ import math
13
+ from faiss .contrib .torch import clustering
14
+ # the kmeans can produce both torch and numpy centroids
15
15
16
16
17
17
class Quantizer :
18
18
19
19
def __init__ (self , d , code_size ):
20
+ """
21
+ d: dimension of vectors
22
+ code_size: nb of bytes of the code (per vector)
23
+ """
20
24
self .d = d
21
25
self .code_size = code_size
22
26
23
27
def train (self , x ):
28
+ """
29
+ takes a n-by-d array and peforms training
30
+ """
24
31
pass
25
32
26
33
def encode (self , x ):
34
+ """
35
+ takes a n-by-d float array, encodes to an n-by-code_size uint8 array
36
+ """
27
37
pass
28
38
29
- def decode (self , x ):
39
+ def decode (self , codes ):
40
+ """
41
+ takes a n-by-code_size uint8 array, returns a n-by-d array
42
+ """
30
43
pass
31
44
32
45
33
46
class VectorQuantizer (Quantizer ):
34
47
35
48
def __init__ (self , d , k ):
36
- code_size = int (torch .ceil (torch .log2 (k ) / 8 ))
49
+
50
+ code_size = int (math .ceil (torch .log2 (k ) / 8 ))
37
51
Quantizer .__init__ (d , code_size )
38
52
self .k = k
39
53
@@ -42,12 +56,41 @@ def train(self, x):
42
56
43
57
44
58
class ProductQuantizer (Quantizer ):
45
-
46
59
def __init__ (self , d , M , nbits ):
47
- code_size = int (torch .ceil (M * nbits / 8 ))
48
- Quantizer .__init__ (d , code_size )
60
+ """ M: number of subvectors, d%M == 0
61
+ nbits: number of bits that each vector is encoded into
62
+ """
63
+ assert d % M == 0
64
+ assert nbits == 8 # todo: implement other nbits values
65
+ code_size = int (math .ceil (M * nbits / 8 ))
66
+ Quantizer .__init__ (self , d , code_size )
49
67
self .M = M
50
68
self .nbits = nbits
69
+ self .code_size = code_size
51
70
52
71
def train (self , x ):
53
- pass
72
+ nc = 2 ** self .nbits
73
+ sd = self .d // self .M
74
+ dev = x .device
75
+ dtype = x .dtype
76
+ self .codebook = torch .zeros ((self .M , nc , sd ), device = dev , dtype = dtype )
77
+ for m in range (self .M ):
78
+ xsub = x [:, m * self .d // self .M : (m + 1 ) * self .d // self .M ]
79
+ data = clustering .DatasetAssign (xsub .contiguous ())
80
+ self .codebook [m ] = clustering .kmeans (2 ** self .nbits , data )
81
+
82
+ def encode (self , x ):
83
+ codes = torch .zeros ((x .shape [0 ], self .code_size ), dtype = torch .uint8 )
84
+ for m in range (self .M ):
85
+ xsub = x [:, m * self .d // self .M :(m + 1 ) * self .d // self .M ]
86
+ _ , I = faiss .knn (xsub .contiguous (), self .codebook [m ], 1 )
87
+ codes [:, m ] = I .ravel ()
88
+ return codes
89
+
90
+ def decode (self , codes ):
91
+ idxs = [codes [:, m ].long () for m in range (self .M )]
92
+ vectors = [self .codebook [m , idxs [m ], :] for m in range (self .M )]
93
+ stacked_vectors = torch .stack (vectors , dim = 1 )
94
+ cbd = self .codebook .shape [- 1 ]
95
+ x_rec = stacked_vectors .reshape (- 1 , cbd * self .M )
96
+ return x_rec
0 commit comments