Skip to content

Commit dda365b

Browse files
addressed PR comments - thanks guys
1 parent db7f294 commit dda365b

File tree

3 files changed

+756
-38
lines changed

3 files changed

+756
-38
lines changed

attribution/api_attribution.py

+46-29
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313

1414
from .attribution_metrics import (
15+
NEAR_ZERO_PROB,
1516
cosine_similarity_attribution,
1617
token_prob_attribution,
1718
)
@@ -24,6 +25,8 @@
2425

2526
load_dotenv()
2627

28+
DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo"
29+
2730

2831
class OpenAIAttributor(BaseLLMAttributor):
2932
def __init__(
@@ -35,7 +38,7 @@ def __init__(
3538
):
3639
openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
3740
self.openai_client = openai.OpenAI(api_key=openai_api_key)
38-
self.openai_model = openai_model or "gpt-3.5-turbo"
41+
self.openai_model = openai_model or DEFAULT_OPENAI_MODEL
3942

4043
self.tokenizer = tokenizer or GPT2Tokenizer.from_pretrained("gpt2")
4144
self.token_embeddings = token_embeddings or GPT2LMHeadModel.from_pretrained("gpt2").transformer.wte.weight.detach().numpy()
@@ -51,6 +54,46 @@ def get_chat_completion(self, input: str) -> openai.types.chat.chat_completion.C
5154
top_logprobs=20,
5255
)
5356
return response.choices[0]
57+
58+
def make_output_location_invariant(self, original_output, perturbed_output):
59+
# Making a copy of the original output, so we can update it with the perturbed output log probs, wherever a token from the unperturned output is found in the perturbed output.
60+
location_invariant_output = deepcopy(original_output)
61+
62+
# Get lists of all tokens and their logprobs (including top 20 in each output position) in the perturbed output
63+
all_top_logprobs = []
64+
all_tokens = []
65+
for perturbed_token in perturbed_output.logprobs.content:
66+
all_top_logprobs.extend([token_logprob.logprob for token_logprob in perturbed_token.top_logprobs])
67+
all_tokens.extend([token_logprob.token for token_logprob in perturbed_token.top_logprobs])
68+
69+
# Sorting the tokens and logprobs by logprob in descending order. This is because .index gets the first occurence of a token in the list, and we want to get the highest logprob for each token.
70+
sorted_indexes = sorted(range(len(all_top_logprobs)), key=all_top_logprobs.__getitem__, reverse=True)
71+
all_tokens_sorted = [all_tokens[s] for s in sorted_indexes]
72+
all_top_logprobs_sorted = [all_top_logprobs[s] for s in sorted_indexes]
73+
74+
# Now, for each token in the original output, if it is found in the perturbed output , update the logprob in the original output with the logprob from the perturbed output.
75+
# Otherwise, set the logprob to a near zero value.
76+
77+
for unperturbed_token in location_invariant_output.logprobs.content:
78+
if unperturbed_token.token in all_tokens_sorted:
79+
perturbed_logprob = all_top_logprobs_sorted[all_tokens_sorted.index(unperturbed_token.token)]
80+
else:
81+
perturbed_logprob = NEAR_ZERO_PROB
82+
83+
# Update the main token logprob
84+
unperturbed_token.logprob = perturbed_logprob
85+
86+
# Update the same token logprob in the top 20 logprobs (duplicate information, but for consistency with the original output structure / OpenAI format)
87+
for top_logprob in unperturbed_token.top_logprobs:
88+
if top_logprob.token == unperturbed_token.token:
89+
top_logprob.logprob = perturbed_logprob
90+
91+
# And update the message content
92+
location_invariant_output.message.content = perturbed_output.message.content
93+
94+
#Now the perturbed output contains the same tokens as the original output, but with the logprobs from the perturbed output.
95+
return location_invariant_output
96+
5497

