@@ -260,11 +260,13 @@ def __init__(
260
260
"""
261
261
)
262
262
263
- self ._query_targets_embeddings_by_kind_and_tag = session .prepare (
263
+ self ._query_targets_embeddings_by_kind_and_tag_and_embedding = session .prepare (
264
264
f"""
265
265
SELECT target_content_id, target_text_embedding, tag
266
266
FROM { keyspace } .{ targets_table }
267
267
WHERE kind = ? AND tag = ?
268
+ ORDER BY target_text_embedding ANN of ?
269
+ LIMIT ?
268
270
"""
269
271
)
270
272
@@ -317,6 +319,14 @@ def _apply_schema(self):
317
319
"""
318
320
)
319
321
322
+ # Index on target_text_embedding (for similarity search)
323
+ self ._session .execute (
324
+ f"""CREATE CUSTOM INDEX IF NOT EXISTS { self ._targets_table } _target_text_embedding_index
325
+ ON { self ._keyspace } .{ self ._targets_table } (target_text_embedding)
326
+ USING 'StorageAttachedIndex';
327
+ """
328
+ )
329
+
320
330
def _concurrent_queries (self ) -> ConcurrentQueries :
321
331
return ConcurrentQueries (self ._session )
322
332
@@ -393,20 +403,14 @@ def add_nodes(rows):
393
403
394
404
return [results [id ] for id in ids ]
395
405
396
- def _linked_ids (
397
- self ,
398
- source_id : str ,
399
- ) -> Iterable [str ]:
400
- adjacent = self ._get_adjacent ([source_id ])
401
- return [edge .target_content_id for edge in adjacent ]
402
-
403
406
def mmr_traversal_search (
404
407
self ,
405
408
query : str ,
406
409
* ,
407
410
k : int = 4 ,
408
411
depth : int = 2 ,
409
412
fetch_k : int = 100 ,
413
+ adjacent_k : int = 10 ,
410
414
lambda_mult : float = 0.5 ,
411
415
score_threshold : float = float ("-inf" ),
412
416
) -> Iterable [Node ]:
@@ -423,7 +427,9 @@ def mmr_traversal_search(
423
427
Args:
424
428
query: The query string to search for.
425
429
k: Number of Documents to return. Defaults to 4.
426
- fetch_k: Number of Documents to fetch via similarity.
430
+ fetch_k: Number of initial Documents to fetch via similarity.
431
+ Defaults to 100.
432
+ adjacent_k: Number of adjacent Documents to fetch.
427
433
Defaults to 10.
428
434
depth: Maximum depth of a node (number of edges) from a node
429
435
retrieved via similarity. Defaults to 2.
@@ -446,9 +452,9 @@ def mmr_traversal_search(
446
452
(query_embedding , fetch_k ),
447
453
)
448
454
449
- query_embedding = emb_to_ndarray (query_embedding )
455
+ query_embedding_ndarray = emb_to_ndarray (query_embedding )
450
456
unselected = {
451
- row .content_id : _Candidate (row .text_embedding , lambda_mult , query_embedding )
457
+ row .content_id : _Candidate (row .text_embedding , lambda_mult , query_embedding_ndarray )
452
458
for row in fetched
453
459
}
454
460
best_score , next_id = max (
@@ -479,7 +485,9 @@ def mmr_traversal_search(
479
485
# Add unselected edges if reached nodes are within `depth`:
480
486
next_depth = next_selected .distance + 1
481
487
if next_depth < depth :
482
- adjacents = self ._get_adjacent ([selected_id ])
488
+ adjacents = self ._get_adjacent ([selected_id ],
489
+ query_embedding = query_embedding ,
490
+ k_per_tag = adjacent_k )
483
491
for adjacent in adjacents :
484
492
target_id = adjacent .target_content_id
485
493
if target_id in selected_set :
@@ -494,7 +502,7 @@ def mmr_traversal_search(
494
502
continue
495
503
496
504
candidate = _Candidate (
497
- adjacent .target_text_embedding , lambda_mult , query_embedding
505
+ adjacent .target_text_embedding , lambda_mult , query_embedding_ndarray
498
506
)
499
507
for selected_embedding in selected_embeddings :
500
508
candidate .update_for_selection (lambda_mult , selected_embedding )
@@ -621,8 +629,19 @@ def similarity_search(
621
629
def _get_adjacent (
622
630
self ,
623
631
source_ids : Iterable [str ],
632
+ query_embedding : List [float ],
633
+ k_per_tag : Optional [int ] = None ,
624
634
) -> Iterable [_Edge ]:
625
- """Return the target nodes adjacent to any of the source nodes."""
635
+ """Return the target nodes adjacent to any of the source nodes.
636
+
637
+ Args:
638
+ source_ids: The source IDs to start from when retrieving adjacent nodes.
639
+ query_embedding: The query embedding. Used to rank target nodes.
640
+ k_per_tag: The number of target nodes to fetch for each outgoing tag.
641
+
642
+ Returns:
643
+ List of adjacent edges.
644
+ """
626
645
627
646
link_to_tags = set ()
628
647
targets = dict ()
@@ -632,9 +651,10 @@ def add_sources(rows):
632
651
for new_tag in row .link_to_tags or []:
633
652
if new_tag not in link_to_tags :
634
653
link_to_tags .add (new_tag )
654
+
635
655
cq .execute (
636
- self ._query_targets_embeddings_by_kind_and_tag ,
637
- new_tag ,
656
+ self ._query_targets_embeddings_by_kind_and_tag_and_embedding ,
657
+ parameters = ( new_tag [ 0 ], new_tag [ 1 ], query_embedding , k_per_tag or 10 ) ,
638
658
callback = add_targets ,
639
659
)
640
660
link_to_tags .add (new_tag )
@@ -653,6 +673,7 @@ def add_targets(rows):
653
673
self ._query_source_tags_by_id , (source_id ,), callback = add_sources
654
674
)
655
675
676
+ # TODO: Consider a combined limit based on the similarity and/or predicated MMR score?
656
677
return [
657
678
_Edge (target_content_id = content_id , target_text_embedding = embedding )
658
679
for (content_id , embedding ) in targets .items ()
0 commit comments