Skip to content

Commit 2e6551f

Browse files
mdouzefacebook-github-bot
authored andcommitted
Support search_preassigned in torch (facebookresearch#3916)
Summary: Pull Request resolved: facebookresearch#3916 Adding missing wrapper to the torch wrappers in Faiss + test it. Also factorized a bit of code between search functions. Reviewed By: algoriddle Differential Revision: D63974821 fbshipit-source-id: a0415a57a763e2d1896956c503e503615c167860
1 parent be4fc8e commit 2e6551f

File tree

2 files changed

+77
-21
lines changed

2 files changed

+77
-21
lines changed

contrib/torch_utils.py

+51-21
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,8 @@ def torch_replacement_train(self, x):
217217
# CPU torch
218218
self.train_c(n, x_ptr)
219219

220-
def torch_replacement_search(self, x, k, D=None, I=None):
221-
if type(x) is np.ndarray:
222-
# forward to faiss __init__.py base method
223-
return self.search_numpy(x, k, D=D, I=I)
224-
225-
assert type(x) is torch.Tensor
220+
def search_methods_common(x, k, D, I):
226221
n, d = x.shape
227-
assert d == self.d
228222
x_ptr = swig_ptr_from_FloatTensor(x)
229223

230224
if D is None:
@@ -241,6 +235,19 @@ def torch_replacement_search(self, x, k, D=None, I=None):
241235
assert I.shape == (n, k)
242236
I_ptr = swig_ptr_from_IndicesTensor(I)
243237

238+
return x_ptr, D_ptr, I_ptr, D, I
239+
240+
def torch_replacement_search(self, x, k, D=None, I=None):
241+
if type(x) is np.ndarray:
242+
# forward to faiss __init__.py base method
243+
return self.search_numpy(x, k, D=D, I=I)
244+
245+
assert type(x) is torch.Tensor
246+
n, d = x.shape
247+
assert d == self.d
248+
249+
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
250+
244251
if x.is_cuda:
245252
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
246253

@@ -261,21 +268,8 @@ def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None)
261268
assert type(x) is torch.Tensor
262269
n, d = x.shape
263270
assert d == self.d
264-
x_ptr = swig_ptr_from_FloatTensor(x)
265271

266-
if D is None:
267-
D = torch.empty(n, k, device=x.device, dtype=torch.float32)
268-
else:
269-
assert type(D) is torch.Tensor
270-
assert D.shape == (n, k)
271-
D_ptr = swig_ptr_from_FloatTensor(D)
272-
273-
if I is None:
274-
I = torch.empty(n, k, device=x.device, dtype=torch.int64)
275-
else:
276-
assert type(I) is torch.Tensor
277-
assert I.shape == (n, k)
278-
I_ptr = swig_ptr_from_IndicesTensor(I)
272+
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
279273

280274
if R is None:
281275
R = torch.empty(n, k, d, device=x.device, dtype=torch.float32)
@@ -296,6 +290,40 @@ def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None)
296290

297291
return D, I, R
298292

293+
def torch_replacement_search_preassigned(self, x, k, Iq, Dq, *, D=None, I=None):
294+
if type(x) is np.ndarray:
295+
# forward to faiss __init__.py base method
296+
return self.search_preassigned_numpy(x, k, Iq, Dq, D=D, I=I)
297+
298+
assert type(x) is torch.Tensor
299+
n, d = x.shape
300+
assert d == self.d
301+
302+
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
303+
304+
assert Iq.shape == (n, self.nprobe)
305+
Iq = Iq.contiguous()
306+
Iq_ptr = swig_ptr_from_IndicesTensor(Iq)
307+
308+
if Dq is not None:
309+
Dq = Dq.contiguous()
310+
assert Dq.shape == Iq.shape
311+
Dq_ptr = swig_ptr_from_FloatTensor(Dq)
312+
else:
313+
Dq_ptr = None
314+
315+
if x.is_cuda:
316+
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
317+
318+
# On the GPU, use proper stream ordering
319+
with using_stream(self.getResources()):
320+
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
321+
else:
322+
# CPU torch
323+
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
324+
325+
return D, I
326+
299327
def torch_replacement_remove_ids(self, x):
300328
# Not yet implemented
301329
assert type(x) is not torch.Tensor, 'remove_ids not yet implemented for torch'
@@ -495,6 +523,8 @@ def torch_replacement_sa_decode(self, codes, x=None):
495523
ignore_missing=True)
496524
torch_replace_method(the_class, 'search_and_reconstruct',
497525
torch_replacement_search_and_reconstruct, ignore_missing=True)
526+
torch_replace_method(the_class, 'search_preassigned',
527+
torch_replacement_search_preassigned, ignore_missing=True)
498528
torch_replace_method(the_class, 'sa_encode', torch_replacement_sa_encode)
499529
torch_replace_method(the_class, 'sa_decode', torch_replacement_sa_decode)
500530

tests/torch_test_contrib.py

+26
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,32 @@ def test_search_and_reconstruct(self):
291291
self.assertTrue(torch.equal(I, I_input))
292292
self.assertTrue(torch.equal(R, R_input))
293293

294+
def test_search_preassigned(self):
295+
ds = datasets.SyntheticDataset(32, 1000, 100, 10)
296+
index = faiss.index_factory(32, "IVF20,PQ4np")
297+
index.train(ds.get_train())
298+
index.add(ds.get_database())
299+
index.nprobe = 4
300+
Dref, Iref = index.search(ds.get_queries(), 10)
301+
quantizer = faiss.clone_index(index.quantizer)
302+
303+
# mutilate the index' quantizer
304+
index.quantizer.reset()
305+
index.quantizer.add(np.zeros((20, 32), dtype='float32'))
306+
307+
# test numpy codepath
308+
Dq, Iq = quantizer.search(ds.get_queries(), 4)
309+
Dref2, Iref2 = index.search_preassigned(ds.get_queries(), 10, Iq, Dq)
310+
np.testing.assert_array_equal(Iref, Iref2)
311+
np.testing.assert_array_equal(Dref, Dref2)
312+
313+
# test torch codepath
314+
xq = torch.from_numpy(ds.get_queries())
315+
Dq, Iq = quantizer.search(xq, 4)
316+
Dref2, Iref2 = index.search_preassigned(xq, 10, Iq, Dq)
317+
np.testing.assert_array_equal(Iref, Iref2.numpy())
318+
np.testing.assert_array_equal(Dref, Dref2.numpy())
319+
294320
# tests sa_encode, sa_decode
295321
def test_sa_encode_decode(self):
296322
d = 16

0 commit comments

Comments
 (0)