Skip to content

Commit 610dfbf

Browse files
authored
feat: Limit items retrieved during MMR traversal (#514)
* feat: Limit items retrieved during MMR traversal * remove unused method / unneeded test * remove unused query
1 parent a5e1c0d commit 610dfbf

File tree

2 files changed

+37
-64
lines changed

2 files changed

+37
-64
lines changed

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

+37-16
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,13 @@ def __init__(
260260
"""
261261
)
262262

263-
self._query_targets_embeddings_by_kind_and_tag = session.prepare(
263+
self._query_targets_embeddings_by_kind_and_tag_and_embedding = session.prepare(
264264
f"""
265265
SELECT target_content_id, target_text_embedding, tag
266266
FROM {keyspace}.{targets_table}
267267
WHERE kind = ? AND tag = ?
268+
ORDER BY target_text_embedding ANN of ?
269+
LIMIT ?
268270
"""
269271
)
270272

@@ -317,6 +319,14 @@ def _apply_schema(self):
317319
"""
318320
)
319321

322+
# Index on target_text_embedding (for similarity search)
323+
self._session.execute(
324+
f"""CREATE CUSTOM INDEX IF NOT EXISTS {self._targets_table}_target_text_embedding_index
325+
ON {self._keyspace}.{self._targets_table}(target_text_embedding)
326+
USING 'StorageAttachedIndex';
327+
"""
328+
)
329+
320330
def _concurrent_queries(self) -> ConcurrentQueries:
321331
return ConcurrentQueries(self._session)
322332

@@ -393,20 +403,14 @@ def add_nodes(rows):
393403

394404
return [results[id] for id in ids]
395405

396-
def _linked_ids(
397-
self,
398-
source_id: str,
399-
) -> Iterable[str]:
400-
adjacent = self._get_adjacent([source_id])
401-
return [edge.target_content_id for edge in adjacent]
402-
403406
def mmr_traversal_search(
404407
self,
405408
query: str,
406409
*,
407410
k: int = 4,
408411
depth: int = 2,
409412
fetch_k: int = 100,
413+
adjacent_k: int = 10,
410414
lambda_mult: float = 0.5,
411415
score_threshold: float = float("-inf"),
412416
) -> Iterable[Node]:
@@ -423,7 +427,9 @@ def mmr_traversal_search(
423427
Args:
424428
query: The query string to search for.
425429
k: Number of Documents to return. Defaults to 4.
426-
fetch_k: Number of Documents to fetch via similarity.
430+
fetch_k: Number of initial Documents to fetch via similarity.
431+
Defaults to 100.
432+
adjacent_k: Number of adjacent Documents to fetch.
427433
Defaults to 10.
428434
depth: Maximum depth of a node (number of edges) from a node
429435
retrieved via similarity. Defaults to 2.
@@ -446,9 +452,9 @@ def mmr_traversal_search(
446452
(query_embedding, fetch_k),
447453
)
448454

449-
query_embedding = emb_to_ndarray(query_embedding)
455+
query_embedding_ndarray = emb_to_ndarray(query_embedding)
450456
unselected = {
451-
row.content_id: _Candidate(row.text_embedding, lambda_mult, query_embedding)
457+
row.content_id: _Candidate(row.text_embedding, lambda_mult, query_embedding_ndarray)
452458
for row in fetched
453459
}
454460
best_score, next_id = max(
@@ -479,7 +485,9 @@ def mmr_traversal_search(
479485
# Add unselected edges if reached nodes are within `depth`:
480486
next_depth = next_selected.distance + 1
481487
if next_depth < depth:
482-
adjacents = self._get_adjacent([selected_id])
488+
adjacents = self._get_adjacent([selected_id],
489+
query_embedding=query_embedding,
490+
k_per_tag=adjacent_k)
483491
for adjacent in adjacents:
484492
target_id = adjacent.target_content_id
485493
if target_id in selected_set:
@@ -494,7 +502,7 @@ def mmr_traversal_search(
494502
continue
495503

496504
candidate = _Candidate(
497-
adjacent.target_text_embedding, lambda_mult, query_embedding
505+
adjacent.target_text_embedding, lambda_mult, query_embedding_ndarray
498506
)
499507
for selected_embedding in selected_embeddings:
500508
candidate.update_for_selection(lambda_mult, selected_embedding)
@@ -621,8 +629,19 @@ def similarity_search(
621629
def _get_adjacent(
622630
self,
623631
source_ids: Iterable[str],
632+
query_embedding: List[float],
633+
k_per_tag: Optional[int] = None,
624634
) -> Iterable[_Edge]:
625-
"""Return the target nodes adjacent to any of the source nodes."""
635+
"""Return the target nodes adjacent to any of the source nodes.
636+
637+
Args:
638+
source_ids: The source IDs to start from when retrieving adjacent nodes.
639+
query_embedding: The query embedding. Used to rank target nodes.
640+
k_per_tag: The number of target nodes to fetch for each outgoing tag.
641+
642+
Returns:
643+
List of adjacent edges.
644+
"""
626645

627646
link_to_tags = set()
628647
targets = dict()
@@ -632,9 +651,10 @@ def add_sources(rows):
632651
for new_tag in row.link_to_tags or []:
633652
if new_tag not in link_to_tags:
634653
link_to_tags.add(new_tag)
654+
635655
cq.execute(
636-
self._query_targets_embeddings_by_kind_and_tag,
637-
new_tag,
656+
self._query_targets_embeddings_by_kind_and_tag_and_embedding,
657+
parameters = (new_tag[0], new_tag[1], query_embedding, k_per_tag or 10),
638658
callback=add_targets,
639659
)
640660
link_to_tags.add(new_tag)
@@ -653,6 +673,7 @@ def add_targets(rows):
653673
self._query_source_tags_by_id, (source_id,), callback=add_sources
654674
)
655675

676+
# TODO: Consider a combined limit based on the similarity and/or predicated MMR score?
656677
return [
657678
_Edge(target_content_id=content_id, target_text_embedding=embedding)
658679
for (content_id, embedding) in targets.items()

libs/langchain/tests/integration_tests/test_graph_store.py

-48
Original file line numberDiff line numberDiff line change
@@ -112,54 +112,6 @@ def _result_ids(docs: Iterable[Document]) -> List[str]:
112112
return list(map(lambda d: d.metadata[METADATA_CONTENT_ID_KEY], docs))
113113

114114

115-
def test_link_directed(cassandra: GraphStoreFactory) -> None:
116-
a = Document(
117-
page_content="A",
118-
metadata={
119-
METADATA_CONTENT_ID_KEY: "a",
120-
METADATA_LINKS_KEY: {
121-
Link.incoming(kind="hyperlink", tag="http://a"),
122-
},
123-
},
124-
)
125-
b = Document(
126-
page_content="B",
127-
metadata={
128-
METADATA_CONTENT_ID_KEY: "b",
129-
METADATA_LINKS_KEY: {
130-
Link.incoming(kind="hyperlink", tag="http://b"),
131-
Link.outgoing(kind="hyperlink", tag="http://a"),
132-
},
133-
},
134-
)
135-
c = Document(
136-
page_content="C",
137-
metadata={
138-
METADATA_CONTENT_ID_KEY: "c",
139-
METADATA_LINKS_KEY: {
140-
Link.outgoing(kind="hyperlink", tag="http://a"),
141-
},
142-
},
143-
)
144-
d = Document(
145-
page_content="D",
146-
metadata={
147-
METADATA_CONTENT_ID_KEY: "d",
148-
METADATA_LINKS_KEY: {
149-
Link.outgoing(kind="hyperlink", tag="http://a"),
150-
Link.outgoing(kind="hyperlink", tag="http://b"),
151-
},
152-
},
153-
)
154-
155-
store = cassandra.store([a, b, c, d])
156-
157-
assert list(store.store._linked_ids("a")) == []
158-
assert list(store.store._linked_ids("b")) == ["a"]
159-
assert list(store.store._linked_ids("c")) == ["a"]
160-
assert sorted(store.store._linked_ids("d")) == ["a", "b"]
161-
162-
163115
@pytest.mark.parametrize("gs_factory", ["cassandra", "astra_db"])
164116
def test_mmr_traversal(request, gs_factory: str):
165117
"""

0 commit comments

Comments
 (0)