-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathknowledge_graph.py
337 lines (285 loc) · 11.4 KB
/
knowledge_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import json
import re
from itertools import repeat
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union, cast
from cassandra.cluster import ResponseFuture, Session
from cassandra.query import BatchStatement
from cassio.config import check_resolve_keyspace, check_resolve_session
from langchain_core.embeddings import Embeddings
from .traverse import Node, Relation, atraverse, traverse
from .utils import batched
def _serialize_md_dict(md_dict: Dict[str, Any]) -> str:
return json.dumps(md_dict, separators=(",", ":"), sort_keys=True)
def _deserialize_md_dict(md_string: str) -> Dict[str, Any]:
return cast(Dict[str, Any], json.loads(md_string))
def _parse_node(row: Any) -> Node:
return Node(
name=row.name,
type=row.type,
properties=_deserialize_md_dict(row.properties_json)
if row.properties_json
else {},
)
_CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*")
class CassandraKnowledgeGraph:
"""Cassandra Knowledge Graph.
Args:
node_table: Name of the table containing nodes. Defaults to `"entities"`.
edge_table: Name of the table containing edges. Defaults to `
"relationships"`.
text_embeddings: Name of the embeddings to use, if any.
session: The Cassandra `Session` to use. If not specified, uses the default
`cassio` session, which requires `cassio.init` has been called.
keyspace: The Cassandra keyspace to use. If not specified, uses the default
`cassio` keyspace, which requires `cassio.init` has been called.
apply_schema: If true, the node table and edge table are created.
"""
def __init__(
self,
node_table: str = "entities",
edge_table: str = "relationships",
text_embeddings: Optional[Embeddings] = None,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
apply_schema: bool = True,
) -> None:
session = check_resolve_session(session)
keyspace = check_resolve_keyspace(keyspace)
if not _CQL_IDENTIFIER_PATTERN.fullmatch(keyspace):
raise ValueError(f"Invalid keyspace: {keyspace}")
if not _CQL_IDENTIFIER_PATTERN.fullmatch(node_table):
raise ValueError(f"Invalid node table name: {node_table}")
if not _CQL_IDENTIFIER_PATTERN.fullmatch(edge_table):
raise ValueError(f"Invalid edge table name: {edge_table}")
self._text_embeddings = text_embeddings
self._text_embeddings_dim = (
# Embedding vectors must have dimension:
# > 0 to be created at all.
# > 1 to support cosine distance.
# So we default to 2.
len(text_embeddings.embed_query("test string")) if text_embeddings else 2
)
self._session = session
self._keyspace = keyspace
self._node_table = node_table
self._edge_table = edge_table
if apply_schema:
self._apply_schema()
self._insert_node = self._session.prepare(
f"""INSERT INTO {keyspace}.{node_table} (
name, type, text_embedding, properties_json
) VALUES (?, ?, ?, ?)
""" # noqa: S608
)
self._insert_relationship = self._session.prepare(
f"""
INSERT INTO {keyspace}.{edge_table} (
source_name, source_type, target_name, target_type, edge_type
) VALUES (?, ?, ?, ?, ?)
""" # noqa: S608
)
self._query_relationship = self._session.prepare(
f"""
SELECT name, type, properties_json
FROM {keyspace}.{node_table}
WHERE name = ? AND type = ?
""" # noqa: S608
)
self._query_nodes_by_embedding = self._session.prepare(
f"""
SELECT name, type, properties_json
FROM {keyspace}.{node_table}
ORDER BY text_embedding ANN OF ?
LIMIT ?
""" # noqa: S608
)
def _apply_schema(self) -> None:
# Partition by `name` and cluster by `type`.
# Each `(name, type)` pair is a unique node.
# We can enumerate all `type` values for a given `name` to identify ambiguous
# terms.
self._session.execute(
f"""
CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._node_table} (
name TEXT,
type TEXT,
properties_json TEXT,
text_embedding VECTOR<FLOAT, {self._text_embeddings_dim}>,
PRIMARY KEY (name, type)
);
"""
)
self._session.execute(
f"""
CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._edge_table} (
source_name TEXT,
source_type TEXT,
target_name TEXT,
target_type TEXT,
edge_type TEXT,
PRIMARY KEY ((source_name, source_type), target_name, target_type, edge_type)
);
""" # noqa: E501
)
self._session.execute(
f"""
CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_text_embedding_index
ON {self._keyspace}.{self._node_table} (text_embedding)
USING 'StorageAttachedIndex';
"""
)
self._session.execute(
f"""
CREATE CUSTOM INDEX IF NOT EXISTS {self._edge_table}_type_index
ON {self._keyspace}.{self._edge_table} (edge_type)
USING 'StorageAttachedIndex';
"""
)
def _send_query_nearest_node(
self, embeddings: Embeddings, node: str, k: int = 1
) -> ResponseFuture:
return self._session.execute_async(
self._query_nodes_by_embedding,
(
embeddings.embed_query(node),
k,
),
)
# TODO: Allow filtering by node predicates and/or minimum similarity.
def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node]:
"""For each node, return the nearest nodes in the table.
Args:
nodes: The strings to search for in the list of nodes.
k: The number of similar nodes to retrieve for each string.
"""
if self._text_embeddings is None:
raise ValueError("Unable to query for nearest nodes without embeddings")
node_futures: Iterable[ResponseFuture] = [
self._send_query_nearest_node(self._text_embeddings, n, k) for n in nodes
]
return {
_parse_node(n) for node_future in node_futures for n in node_future.result()
}
# TODO: Introduce `ainsert` for async insertions.
def insert(
self,
elements: Iterable[Union[Node, Relation]],
) -> None:
"""Insert the given elements into the graph."""
for batch in batched(elements, n=4):
from yaml import dump
text_embeddings = (
iter(
self._text_embeddings.embed_documents(
[dump(n) for n in batch if isinstance(n, Node)]
)
)
if self._text_embeddings
else repeat([0.0, 1.0])
)
batch_statement = BatchStatement()
for element in batch:
if isinstance(element, Node):
properties_json = _serialize_md_dict(element.properties)
batch_statement.add(
self._insert_node,
(
element.name,
element.type,
next(text_embeddings),
properties_json,
),
)
elif isinstance(element, Relation):
batch_statement.add(
self._insert_relationship,
(
element.source.name,
element.source.type,
element.target.name,
element.target.type,
element.type,
),
)
else:
raise TypeError(f"Unsupported element type: {element}")
# TODO: Support concurrent execution of these statements.
self._session.execute(batch_statement)
def subgraph(
self,
start: Union[Node, Sequence[Node]],
edge_filters: Sequence[str] = (),
steps: int = 3,
) -> Tuple[Iterable[Node], Iterable[Relation]]:
"""Retrieve the sub-graph from the given starting nodes."""
edges = self.traverse(start, edge_filters, steps)
# Create the set of nodes.
nodes = {n for e in edges for n in (e.source, e.target)}
# Retrieve the set of nodes to get the properties.
# TODO: We really should have a NodeKey separate from Node. Otherwise, we end
# up in a state where two nodes can be the "same" but with different properties,
# etc.
node_futures: Iterable[ResponseFuture] = [
self._session.execute_async(self._query_relationship, (n.name, n.type))
for n in nodes
]
graph_nodes = [
_parse_node(n) for future in node_futures for n in future.result()
]
return graph_nodes, edges
def traverse(
self,
start: Union[Node, Sequence[Node]],
edge_filters: Sequence[str] = (),
steps: int = 3,
) -> Iterable[Relation]:
"""Traverse the graph from the given starting nodes.
Returns the resulting sub-graph.
Args:
start: The starting node or nodes.
edge_filters: Filters to apply to the edges being traversed.
steps: The number of steps of edges to follow from a start node.
Returns:
An iterable over relations in the traversed sub-graph.
"""
return traverse(
start=start,
edge_table=self._edge_table,
edge_source_name="source_name",
edge_source_type="source_type",
edge_target_name="target_name",
edge_target_type="target_type",
edge_type="edge_type",
edge_filters=edge_filters,
steps=steps,
session=self._session,
keyspace=self._keyspace,
)
async def atraverse(
self,
start: Union[Node, Sequence[Node]],
edge_filters: Sequence[str] = (),
steps: int = 3,
) -> Iterable[Relation]:
"""Traverse the graph from the given starting nodes.
Returns the resulting sub-graph.
Args:
start: The starting node or nodes.
edge_filters: Filters to apply to the edges being traversed.
steps: The number of steps of edges to follow from a start node.
Returns:
An iterable over relations in the traversed sub-graph.
"""
return await atraverse(
start=start,
edge_table=self._edge_table,
edge_source_name="source_name",
edge_source_type="source_type",
edge_target_name="target_name",
edge_target_type="target_type",
edge_type="edge_type",
edge_filters=edge_filters,
steps=steps,
session=self._session,
keyspace=self._keyspace,
)