Skip to content

Commit 62fb876

Browse files
fix ruff errors
1 parent 52cbdfd commit 62fb876

File tree

2 files changed

+55
-37
lines changed

2 files changed

+55
-37
lines changed

attribution/experiment_logger.py

+53-34
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import pickle
33
import time
44
from typing import Any, Literal, Optional
5-
from IPython.display import HTML
5+
6+
import matplotlib.colors as mcolors
7+
import matplotlib.pyplot as plt
68
import pandas as pd
79
from IPython.core.getipython import get_ipython
10+
from IPython.display import HTML
811

912
from .token_perturbation import PerturbationStrategy, PerturbedLLMInput, combine_unit
10-
import matplotlib.pyplot as plt
11-
import matplotlib.colors as mcolors
1213

1314

1415
class ExperimentLogger:
@@ -177,23 +178,26 @@ def log_perturbation(
177178
def clean_tokens(self, tokens):
178179
return [token.replace("Ġ", " ") for token in tokens]
179180

180-
181181
def score_to_color(self, score, vmin=-1, vmax=1):
182182
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
183183
cmap = plt.cm.coolwarm
184184
rgba_color = cmap(norm(score))
185185
color_hex = mcolors.to_hex(rgba_color)
186186
return color_hex
187187

188-
def print_text_total_attribution(self, exp_id: Optional[int] = None, score_agg: Literal["mean", "sum", "last"] = "mean"):
189-
188+
def print_text_total_attribution(
189+
self, exp_id: Optional[int] = None, score_agg: Literal["mean", "sum", "last"] = "mean"
190+
):
190191
if exp_id == -1:
191192
exp_id = self.df_experiments["exp_id"].max()
192193

193-
token_attrs_df = self.df_input_token_attribution.groupby(
194-
["exp_id", "attribution_strategy"]) if exp_id is None else self.df_input_token_attribution[self.df_input_token_attribution["exp_id"] == exp_id].groupby(
195-
["exp_id", "attribution_strategy"])
196-
194+
token_attrs_df = (
195+
self.df_input_token_attribution.groupby(["exp_id", "attribution_strategy"])
196+
if exp_id is None
197+
else self.df_input_token_attribution[
198+
self.df_input_token_attribution["exp_id"] == exp_id
199+
].groupby(["exp_id", "attribution_strategy"])
200+
)
197201

198202
for (exp_id, attr_strat), exp_data in token_attrs_df:
199203
if exp_data["input_token_pos"].duplicated().any():
@@ -221,40 +225,40 @@ def print_text_total_attribution(self, exp_id: Optional[int] = None, score_agg:
221225
color = self.score_to_color(score, vmin, vmax)
222226
html_str += f'<span style="text-decoration: underline; text-decoration-color: {color}; text-decoration-thickness: 4px;">{token}</span>'
223227

224-
225-
html_str += f' -> <b>{output}</b>'
226-
html_str += '</div>'
228+
html_str += f" -> <b>{output}</b>"
229+
html_str += "</div>"
227230

228231
# Display
229-
232+
230233
if get_ipython() and "IPKernelApp" in get_ipython().config:
231234
from IPython.display import display
235+
232236
display(HTML(html_str))
233237
else:
234238
self.pretty_print(df)
235239

236-
237240
def print_text_attribution_matrix(self, exp_id: int = -1):
238-
239241
if exp_id == -1:
240242
exp_id = self.df_experiments["exp_id"].max()
241243

242244
matrix = self.get_attribution_matrix(exp_id)
243245

244-
input_tokens = [' '.join(x.split(' ')[:-1]) for x in matrix.index]
246+
input_tokens = [" ".join(x.split(" ")[:-1]) for x in matrix.index]
245247
token_dict = {f"token_{i+1}": t for i, t in enumerate(input_tokens)}
246248

247249
for oi, output_token in enumerate(matrix.columns):
248-
prev_output_str = ''.join([' '.join(ot.split(' ')[:-1]) for ot in matrix.columns[:oi]])
249-
following_output_str = ''.join([' '.join(ot.split(' ')[:-1]) for ot in matrix.columns[oi+1:]])
250+
prev_output_str = "".join([" ".join(ot.split(" ")[:-1]) for ot in matrix.columns[:oi]])
251+
following_output_str = "".join(
252+
[" ".join(ot.split(" ")[:-1]) for ot in matrix.columns[oi + 1 :]]
253+
)
250254
attr_scores = matrix[output_token].tolist()
251255
vmax = max(abs(min(attr_scores)), abs(max(attr_scores)))
252256
vmin = -vmax
253257

254258
score_dict = {f"token_{i+1}": score for i, score in enumerate(attr_scores)}
255259

256260
df = pd.DataFrame([token_dict, score_dict], index=["token", "attr_score"])
257-
261+
258262
# Generating HTML
259263
html_str = '<div style="font-family: monospace;">'
260264
for col in df.columns:
@@ -263,27 +267,33 @@ def print_text_attribution_matrix(self, exp_id: int = -1):
263267
color = self.score_to_color(score, vmin, vmax)
264268
html_str += f'<span style="text-decoration: underline; text-decoration-color: {color}; text-decoration-thickness: 4px;">{token}</span>'
265269

266-
clean_output_token = ' '.join(output_token.split(' ')[:-1])
267-
html_str += f' -> {prev_output_str}<b>{clean_output_token}</b>{following_output_str}'
268-
html_str += '</div>'
270+
clean_output_token = " ".join(output_token.split(" ")[:-1])
271+
html_str += f" -> {prev_output_str}<b>{clean_output_token}</b>{following_output_str}"
272+
html_str += "</div>"
269273

270274
# Display
271-
275+
272276
if get_ipython() and "IPKernelApp" in get_ipython().config:
273277
from IPython.display import display
278+
274279
display(HTML(html_str))
275280
else:
276281
self.pretty_print(df)
277282

278-
279-
def print_total_attribution(self, exp_id: Optional[int] = None, score_agg: Literal["mean", "last"] = "mean"):
283+
def print_total_attribution(
284+
self, exp_id: Optional[int] = None, score_agg: Literal["mean", "last"] = "mean"
285+
):
280286
totals = []
281287
if exp_id == -1:
282288
exp_id = self.df_experiments["exp_id"].max()
283289

284-
token_attrs_df = self.df_input_token_attribution.groupby(
285-
["exp_id", "attribution_strategy"]) if exp_id is None else self.df_input_token_attribution[self.df_input_token_attribution["exp_id"] == exp_id].groupby(
286-
["exp_id", "attribution_strategy"])
290+
token_attrs_df = (
291+
self.df_input_token_attribution.groupby(["exp_id", "attribution_strategy"])
292+
if exp_id is None
293+
else self.df_input_token_attribution[
294+
self.df_input_token_attribution["exp_id"] == exp_id
295+
].groupby(["exp_id", "attribution_strategy"])
296+
)
287297

288298
for (exp_id, attr_strat), exp_data in token_attrs_df:
289299
if exp_data["input_token_pos"].duplicated().any():
@@ -324,11 +334,18 @@ def print_attribution_matrix(
324334
):
325335
if exp_id == -1:
326336
exp_id = self.df_experiments["exp_id"].max()
327-
matrix = self.get_attribution_matrix(exp_id, attribution_strategy, show_debug_cols, score_agg)
337+
matrix = self.get_attribution_matrix(
338+
exp_id, attribution_strategy, show_debug_cols, score_agg
339+
)
328340

329341
if get_ipython() and "IPKernelApp" in get_ipython().config:
330342
from IPython.display import display
331-
display(matrix.style.background_gradient(cmap="coolwarm", vmin=-1, vmax=1).set_properties(**{"white-space": "pre-wrap"}))
343+
344+
display(
345+
matrix.style.background_gradient(cmap="coolwarm", vmin=-1, vmax=1).set_properties(
346+
**{"white-space": "pre-wrap"}
347+
)
348+
)
332349
else:
333350
self.pretty_print(matrix)
334351

@@ -343,13 +360,14 @@ def get_attribution_matrix(
343360
exp_id = self.df_experiments["exp_id"].max()
344361

345362
if attribution_strategy is None:
346-
strategies = self.df_token_attribution_matrix[(self.df_token_attribution_matrix["exp_id"] == exp_id)]["attribution_strategy"].unique()
363+
strategies = self.df_token_attribution_matrix[
364+
(self.df_token_attribution_matrix["exp_id"] == exp_id)
365+
]["attribution_strategy"].unique()
347366
else:
348367
strategies = [attribution_strategy]
349368

350369
matrices = []
351370
for attribution_strategy in strategies:
352-
353371
# Filter the data for the specific experiment and attribution strategy
354372
exp_data = self.df_token_attribution_matrix[
355373
(self.df_token_attribution_matrix["exp_id"] == exp_id)
@@ -401,14 +419,15 @@ def get_attribution_matrix(
401419
)
402420
additional_columns = additional_columns[["perturbed_input", "perturbed_output"]]
403421
matrix = matrix.join(additional_columns)
404-
422+
405423
matrices.append(matrix)
406424
return pd.concat(matrices)
407425

408426
def pretty_print(self, df: pd.DataFrame):
409427
# Check if code is running in Jupyter notebook
410428
if get_ipython() and "IPKernelApp" in get_ipython().config:
411429
from IPython.display import display
430+
412431
display(df.style.set_properties(**{"white-space": "pre-wrap"}))
413432
else:
414433
print(df.to_string())

examples/example_PIZZA.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18+
"import warnings\n",
19+
"\n",
1820
"# Set your open ai API key\n",
1921
"# BEWARE: This will cost you API credits!\n",
2022
"YOUR_OPENAI_API_KEY = \"your-api-key\"\n",
2123
"\n",
22-
"\n",
23-
"import warnings\n",
24-
"\n",
2524
"# Suppress annoying FutureWarning from huggingface_hub\n",
2625
"warnings.filterwarnings(\"ignore\", category=FutureWarning, module=\"huggingface_hub\")"
2726
]

0 commit comments

Comments
 (0)