Skip to content

Commit b160b6f

Browse files
committed
printing for debug and cleanup
1 parent 0e6b58f commit b160b6f

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

attribution/api_attribution.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import itertools
33
import os
4-
import statistics
54
from copy import deepcopy
65
from typing import Any, List, Optional
76

@@ -30,6 +29,8 @@
3029
load_dotenv()
3130

3231
DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo"
32+
REQUEST_DELAY = 0.1
33+
MIN_MIDRANGE_THRESHOLD = 0.01
3334

3435

3536
class OpenAIAttributor(BaseAsyncLLMAttributor):
@@ -156,11 +157,14 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
156157

157158
final_scores = np.zeros(token_count)
158159
total_llm_calls = 1
160+
stage = 0
159161

160162
while masks:
163+
print(f"Stage {stage}")
161164
new_masks = []
162165
perturbation_scores = []
163166
perturbations = []
167+
masked_out = []
164168
for mask in masks:
165169
perturbed_units = [token if not mask[i] else perturbation_strategy.replacement_token for i, token in enumerate(tokens)]
166170
# TODO: Check this is correct unit > token conversion
@@ -173,7 +177,9 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
173177
"token_idx": np.where(mask)[0].tolist(),
174178
}
175179
)
180+
masked_out.append([self.tokenizer.convert_tokens_to_string(list(itertools.chain.from_iterable(itertools.compress(tokens, mask)))).strip()])
176181

182+
print("Masked out tokens/words:", *masked_out, sep="\n")
177183
outputs = await self.compute_attribution_chunks(perturbations)
178184
chunk_scores = self.get_scores(outputs, original_output, **kwargs)
179185
total_llm_calls += len(outputs)
@@ -217,7 +223,7 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
217223
final_scores[mask] = attr_score
218224

219225
midrange_score = (np.max(perturbation_scores) + np.min(perturbation_scores)) / 2
220-
if midrange_score < 0.01:
226+
if midrange_score < MIN_MIDRANGE_THRESHOLD:
221227
break
222228

223229
for mask, score in zip(masks, perturbation_scores):
@@ -236,6 +242,7 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
236242
new_masks.append(mask2)
237243

238244
masks = new_masks
245+
stage += 1
239246

240247
if logger:
241248
logger.df_token_attribution_matrix = logger.df_token_attribution_matrix.drop_duplicates(subset=["exp_id", "input_token_pos", "output_token"], keep="last").sort_values(by=["input_token_pos", "output_token_pos"]).reset_index(drop=True)
@@ -301,7 +308,7 @@ async def compute_attribution_chunks(self, perturbations: list[dict[str, Any]])
301308
tasks[i] for i in range(idx, min(idx + self.request_chunksize, len(tasks)))
302309
]
303310
outputs.extend(await asyncio.gather(*batch))
304-
await asyncio.sleep(0.1)
311+
await asyncio.sleep(REQUEST_DELAY)
305312
else:
306313
outputs = await asyncio.gather(*tasks)
307314

0 commit comments

Comments
 (0)