1
1
import asyncio
2
2
import itertools
3
3
import os
4
+ import statistics
4
5
from copy import deepcopy
5
6
from typing import Any , List , Optional
6
7
@@ -116,11 +117,15 @@ def get_perturbations(self, input_text, chunksize, **kwargs):
116
117
117
118
return perturbations
118
119
119
- async def hierarchical_perturbation (self , input_text : str , init_chunksize : int , stages : int , ** kwargs ):
120
+ async def hierarchical_perturbation (self , input_text : str , init_chunksize : int , stages : int , threshold : float = 0.5 , ** kwargs ):
120
121
perturbation_strategy : PerturbationStrategy = kwargs .get (
121
122
"perturbation_strategy" , FixedPerturbationStrategy ()
122
123
)
123
124
125
+ attribution_strategies : List [str ] = kwargs .get (
126
+ "attribution_strategies" , ["cosine" , "prob_diff" ]
127
+ )
128
+
124
129
logger : ExperimentLogger = kwargs .get ("logger" , None )
125
130
perturb_word_wise : bool = kwargs .get ("perturb_word_wise" , False )
126
131
@@ -139,21 +144,25 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
139
144
process_chunks = None
140
145
prev_perturbations = None
141
146
prev_process_chunks = None
142
- all_scores = []
147
+ total_llm_calls = 0
143
148
for stage in range (stages ):
144
149
145
150
perturbations = self .get_perturbations (input_text , chunksize , ** kwargs )
146
151
147
152
if stage > 0 :
148
- process_chunks = []
149
- for p , processed in zip (prev_perturbations , prev_process_chunks ):
153
+ scores = []
154
+ for perturbation , processed in zip (prev_perturbations , prev_process_chunks ):
150
155
if processed :
151
- score = chunk_scores .pop (0 )
152
- decision = score [ "cosine" ] ["sentence_attribution" ] > 0.5
156
+ attr = chunk_scores .pop (0 )
157
+ scores . append ( attr [ attribution_strategies [ 0 ]] ["sentence_attribution" ])
153
158
else :
154
- decision = False
159
+ scores . append ( None )
155
160
156
- process_chunks .extend ([decision ] * (2 if chunksize > 1 else len (p ["unit_tokens" ])))
161
+ process_chunks = []
162
+ median_score = statistics .median ([s for s in scores if s is not None ])
163
+ for score in scores :
164
+ decision = score is not None and (score > threshold or score > median_score )
165
+ process_chunks .extend ([decision ] * (2 if chunksize > 1 else len (perturbation ["unit_tokens" ])))
157
166
else :
158
167
process_chunks = [True ] * len (perturbations )
159
168
@@ -165,12 +174,12 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
165
174
outputs = await self .compute_attribution_chunks (perturbations , ** kwargs )
166
175
chunk_scores = self .get_scores (outputs , original_output , ** kwargs )
167
176
177
+ total_llm_calls += len (outputs )
168
178
prev_process_chunks = process_chunks
169
179
170
-
171
180
if logger :
172
- for p , output , score in zip (perturbations , outputs , chunk_scores ):
173
- for unit_token , token_id in zip (p ["unit_tokens" ], p ["token_idx" ]):
181
+ for perturbation , output , score in zip (perturbations , outputs , chunk_scores ):
182
+ for unit_token , token_id in zip (perturbation ["unit_tokens" ], perturbation ["token_idx" ]):
174
183
175
184
for attribution_strategy , attr_result in score .items ():
176
185
@@ -187,22 +196,21 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
187
196
j ,
188
197
attr_result ["attributed_tokens" ][j ],
189
198
attr_score .squeeze (),
190
- p ["input" ],
199
+ perturbation ["input" ],
191
200
output .message .content ,
192
201
)
193
202
194
203
logger .log_perturbation (
195
204
0 , # TODO: Why is this here?
196
- self .tokenizer .decode (p ["replaced_token_ids" ], skip_special_tokens = True )[
205
+ self .tokenizer .decode (perturbation ["replaced_token_ids" ], skip_special_tokens = True )[
197
206
0
198
207
],
199
208
perturbation_strategy ,
200
209
input_text ,
201
210
original_output .message .content ,
202
- p ["input" ],
211
+ perturbation ["input" ],
203
212
output .message .content ,
204
213
)
205
- logger .stop_experiment ()
206
214
207
215
if stage == stages - 2 :
208
216
chunksize = 1
@@ -211,10 +219,9 @@ async def hierarchical_perturbation(self, input_text: str, init_chunksize: int,
211
219
if chunksize == 0 :
212
220
break
213
221
214
- logger .df_token_attribution_matrix = logger .df_token_attribution_matrix .drop_duplicates (subset = "input_token_pos" , keep = "last" ).sort_values (by = "input_token_pos" )
215
- logger .df_input_token_attribution = logger .df_input_token_attribution .drop_duplicates (subset = "input_token_pos" , keep = "last" ).sort_values (by = "input_token_pos" )
216
-
217
- return all_scores
222
+ logger .df_token_attribution_matrix = logger .df_token_attribution_matrix .drop_duplicates (subset = ["input_token_pos" , "output_token" ], keep = "last" ).sort_values (by = "input_token_pos" )
223
+ logger .df_input_token_attribution = logger .df_input_token_attribution .drop_duplicates (subset = ["input_token_pos" ], keep = "last" ).sort_values (by = "input_token_pos" )
224
+ logger .stop_experiment (num_llm_calls = total_llm_calls )
218
225
219
226
def get_scores (self , perturbed_output , original_output , ** kwargs ):
220
227
attribution_strategies : List [str ] = kwargs .get (
@@ -463,4 +470,4 @@ async def compute_attributions(self, input_text: str, **kwargs):
463
470
perturbation ["input" ],
464
471
perturbed_output .message .content ,
465
472
)
466
- logger .stop_experiment ()
473
+ logger .stop_experiment (num_llm_calls = len ( outputs ) )
0 commit comments