2
2
import pickle
3
3
import time
4
4
from typing import Any , Literal , Optional
5
- from IPython .display import HTML
5
+
6
+ import matplotlib .colors as mcolors
7
+ import matplotlib .pyplot as plt
6
8
import pandas as pd
7
9
from IPython .core .getipython import get_ipython
10
+ from IPython .display import HTML
8
11
9
12
from .token_perturbation import PerturbationStrategy , PerturbedLLMInput , combine_unit
10
- import matplotlib .pyplot as plt
11
- import matplotlib .colors as mcolors
12
13
13
14
14
15
class ExperimentLogger :
@@ -177,23 +178,26 @@ def log_perturbation(
177
178
def clean_tokens (self , tokens ):
178
179
return [token .replace ("Ġ" , " " ) for token in tokens ]
179
180
180
-
181
181
def score_to_color (self , score , vmin = - 1 , vmax = 1 ):
182
182
norm = mcolors .Normalize (vmin = vmin , vmax = vmax )
183
183
cmap = plt .cm .coolwarm
184
184
rgba_color = cmap (norm (score ))
185
185
color_hex = mcolors .to_hex (rgba_color )
186
186
return color_hex
187
187
188
- def print_text_total_attribution (self , exp_id : Optional [int ] = None , score_agg : Literal ["mean" , "sum" , "last" ] = "mean" ):
189
-
188
+ def print_text_total_attribution (
189
+ self , exp_id : Optional [int ] = None , score_agg : Literal ["mean" , "sum" , "last" ] = "mean"
190
+ ):
190
191
if exp_id == - 1 :
191
192
exp_id = self .df_experiments ["exp_id" ].max ()
192
193
193
- token_attrs_df = self .df_input_token_attribution .groupby (
194
- ["exp_id" , "attribution_strategy" ]) if exp_id is None else self .df_input_token_attribution [self .df_input_token_attribution ["exp_id" ] == exp_id ].groupby (
195
- ["exp_id" , "attribution_strategy" ])
196
-
194
+ token_attrs_df = (
195
+ self .df_input_token_attribution .groupby (["exp_id" , "attribution_strategy" ])
196
+ if exp_id is None
197
+ else self .df_input_token_attribution [
198
+ self .df_input_token_attribution ["exp_id" ] == exp_id
199
+ ].groupby (["exp_id" , "attribution_strategy" ])
200
+ )
197
201
198
202
for (exp_id , attr_strat ), exp_data in token_attrs_df :
199
203
if exp_data ["input_token_pos" ].duplicated ().any ():
@@ -221,40 +225,40 @@ def print_text_total_attribution(self, exp_id: Optional[int] = None, score_agg:
221
225
color = self .score_to_color (score , vmin , vmax )
222
226
html_str += f'<span style="text-decoration: underline; text-decoration-color: { color } ; text-decoration-thickness: 4px;">{ token } </span>'
223
227
224
-
225
- html_str += f' -> <b>{ output } </b>'
226
- html_str += '</div>'
228
+ html_str += f" -> <b>{ output } </b>"
229
+ html_str += "</div>"
227
230
228
231
# Display
229
-
232
+
230
233
if get_ipython () and "IPKernelApp" in get_ipython ().config :
231
234
from IPython .display import display
235
+
232
236
display (HTML (html_str ))
233
237
else :
234
238
self .pretty_print (df )
235
239
236
-
237
240
def print_text_attribution_matrix (self , exp_id : int = - 1 ):
238
-
239
241
if exp_id == - 1 :
240
242
exp_id = self .df_experiments ["exp_id" ].max ()
241
243
242
244
matrix = self .get_attribution_matrix (exp_id )
243
245
244
- input_tokens = [' ' .join (x .split (' ' )[:- 1 ]) for x in matrix .index ]
246
+ input_tokens = [" " .join (x .split (" " )[:- 1 ]) for x in matrix .index ]
245
247
token_dict = {f"token_{ i + 1 } " : t for i , t in enumerate (input_tokens )}
246
248
247
249
for oi , output_token in enumerate (matrix .columns ):
248
- prev_output_str = '' .join ([' ' .join (ot .split (' ' )[:- 1 ]) for ot in matrix .columns [:oi ]])
249
- following_output_str = '' .join ([' ' .join (ot .split (' ' )[:- 1 ]) for ot in matrix .columns [oi + 1 :]])
250
+ prev_output_str = "" .join ([" " .join (ot .split (" " )[:- 1 ]) for ot in matrix .columns [:oi ]])
251
+ following_output_str = "" .join (
252
+ [" " .join (ot .split (" " )[:- 1 ]) for ot in matrix .columns [oi + 1 :]]
253
+ )
250
254
attr_scores = matrix [output_token ].tolist ()
251
255
vmax = max (abs (min (attr_scores )), abs (max (attr_scores )))
252
256
vmin = - vmax
253
257
254
258
score_dict = {f"token_{ i + 1 } " : score for i , score in enumerate (attr_scores )}
255
259
256
260
df = pd .DataFrame ([token_dict , score_dict ], index = ["token" , "attr_score" ])
257
-
261
+
258
262
# Generating HTML
259
263
html_str = '<div style="font-family: monospace;">'
260
264
for col in df .columns :
@@ -263,27 +267,33 @@ def print_text_attribution_matrix(self, exp_id: int = -1):
263
267
color = self .score_to_color (score , vmin , vmax )
264
268
html_str += f'<span style="text-decoration: underline; text-decoration-color: { color } ; text-decoration-thickness: 4px;">{ token } </span>'
265
269
266
- clean_output_token = ' ' .join (output_token .split (' ' )[:- 1 ])
267
- html_str += f' -> { prev_output_str } <b>{ clean_output_token } </b>{ following_output_str } '
268
- html_str += ' </div>'
270
+ clean_output_token = " " .join (output_token .split (" " )[:- 1 ])
271
+ html_str += f" -> { prev_output_str } <b>{ clean_output_token } </b>{ following_output_str } "
272
+ html_str += " </div>"
269
273
270
274
# Display
271
-
275
+
272
276
if get_ipython () and "IPKernelApp" in get_ipython ().config :
273
277
from IPython .display import display
278
+
274
279
display (HTML (html_str ))
275
280
else :
276
281
self .pretty_print (df )
277
282
278
-
279
- def print_total_attribution (self , exp_id : Optional [int ] = None , score_agg : Literal ["mean" , "last" ] = "mean" ):
283
+ def print_total_attribution (
284
+ self , exp_id : Optional [int ] = None , score_agg : Literal ["mean" , "last" ] = "mean"
285
+ ):
280
286
totals = []
281
287
if exp_id == - 1 :
282
288
exp_id = self .df_experiments ["exp_id" ].max ()
283
289
284
- token_attrs_df = self .df_input_token_attribution .groupby (
285
- ["exp_id" , "attribution_strategy" ]) if exp_id is None else self .df_input_token_attribution [self .df_input_token_attribution ["exp_id" ] == exp_id ].groupby (
286
- ["exp_id" , "attribution_strategy" ])
290
+ token_attrs_df = (
291
+ self .df_input_token_attribution .groupby (["exp_id" , "attribution_strategy" ])
292
+ if exp_id is None
293
+ else self .df_input_token_attribution [
294
+ self .df_input_token_attribution ["exp_id" ] == exp_id
295
+ ].groupby (["exp_id" , "attribution_strategy" ])
296
+ )
287
297
288
298
for (exp_id , attr_strat ), exp_data in token_attrs_df :
289
299
if exp_data ["input_token_pos" ].duplicated ().any ():
@@ -324,11 +334,18 @@ def print_attribution_matrix(
324
334
):
325
335
if exp_id == - 1 :
326
336
exp_id = self .df_experiments ["exp_id" ].max ()
327
- matrix = self .get_attribution_matrix (exp_id , attribution_strategy , show_debug_cols , score_agg )
337
+ matrix = self .get_attribution_matrix (
338
+ exp_id , attribution_strategy , show_debug_cols , score_agg
339
+ )
328
340
329
341
if get_ipython () and "IPKernelApp" in get_ipython ().config :
330
342
from IPython .display import display
331
- display (matrix .style .background_gradient (cmap = "coolwarm" , vmin = - 1 , vmax = 1 ).set_properties (** {"white-space" : "pre-wrap" }))
343
+
344
+ display (
345
+ matrix .style .background_gradient (cmap = "coolwarm" , vmin = - 1 , vmax = 1 ).set_properties (
346
+ ** {"white-space" : "pre-wrap" }
347
+ )
348
+ )
332
349
else :
333
350
self .pretty_print (matrix )
334
351
@@ -343,13 +360,14 @@ def get_attribution_matrix(
343
360
exp_id = self .df_experiments ["exp_id" ].max ()
344
361
345
362
if attribution_strategy is None :
346
- strategies = self .df_token_attribution_matrix [(self .df_token_attribution_matrix ["exp_id" ] == exp_id )]["attribution_strategy" ].unique ()
363
+ strategies = self .df_token_attribution_matrix [
364
+ (self .df_token_attribution_matrix ["exp_id" ] == exp_id )
365
+ ]["attribution_strategy" ].unique ()
347
366
else :
348
367
strategies = [attribution_strategy ]
349
368
350
369
matrices = []
351
370
for attribution_strategy in strategies :
352
-
353
371
# Filter the data for the specific experiment and attribution strategy
354
372
exp_data = self .df_token_attribution_matrix [
355
373
(self .df_token_attribution_matrix ["exp_id" ] == exp_id )
@@ -401,14 +419,15 @@ def get_attribution_matrix(
401
419
)
402
420
additional_columns = additional_columns [["perturbed_input" , "perturbed_output" ]]
403
421
matrix = matrix .join (additional_columns )
404
-
422
+
405
423
matrices .append (matrix )
406
424
return pd .concat (matrices )
407
425
408
426
def pretty_print (self , df : pd .DataFrame ):
409
427
# Check if code is running in Jupyter notebook
410
428
if get_ipython () and "IPKernelApp" in get_ipython ().config :
411
429
from IPython .display import display
430
+
412
431
display (df .style .set_properties (** {"white-space" : "pre-wrap" }))
413
432
else :
414
433
print (df .to_string ())
0 commit comments