-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
Copy pathentity_extraction.py
121 lines (107 loc) · 4.23 KB
/
entity_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Orchestration Context Builders."""
from enum import Enum
from graphrag.data_model.entity import Entity
from graphrag.data_model.relationship import Relationship
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.query.input.retrieval.entities import (
get_entity_by_id,
get_entity_by_key,
get_entity_by_name,
)
from graphrag.vector_stores.base import BaseVectorStore
class EntityVectorStoreKey(str, Enum):
"""Keys used as ids in the entity embedding vectorstores."""
ID = "id"
TITLE = "title"
@staticmethod
def from_string(value: str) -> "EntityVectorStoreKey":
"""Convert string to EntityVectorStoreKey."""
if value == "id":
return EntityVectorStoreKey.ID
if value == "title":
return EntityVectorStoreKey.TITLE
msg = f"Invalid EntityVectorStoreKey: {value}"
raise ValueError(msg)
def map_query_to_entities(
query: str,
text_embedding_vectorstore: BaseVectorStore,
text_embedder: EmbeddingModel,
all_entities_dict: dict[str, Entity],
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
include_entity_names: list[str] | None = None,
exclude_entity_names: list[str] | None = None,
k: int = 10,
oversample_scaler: int = 2,
) -> list[Entity]:
"""Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions."""
if include_entity_names is None:
include_entity_names = []
if exclude_entity_names is None:
exclude_entity_names = []
all_entities = list(all_entities_dict.values())
matched_entities = []
if query != "":
# get entities with highest semantic similarity to query
# oversample to account for excluded entities
search_results = text_embedding_vectorstore.similarity_search_by_text(
text=query,
text_embedder=lambda t: text_embedder.embed(t),
k=k * oversample_scaler,
)
for result in search_results:
if embedding_vectorstore_key == EntityVectorStoreKey.ID and isinstance(
result.document.id, str
):
matched = get_entity_by_id(all_entities_dict, result.document.id)
else:
matched = get_entity_by_key(
entities=all_entities,
key=embedding_vectorstore_key,
value=result.document.id,
)
if matched:
matched_entities.append(matched)
else:
all_entities.sort(key=lambda x: x.rank if x.rank else 0, reverse=True)
matched_entities = all_entities[:k]
# filter out excluded entities
if exclude_entity_names:
matched_entities = [
entity
for entity in matched_entities
if entity.title not in exclude_entity_names
]
# add entities in the include_entity list
included_entities = []
for entity_name in include_entity_names:
included_entities.extend(get_entity_by_name(all_entities, entity_name))
return included_entities + matched_entities
def find_nearest_neighbors_by_entity_rank(
entity_name: str,
all_entities: list[Entity],
all_relationships: list[Relationship],
exclude_entity_names: list[str] | None = None,
k: int | None = 10,
) -> list[Entity]:
"""Retrieve entities that have direct connections with the target entity, sorted by entity rank."""
if exclude_entity_names is None:
exclude_entity_names = []
entity_relationships = [
rel
for rel in all_relationships
if rel.source == entity_name or rel.target == entity_name
]
source_entity_names = {rel.source for rel in entity_relationships}
target_entity_names = {rel.target for rel in entity_relationships}
related_entity_names = (source_entity_names.union(target_entity_names)).difference(
set(exclude_entity_names)
)
top_relations = [
entity for entity in all_entities if entity.title in related_entity_names
]
top_relations.sort(key=lambda x: x.rank if x.rank else 0, reverse=True)
if k:
return top_relations[:k]
return top_relations