5598
def compute_attributions(self, input_text: str, **kwargs):
5699
perturbation_strategy: PerturbationStrategy = kwargs.get(
@@ -64,7 +107,6 @@ def compute_attributions(self, input_text: str, **kwargs):
64107
ignore_output_token_location: bool = kwargs.get("ignore_output_token_location", True)
65108

66109
original_output = self.get_chat_completion(input_text)
67-
remaining_output = deepcopy(original_output)
68110

69111
if logger:
70112
logger.start_experiment(
@@ -114,34 +156,9 @@ def compute_attributions(self, input_text: str, **kwargs):
114156
# Get the output logprobs for the perturbed input
115157
perturbed_output = self.get_chat_completion(perturbed_input)
116158

117-
118159
if ignore_output_token_location:
119-
120-
all_top_logprobs = []
121-
all_toks = []
122-
for ptl in perturbed_output.logprobs.content:
123-
all_top_logprobs.extend([tl.logprob for tl in ptl.top_logprobs])
124-
all_toks.extend([tl.token for tl in ptl.top_logprobs])
125-
126-
sorted_indexes = sorted(range(len(all_top_logprobs)), key=all_top_logprobs.__getitem__, reverse=True)
127-
all_toks = [all_toks[s] for s in sorted_indexes]
128-
all_top_logprobs = [all_top_logprobs[s] for s in sorted_indexes]
129-
130-
for otl in remaining_output.logprobs.content:
131-
if otl.token in all_toks:
132-
new_lp = all_top_logprobs[all_toks.index(otl.token)]
133-
134-
else:
135-
new_lp = -100
136-
137-
otl.logprob = new_lp
138-
for tl in otl.top_logprobs:
139-
if tl.token == otl.token:
140-
tl.logprob = new_lp
141-
142-
remaining_output.message.content = perturbed_output.message.content
143-
perturbed_output = remaining_output
144-
160+
perturbed_output = self.make_output_location_invariant(original_output, perturbed_output)
161+
145162
for attribution_strategy in attribution_strategies:
146163
if attribution_strategy == "cosine":
147164
sentence_attr, attributed_tokens, token_attributions = cosine_similarity_attribution(

attribution/attribution_metrics.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.metrics.pairwise import cosine_similarity
77
from transformers import PreTrainedTokenizer
88

9+
NEAR_ZERO_PROB = -100 # Logprob constant for near zero probability
910

1011
def token_prob_attribution(
1112
initial_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
@@ -25,7 +26,6 @@ def token_prob_attribution(
2526

2627
# Probability change for each input token
2728
prob_difference_per_token = np.zeros(len(initial_tokens))
28-
NEAR_ZERO_PROB = -100 # Logprob constant for near zero probability
2929

3030
# Calculate the absolute difference in probabilities for each token
3131
for i, initial_token in enumerate(initial_token_logprobs):
@@ -44,18 +44,18 @@ def cosine_similarity_attribution(
4444
perturbed_output_str: str,
4545
token_embeddings: np.ndarray,
4646
tokenizer: PreTrainedTokenizer,
47-
) -> Tuple[float, np.ndarray]:
47+
) -> Tuple[float, list[str], np.ndarray]:
4848
# Extract embeddings
4949

50-
original_token_ix = tokenizer.encode(original_output_str, return_tensors="pt", add_special_tokens=False)
51-
perturbed_token_ix = tokenizer.encode(perturbed_output_str, return_tensors="pt", add_special_tokens=False)
52-
initial_tokens = [tokenizer.decode(t) for t in original_token_ix.squeeze(axis=0)]
50+
original_token_id = tokenizer.encode(original_output_str, return_tensors="pt", add_special_tokens=False)
51+
perturbed_token_id = tokenizer.encode(perturbed_output_str, return_tensors="pt", add_special_tokens=False)
52+
initial_tokens = [tokenizer.decode(t) for t in original_token_id.squeeze(axis=0)]
5353

54-
original_output_emb = token_embeddings[original_token_ix].reshape(-1, token_embeddings.shape[-1])
55-
perturbed_output_emb = token_embeddings[perturbed_token_ix].reshape(-1, token_embeddings.shape[-1])
54+
original_output_emb = token_embeddings[original_token_id].reshape(-1, token_embeddings.shape[-1])
55+
perturbed_output_emb = token_embeddings[perturbed_token_id].reshape(-1, token_embeddings.shape[-1])
5656

57-
cd = 1-cosine_similarity(original_output_emb, perturbed_output_emb)
58-
token_distance = cd.min(axis=-1)
57+
cosine_distance = 1-cosine_similarity(original_output_emb, perturbed_output_emb)
58+
token_distance = cosine_distance.min(axis=-1)
5959
return token_distance.mean(), initial_tokens, token_distance
6060

6161

0 commit comments

Comments
 (0)