Skip to content

Commit da19e3a

Browse files
committed
adding dynamic threshold and llm call counter
1 parent 95a2afd commit da19e3a

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

attribution/api_attribution.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import itertools
33
import os
4+
import statistics
45
from copy import deepcopy
56
from typing import Any, List, Optional
67

@@ -116,11 +117,15 @@ def get_perturbations(self, input_text, chunksize, **kwargs):
116117

117118
return perturbations
118119

119-
async def hierarchical_perturbation(self, input_text: str, init_chunksize: int, stages: int, **kwargs):
120+
async def hierarchical_perturbation(self, input_text: str, init_chunksize: int, stages: int, threshold: float = 0.5, **kwargs):
120121
perturbation_strategy: PerturbationStrategy = kwargs.get(
121122
"perturbation_strategy", FixedPerturbationStrategy()
122123
)
123124

125+
attribution_strategies: List[str] = kwargs.get(
126+
"attribution_strategies", ["cosine", "prob_diff"]
127+
)
128+
124129
logger: ExperimentLogger = kwargs.get("logger", None)
125130
perturb_word_wise: bool = kwargs.get("perturb_word_wise", False)
126131

@@ -139,21 +144,25 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
139144
process_chunks = None
140145
prev_perturbations = None
141146
prev_process_chunks = None
142-
all_scores = []
147+
total_llm_calls = 0
143148
for stage in range(stages):
144149

145150
perturbations = self.get_perturbations(input_text, chunksize, **kwargs)
146151

147152
if stage > 0:
148-
process_chunks = []
149-
for p, processed in zip(prev_perturbations, prev_process_chunks):
153+
scores = []
154+
for perturbation, processed in zip(prev_perturbations, prev_process_chunks):
150155
if processed:
151-
score = chunk_scores.pop(0)
152-
decision = score["cosine"]["sentence_attribution"] > 0.5
156+
attr = chunk_scores.pop(0)
157+
scores.append(attr[attribution_strategies[0]]["sentence_attribution"])
153158
else:
154-
decision = False
159+
scores.append(None)
155160

156-
process_chunks.extend([decision] * (2 if chunksize > 1 else len(p["unit_tokens"])))
161+
process_chunks = []
162+
median_score = statistics.median([s for s in scores if s is not None])
163+
for score in scores:
164+
decision = score is not None and (score > threshold or score > median_score)
165+
process_chunks.extend([decision] * (2 if chunksize > 1 else len(perturbation["unit_tokens"])))
157166
else:
158167
process_chunks = [True] * len(perturbations)
159168

@@ -165,12 +174,12 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
165174
outputs = await self.compute_attribution_chunks(perturbations, **kwargs)
166175
chunk_scores = self.get_scores(outputs, original_output, **kwargs)
167176

177+
total_llm_calls += len(outputs)
168178
prev_process_chunks = process_chunks
169179

170-
171180
if logger:
172-
for p, output, score in zip(perturbations, outputs, chunk_scores):
173-
for unit_token, token_id in zip(p["unit_tokens"], p["token_idx"]):
181+
for perturbation, output, score in zip(perturbations, outputs, chunk_scores):
182+
for unit_token, token_id in zip(perturbation["unit_tokens"], perturbation["token_idx"]):
174183

175184
for attribution_strategy, attr_result in score.items():
176185

@@ -187,22 +196,21 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
187196
j,
188197
attr_result["attributed_tokens"][j],
189198
attr_score.squeeze(),
190-
p["input"],
199+
perturbation["input"],
191200
output.message.content,
192201
)
193202

194203
logger.log_perturbation(
195204
0, # TODO: Why is this here?
196-
self.tokenizer.decode(p["replaced_token_ids"], skip_special_tokens=True)[
205+
self.tokenizer.decode(perturbation["replaced_token_ids"], skip_special_tokens=True)[
197206
0
198207
],
199208
perturbation_strategy,
200209
input_text,
201210
original_output.message.content,
202-
p["input"],
211+
perturbation["input"],
203212
output.message.content,
204213
)
205-
logger.stop_experiment()
206214

207215
if stage == stages - 2:
208216
chunksize = 1
@@ -211,10 +219,9 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
211219
if chunksize == 0:
212220
break
213221

214-
logger.df_token_attribution_matrix = logger.df_token_attribution_matrix.drop_duplicates(subset="input_token_pos", keep="last").sort_values(by="input_token_pos")
215-
logger.df_input_token_attribution = logger.df_input_token_attribution.drop_duplicates(subset="input_token_pos", keep="last").sort_values(by="input_token_pos")
216-
217-
return all_scores
222+
logger.df_token_attribution_matrix = logger.df_token_attribution_matrix.drop_duplicates(subset=["input_token_pos", "output_token"], keep="last").sort_values(by="input_token_pos")
223+
logger.df_input_token_attribution = logger.df_input_token_attribution.drop_duplicates(subset=["input_token_pos"], keep="last").sort_values(by="input_token_pos")
224+
logger.stop_experiment(num_llm_calls=total_llm_calls)
218225

219226
def get_scores(self, perturbed_output, original_output, **kwargs):
220227
attribution_strategies: List[str] = kwargs.get(
@@ -463,4 +470,4 @@ async def compute_attributions(self, input_text: str, **kwargs):
463470
perturbation["input"],
464471
perturbed_output.message.content,
465472
)
466-
logger.stop_experiment()
473+
logger.stop_experiment(num_llm_calls=len(outputs))

attribution/experiment_logger.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self, experiment_id=0):
2020
"perturbation_strategy",
2121
"perturb_word_wise",
2222
"duration",
23+
"num_llm_calls",
2324
]
2425
)
2526
self.df_input_token_attribution = pd.DataFrame(
@@ -72,12 +73,16 @@ def start_experiment(
7273
"perturbation_strategy": str(perturbation_strategy),
7374
"perturb_word_wise": perturb_word_wise,
7475
"duration": None,
76+
"num_llm_calls": None,
7577
}
7678

77-
def stop_experiment(self):
79+
def stop_experiment(self, num_llm_calls: Optional[int] = None):
7880
self.df_experiments.loc[len(self.df_experiments) - 1, "duration"] = (
7981
time.time() - self.experiment_start_time
8082
)
83+
self.df_experiments.loc[len(self.df_experiments) - 1, "num_llm_calls"] = (
84+
num_llm_calls
85+
)
8186

8287
def log_input_token_attribution(
8388
self,

0 commit comments

Comments
 (0)