12
12
)
13
13
14
14
from .attribution_metrics import (
15
+ NEAR_ZERO_PROB ,
15
16
cosine_similarity_attribution ,
16
17
token_prob_attribution ,
17
18
)
24
25
25
26
load_dotenv ()
26
27
28
+ DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo"
29
+
27
30
28
31
class OpenAIAttributor (BaseLLMAttributor ):
29
32
def __init__ (
@@ -35,7 +38,7 @@ def __init__(
35
38
):
36
39
openai_api_key = openai_api_key or os .getenv ("OPENAI_API_KEY" )
37
40
self .openai_client = openai .OpenAI (api_key = openai_api_key )
38
- self .openai_model = openai_model or "gpt-3.5-turbo"
41
+ self .openai_model = openai_model or DEFAULT_OPENAI_MODEL
39
42
40
43
self .tokenizer = tokenizer or GPT2Tokenizer .from_pretrained ("gpt2" )
41
44
self .token_embeddings = token_embeddings or GPT2LMHeadModel .from_pretrained ("gpt2" ).transformer .wte .weight .detach ().numpy ()
@@ -51,6 +54,46 @@ def get_chat_completion(self, input: str) -> openai.types.chat.chat_completion.C
51
54
top_logprobs = 20 ,
52
55
)
53
56
return response .choices [0 ]
57
+
58
+ def make_output_location_invariant (self , original_output , perturbed_output ):
59
+ # Making a copy of the original output, so we can update it with the perturbed output log probs, wherever a token from the unperturned output is found in the perturbed output.
60
+ location_invariant_output = deepcopy (original_output )
61
+
62
+ # Get lists of all tokens and their logprobs (including top 20 in each output position) in the perturbed output
63
+ all_top_logprobs = []
64
+ all_tokens = []
65
+ for perturbed_token in perturbed_output .logprobs .content :
66
+ all_top_logprobs .extend ([token_logprob .logprob for token_logprob in perturbed_token .top_logprobs ])
67
+ all_tokens .extend ([token_logprob .token for token_logprob in perturbed_token .top_logprobs ])
68
+
69
+ # Sorting the tokens and logprobs by logprob in descending order. This is because .index gets the first occurence of a token in the list, and we want to get the highest logprob for each token.
70
+ sorted_indexes = sorted (range (len (all_top_logprobs )), key = all_top_logprobs .__getitem__ , reverse = True )
71
+ all_tokens_sorted = [all_tokens [s ] for s in sorted_indexes ]
72
+ all_top_logprobs_sorted = [all_top_logprobs [s ] for s in sorted_indexes ]
73
+
74
+ # Now, for each token in the original output, if it is found in the perturbed output , update the logprob in the original output with the logprob from the perturbed output.
75
+ # Otherwise, set the logprob to a near zero value.
76
+
77
+ for unperturbed_token in location_invariant_output .logprobs .content :
78
+ if unperturbed_token .token in all_tokens_sorted :
79
+ perturbed_logprob = all_top_logprobs_sorted [all_tokens_sorted .index (unperturbed_token .token )]
80
+ else :
81
+ perturbed_logprob = NEAR_ZERO_PROB
82
+
83
+ # Update the main token logprob
84
+ unperturbed_token .logprob = perturbed_logprob
85
+
86
+ # Update the same token logprob in the top 20 logprobs (duplicate information, but for consistency with the original output structure / OpenAI format)
87
+ for top_logprob in unperturbed_token .top_logprobs :
88
+ if top_logprob .token == unperturbed_token .token :
89
+ top_logprob .logprob = perturbed_logprob
90
+
91
+ # And update the message content
92
+ location_invariant_output .message .content = perturbed_output .message .content
93
+
94
+ #Now the perturbed output contains the same tokens as the original output, but with the logprobs from the perturbed output.
95
+ return location_invariant_output
96
+
54
97
55
98
def compute_attributions (self , input_text : str , ** kwargs ):
56
99
perturbation_strategy : PerturbationStrategy = kwargs .get (
@@ -64,7 +107,6 @@ def compute_attributions(self, input_text: str, **kwargs):
64
107
ignore_output_token_location : bool = kwargs .get ("ignore_output_token_location" , True )
65
108
66
109
original_output = self .get_chat_completion (input_text )
67
- remaining_output = deepcopy (original_output )
68
110
69
111
if logger :
70
112
logger .start_experiment (
@@ -114,34 +156,9 @@ def compute_attributions(self, input_text: str, **kwargs):
114
156
# Get the output logprobs for the perturbed input
115
157
perturbed_output = self .get_chat_completion (perturbed_input )
116
158
117
-
118
159
if ignore_output_token_location :
119
-
120
- all_top_logprobs = []
121
- all_toks = []
122
- for ptl in perturbed_output .logprobs .content :
123
- all_top_logprobs .extend ([tl .logprob for tl in ptl .top_logprobs ])
124
- all_toks .extend ([tl .token for tl in ptl .top_logprobs ])
125
-
126
- sorted_indexes = sorted (range (len (all_top_logprobs )), key = all_top_logprobs .__getitem__ , reverse = True )
127
- all_toks = [all_toks [s ] for s in sorted_indexes ]
128
- all_top_logprobs = [all_top_logprobs [s ] for s in sorted_indexes ]
129
-
130
- for otl in remaining_output .logprobs .content :
131
- if otl .token in all_toks :
132
- new_lp = all_top_logprobs [all_toks .index (otl .token )]
133
-
134
- else :
135
- new_lp = - 100
136
-
137
- otl .logprob = new_lp
138
- for tl in otl .top_logprobs :
139
- if tl .token == otl .token :
140
- tl .logprob = new_lp
141
-
142
- remaining_output .message .content = perturbed_output .message .content
143
- perturbed_output = remaining_output
144
-
160
+ perturbed_output = self .make_output_location_invariant (original_output , perturbed_output )
161
+
145
162
for attribution_strategy in attribution_strategies :
146
163
if attribution_strategy == "cosine" :
147
164
sentence_attr , attributed_tokens , token_attributions = cosine_similarity_attribution (
0 commit comments