@@ -102,7 +102,7 @@ def add_nodes(
102
102
self ,
103
103
nodes : Iterable [Node ],
104
104
** kwargs : Any ,
105
- ) -> List [str ]:
105
+ ) -> Iterable [str ]:
106
106
"""Add nodes to the knowledge store
107
107
108
108
Args:
@@ -113,13 +113,19 @@ async def aadd_nodes(
113
113
self ,
114
114
nodes : Iterable [Node ],
115
115
** kwargs : Any ,
116
- ) -> List [str ]:
116
+ ) -> AsyncIterable [str ]:
117
117
"""Add nodes to the knowledge store
118
118
119
119
Args:
120
120
nodes: the nodes to add.
121
121
"""
122
- return await run_in_executor (None , self .add_nodes , nodes , ** kwargs )
122
+ iterator = iter (await run_in_executor (None , self .add_nodes , nodes , ** kwargs ))
123
+ done = object ()
124
+ while True :
125
+ doc = await run_in_executor (None , next , iterator , done )
126
+ if doc is done :
127
+ break
128
+ yield doc
123
129
124
130
def add_texts (
125
131
self ,
@@ -130,7 +136,7 @@ def add_texts(
130
136
** kwargs : Any ,
131
137
) -> List [str ]:
132
138
nodes = _texts_to_nodes (texts , metadatas , ids )
133
- return self .add_nodes (nodes , ** kwargs )
139
+ return list ( self .add_nodes (nodes , ** kwargs ) )
134
140
135
141
async def aadd_texts (
136
142
self ,
@@ -141,7 +147,7 @@ async def aadd_texts(
141
147
** kwargs : Any ,
142
148
) -> List [str ]:
143
149
nodes = _texts_to_nodes (texts , metadatas , ids )
144
- return await self .aadd_nodes (nodes , ** kwargs )
150
+ return [ _id async for _id in self .aadd_nodes (nodes , ** kwargs )]
145
151
146
152
def add_documents (
147
153
self ,
@@ -151,7 +157,7 @@ def add_documents(
151
157
** kwargs : Any ,
152
158
) -> List [str ]:
153
159
nodes = _documents_to_nodes (documents , ids )
154
- return self .add_nodes (nodes , ** kwargs )
160
+ return list ( self .add_nodes (nodes , ** kwargs ) )
155
161
156
162
async def aadd_documents (
157
163
self ,
@@ -161,7 +167,7 @@ async def aadd_documents(
161
167
** kwargs : Any ,
162
168
) -> List [str ]:
163
169
nodes = _documents_to_nodes (documents , ids )
164
- return await self .aadd_nodes (nodes , ** kwargs )
170
+ return [ _id async for _id in self .aadd_nodes (nodes , ** kwargs )]
165
171
166
172
@abstractmethod
167
173
def traversal_search (
@@ -209,9 +215,16 @@ async def atraversal_search(
209
215
Returns:
210
216
Retrieved documents.
211
217
"""
212
- for doc in await run_in_executor (
213
- None , self .traversal_search , query , k = k , depth = depth , ** kwargs
214
- ):
218
+ iterator = iter (
219
+ await run_in_executor (
220
+ None , self .traversal_search , query , k = k , depth = depth , ** kwargs
221
+ )
222
+ )
223
+ done = object ()
224
+ while True :
225
+ doc = await run_in_executor (None , next , iterator , done )
226
+ if doc is done :
227
+ break
215
228
yield doc
216
229
217
230
@abstractmethod
@@ -284,17 +297,24 @@ async def ammr_traversal_search(
284
297
score_threshold: Only documents with a score greater than or equal
285
298
this threshold will be chosen. Defaults to negative infinity.
286
299
"""
287
- for doc in await run_in_executor (
288
- None ,
289
- self .traversal_search ,
290
- query ,
291
- k = k ,
292
- fetch_k = fetch_k ,
293
- depth = depth ,
294
- lambda_mult = lambda_mult ,
295
- score_threshold = score_threshold ,
296
- ** kwargs ,
297
- ):
300
+ iterator = iter (
301
+ await run_in_executor (
302
+ None ,
303
+ self .mmr_traversal_search ,
304
+ query ,
305
+ k = k ,
306
+ fetch_k = fetch_k ,
307
+ depth = depth ,
308
+ lambda_mult = lambda_mult ,
309
+ score_threshold = score_threshold ,
310
+ ** kwargs ,
311
+ )
312
+ )
313
+ done = object ()
314
+ while True :
315
+ doc = await run_in_executor (None , next , iterator , done )
316
+ if doc is done :
317
+ break
298
318
yield doc
299
319
300
320
def similarity_search (
0 commit comments