@@ -217,14 +217,8 @@ def torch_replacement_train(self, x):
217
217
# CPU torch
218
218
self .train_c (n , x_ptr )
219
219
220
- def torch_replacement_search (self , x , k , D = None , I = None ):
221
- if type (x ) is np .ndarray :
222
- # forward to faiss __init__.py base method
223
- return self .search_numpy (x , k , D = D , I = I )
224
-
225
- assert type (x ) is torch .Tensor
220
+ def search_methods_common (x , k , D , I ):
226
221
n , d = x .shape
227
- assert d == self .d
228
222
x_ptr = swig_ptr_from_FloatTensor (x )
229
223
230
224
if D is None :
@@ -241,6 +235,19 @@ def torch_replacement_search(self, x, k, D=None, I=None):
241
235
assert I .shape == (n , k )
242
236
I_ptr = swig_ptr_from_IndicesTensor (I )
243
237
238
+ return x_ptr , D_ptr , I_ptr , D , I
239
+
240
+ def torch_replacement_search (self , x , k , D = None , I = None ):
241
+ if type (x ) is np .ndarray :
242
+ # forward to faiss __init__.py base method
243
+ return self .search_numpy (x , k , D = D , I = I )
244
+
245
+ assert type (x ) is torch .Tensor
246
+ n , d = x .shape
247
+ assert d == self .d
248
+
249
+ x_ptr , D_ptr , I_ptr , D , I = search_methods_common (x , k , D , I )
250
+
244
251
if x .is_cuda :
245
252
assert hasattr (self , 'getDevice' ), 'GPU tensor on CPU index not allowed'
246
253
@@ -261,21 +268,8 @@ def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None)
261
268
assert type (x ) is torch .Tensor
262
269
n , d = x .shape
263
270
assert d == self .d
264
- x_ptr = swig_ptr_from_FloatTensor (x )
265
271
266
- if D is None :
267
- D = torch .empty (n , k , device = x .device , dtype = torch .float32 )
268
- else :
269
- assert type (D ) is torch .Tensor
270
- assert D .shape == (n , k )
271
- D_ptr = swig_ptr_from_FloatTensor (D )
272
-
273
- if I is None :
274
- I = torch .empty (n , k , device = x .device , dtype = torch .int64 )
275
- else :
276
- assert type (I ) is torch .Tensor
277
- assert I .shape == (n , k )
278
- I_ptr = swig_ptr_from_IndicesTensor (I )
272
+ x_ptr , D_ptr , I_ptr , D , I = search_methods_common (x , k , D , I )
279
273
280
274
if R is None :
281
275
R = torch .empty (n , k , d , device = x .device , dtype = torch .float32 )
@@ -296,6 +290,40 @@ def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None)
296
290
297
291
return D , I , R
298
292
293
+ def torch_replacement_search_preassigned (self , x , k , Iq , Dq , * , D = None , I = None ):
294
+ if type (x ) is np .ndarray :
295
+ # forward to faiss __init__.py base method
296
+ return self .search_preassigned_numpy (x , k , Iq , Dq , D = D , I = I )
297
+
298
+ assert type (x ) is torch .Tensor
299
+ n , d = x .shape
300
+ assert d == self .d
301
+
302
+ x_ptr , D_ptr , I_ptr , D , I = search_methods_common (x , k , D , I )
303
+
304
+ assert Iq .shape == (n , self .nprobe )
305
+ Iq = Iq .contiguous ()
306
+ Iq_ptr = swig_ptr_from_IndicesTensor (Iq )
307
+
308
+ if Dq is not None :
309
+ Dq = Dq .contiguous ()
310
+ assert Dq .shape == Iq .shape
311
+ Dq_ptr = swig_ptr_from_FloatTensor (Dq )
312
+ else :
313
+ Dq_ptr = None
314
+
315
+ if x .is_cuda :
316
+ assert hasattr (self , 'getDevice' ), 'GPU tensor on CPU index not allowed'
317
+
318
+ # On the GPU, use proper stream ordering
319
+ with using_stream (self .getResources ()):
320
+ self .search_preassigned_c (n , x_ptr , k , Iq_ptr , Dq_ptr , D_ptr , I_ptr , False )
321
+ else :
322
+ # CPU torch
323
+ self .search_preassigned_c (n , x_ptr , k , Iq_ptr , Dq_ptr , D_ptr , I_ptr , False )
324
+
325
+ return D , I
326
+
299
327
def torch_replacement_remove_ids (self , x ):
300
328
# Not yet implemented
301
329
assert type (x ) is not torch .Tensor , 'remove_ids not yet implemented for torch'
@@ -495,6 +523,8 @@ def torch_replacement_sa_decode(self, codes, x=None):
495
523
ignore_missing = True )
496
524
torch_replace_method (the_class , 'search_and_reconstruct' ,
497
525
torch_replacement_search_and_reconstruct , ignore_missing = True )
526
+ torch_replace_method (the_class , 'search_preassigned' ,
527
+ torch_replacement_search_preassigned , ignore_missing = True )
498
528
torch_replace_method (the_class , 'sa_encode' , torch_replacement_sa_encode )
499
529
torch_replace_method (the_class , 'sa_decode' , torch_replacement_sa_decode )
500
530
0 commit comments