-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy path_mmr_helper.py
242 lines (189 loc) · 8.67 KB
/
_mmr_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import dataclasses
from typing import Dict, Iterable, List, Optional
import numpy as np
from numpy.typing import NDArray
from ragstack_knowledge_store.math import cosine_similarity
def _emb_to_ndarray(embedding: List[float]) -> NDArray[np.float32]:
emb_array = np.array(embedding, dtype=np.float32)
if emb_array.ndim == 1:
emb_array = np.expand_dims(emb_array, axis=0)
return emb_array
NEG_INF = float("-inf")
@dataclasses.dataclass
class _Candidate:
id: str
weighted_similarity: float
weighted_redundancy: float
score: float = dataclasses.field(init=False)
def __post_init__(self) -> None:
self.score = self.weighted_similarity - self.weighted_redundancy
def update_redundancy(self, new_weighted_redundancy: float) -> None:
if new_weighted_redundancy > self.weighted_redundancy:
self.weighted_redundancy = new_weighted_redundancy
self.score = self.weighted_similarity - self.weighted_redundancy
class MmrHelper:
"""Helper for executing an MMR traversal query.
Args:
query_embedding: The embedding of the query to use for scoring.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to -infinity.
"""
dimensions: int
"""Dimensions of the embedding."""
query_embedding: NDArray[np.float32]
"""Embedding of the query as a (1,dim) ndarray."""
lambda_mult: float
"""Number between 0 and 1.
Determines the degree of diversity among the results with 0 corresponding to
maximum diversity and 1 to minimum diversity."""
lambda_mult_complement: float
"""1 - lambda_mult."""
score_threshold: float
"""Only documents with a score greater than or equal to this will be chosen."""
selected_ids: List[str]
"""List of selected IDs (in selection order)."""
selected_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each selected node."""
candidate_id_to_index: Dict[str, int]
"""Dictionary of candidate IDs to indices in candidates and candidate_embeddings."""
candidates: List[_Candidate]
"""List containing information about candidates.
Same order as rows in `candidate_embeddings`.
"""
candidate_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each candidate."""
best_score: float
best_id: Optional[str]
def __init__(
self,
k: int,
query_embedding: List[float],
lambda_mult: float = 0.5,
score_threshold: float = NEG_INF,
) -> None:
self.query_embedding = _emb_to_ndarray(query_embedding)
self.dimensions = self.query_embedding.shape[1]
self.lambda_mult = lambda_mult
self.lambda_mult_complement = 1 - lambda_mult
self.score_threshold = score_threshold
self.selected_ids = []
# List of selected embeddings (in selection order).
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)
self.candidate_id_to_index = {}
# List of the candidates.
self.candidates = []
# ND array of the candidate embeddings.
self.candidate_embeddings = np.ndarray((0, self.dimensions), dtype=np.float32)
self.best_score = NEG_INF
self.best_id = None
def candidate_ids(self) -> Iterable[str]:
"""Return the IDs of the candidates."""
return self.candidate_id_to_index.keys()
def _already_selected_embeddings(self) -> NDArray[np.float32]:
"""Return the selected embeddings sliced to the already assigned values."""
selected = len(self.selected_ids)
return np.vsplit(self.selected_embeddings, [selected])[0]
def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
"""Pop the candidate with the given ID.
Returns:
The embedding of the candidate.
"""
# Get the embedding for the id.
index = self.candidate_id_to_index.pop(candidate_id)
if not self.candidates[index].id == candidate_id:
raise ValueError(
"ID in self.candidate_id_to_index doesn't match the ID of the "
"corresponding index in self.candidates"
)
embedding: NDArray[np.float32] = self.candidate_embeddings[index].copy()
# Swap that index with the last index in the candidates and
# candidate_embeddings.
last_index = self.candidate_embeddings.shape[0] - 1
if index == last_index:
# Already the last item. We don't need to swap.
self.candidates.pop()
else:
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]
old_last = self.candidates.pop()
self.candidates[index] = old_last
self.candidate_id_to_index[old_last.id] = index
self.candidate_embeddings = np.vsplit(self.candidate_embeddings, [last_index])[
0
]
return embedding
def pop_best(self) -> Optional[str]:
"""Select and pop the best item being considered.
Updates the consideration set based on it.
Returns:
A tuple containing the ID of the best item.
"""
if self.best_id is None or self.best_score < self.score_threshold:
return None
# Get the selection and remove from candidates.
selected_id = self.best_id
selected_embedding = self._pop_candidate(selected_id)
# Add the ID and embedding to the selected information.
selection_index = len(self.selected_ids)
self.selected_ids.append(selected_id)
self.selected_embeddings[selection_index] = selected_embedding
# Reset the best score / best ID.
self.best_score = NEG_INF
self.best_id = None
# Update the candidates redundancy, tracking the best node.
if self.candidate_embeddings.shape[0] > 0:
similarity = cosine_similarity(
self.candidate_embeddings, np.expand_dims(selected_embedding, axis=0)
)
for index, candidate in enumerate(self.candidates):
candidate.update_redundancy(similarity[index][0])
if candidate.score > self.best_score:
self.best_score = candidate.score
self.best_id = candidate.id
return selected_id
def add_candidates(self, candidates: Dict[str, List[float]]) -> None:
"""Add candidates to the consideration set."""
# Determine the keys to actually include.
# These are the candidates that aren't already selected
# or under consideration.
include_ids_set = set(candidates.keys())
include_ids_set.difference_update(self.selected_ids)
include_ids_set.difference_update(self.candidate_id_to_index.keys())
include_ids = list(include_ids_set)
# Now, build up a matrix of the remaining candidate embeddings.
# And add them to the
new_embeddings: NDArray[np.float32] = np.ndarray(
(len(include_ids), self.dimensions)
)
offset = self.candidate_embeddings.shape[0]
for index, candidate_id in enumerate(include_ids):
if candidate_id in include_ids:
self.candidate_id_to_index[candidate_id] = offset + index
embedding = candidates[candidate_id]
new_embeddings[index] = embedding
# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self.query_embedding)
# Compute the distance metrics of all of pairs in the selected set with
# the new candidates.
redundancy = cosine_similarity(
new_embeddings, self._already_selected_embeddings()
)
for index, candidate_id in enumerate(include_ids):
max_redundancy = 0.0
if redundancy.shape[0] > 0:
max_redundancy = redundancy[index].max()
candidate = _Candidate(
id=candidate_id,
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
)
self.candidates.append(candidate)
if candidate.score >= self.best_score:
self.best_score = candidate.score
self.best_id = candidate.id
# Add the new embeddings to the candidate set.
self.candidate_embeddings = np.vstack(
(self.candidate_embeddings, new_embeddings)
)