Skip to content

Commit 7d4f13d

Browse files
committed
Use more generic types in API
1 parent 3427b98 commit 7d4f13d

File tree

2 files changed

+43
-23
lines changed

2 files changed

+43
-23
lines changed

libs/knowledge-store/ragstack_knowledge_store/knowledge_store.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _concurrent_queries(self) -> ConcurrentQueries:
325325
def add_nodes(
326326
self,
327327
nodes: Iterable[Node] = None,
328-
):
328+
) -> Iterable[str]:
329329
texts = []
330330
metadatas = []
331331
for node in nodes:
@@ -483,7 +483,7 @@ def add_edges_for_targets(
483483
def _query_by_ids(
484484
self,
485485
ids: Iterable[str],
486-
) -> Iterable[TextNode]:
486+
) -> List[TextNode]:
487487
results = []
488488
with self._concurrent_queries() as cq:
489489

libs/knowledge-store/ragstack_knowledge_store/langchain/base.py

+41-21
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def add_nodes(
102102
self,
103103
nodes: Iterable[Node],
104104
**kwargs: Any,
105-
) -> List[str]:
105+
) -> Iterable[str]:
106106
"""Add nodes to the knowledge store
107107
108108
Args:
@@ -113,13 +113,19 @@ async def aadd_nodes(
113113
self,
114114
nodes: Iterable[Node],
115115
**kwargs: Any,
116-
) -> List[str]:
116+
) -> AsyncIterable[str]:
117117
"""Add nodes to the knowledge store
118118
119119
Args:
120120
nodes: the nodes to add.
121121
"""
122-
return await run_in_executor(None, self.add_nodes, nodes, **kwargs)
122+
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
123+
done = object()
124+
while True:
125+
doc = await run_in_executor(None, next, iterator, done)
126+
if doc is done:
127+
break
128+
yield doc
123129

124130
def add_texts(
125131
self,
@@ -130,7 +136,7 @@ def add_texts(
130136
**kwargs: Any,
131137
) -> List[str]:
132138
nodes = _texts_to_nodes(texts, metadatas, ids)
133-
return self.add_nodes(nodes, **kwargs)
139+
return list(self.add_nodes(nodes, **kwargs))
134140

135141
async def aadd_texts(
136142
self,
@@ -141,7 +147,7 @@ async def aadd_texts(
141147
**kwargs: Any,
142148
) -> List[str]:
143149
nodes = _texts_to_nodes(texts, metadatas, ids)
144-
return await self.aadd_nodes(nodes, **kwargs)
150+
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
145151

146152
def add_documents(
147153
self,
@@ -151,7 +157,7 @@ def add_documents(
151157
**kwargs: Any,
152158
) -> List[str]:
153159
nodes = _documents_to_nodes(documents, ids)
154-
return self.add_nodes(nodes, **kwargs)
160+
return list(self.add_nodes(nodes, **kwargs))
155161

156162
async def aadd_documents(
157163
self,
@@ -161,7 +167,7 @@ async def aadd_documents(
161167
**kwargs: Any,
162168
) -> List[str]:
163169
nodes = _documents_to_nodes(documents, ids)
164-
return await self.aadd_nodes(nodes, **kwargs)
170+
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
165171

166172
@abstractmethod
167173
def traversal_search(
@@ -209,9 +215,16 @@ async def atraversal_search(
209215
Returns:
210216
Retrieved documents.
211217
"""
212-
for doc in await run_in_executor(
213-
None, self.traversal_search, query, k=k, depth=depth, **kwargs
214-
):
218+
iterator = iter(
219+
await run_in_executor(
220+
None, self.traversal_search, query, k=k, depth=depth, **kwargs
221+
)
222+
)
223+
done = object()
224+
while True:
225+
doc = await run_in_executor(None, next, iterator, done)
226+
if doc is done:
227+
break
215228
yield doc
216229

217230
@abstractmethod
@@ -284,17 +297,24 @@ async def ammr_traversal_search(
284297
score_threshold: Only documents with a score greater than or equal
285298
this threshold will be chosen. Defaults to negative infinity.
286299
"""
287-
for doc in await run_in_executor(
288-
None,
289-
self.traversal_search,
290-
query,
291-
k=k,
292-
fetch_k=fetch_k,
293-
depth=depth,
294-
lambda_mult=lambda_mult,
295-
score_threshold=score_threshold,
296-
**kwargs,
297-
):
300+
iterator = iter(
301+
await run_in_executor(
302+
None,
303+
self.mmr_traversal_search,
304+
query,
305+
k=k,
306+
fetch_k=fetch_k,
307+
depth=depth,
308+
lambda_mult=lambda_mult,
309+
score_threshold=score_threshold,
310+
**kwargs,
311+
)
312+
)
313+
done = object()
314+
while True:
315+
doc = await run_in_executor(None, next, iterator, done)
316+
if doc is done:
317+
break
298318
yield doc
299319

300320
def similarity_search(

0 commit comments

Comments
 (0)