Skip to content

Commit b467d5b

Browse files
author
Sebastian Sosa
committed
Refactor attribution metrics, logger and perturbation method. Show example usage on notebook
1 parent 0f52d31 commit b467d5b

6 files changed

+3951
-2075
lines changed

attribution/attribution_metrics.py

+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import math
2+
from typing import List, Tuple
3+
4+
import numpy as np
5+
import openai
6+
from sklearn.metrics.pairwise import cosine_similarity
7+
from transformers import PreTrainedModel, PreTrainedTokenizer
8+
9+
10+
def token_prob_difference(
11+
initial_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
12+
perturbed_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
13+
) -> Tuple[float, List[str], np.ndarray]:
14+
# Extract token and logprob from initial_logprobs
15+
initial_token_logprobs = [
16+
(logprob.token, logprob.logprob) for logprob in initial_logprobs.content
17+
]
18+
initial_tokens = [content.token for content in initial_logprobs.content]
19+
20+
# Create a list of dictionaries with token and logprob from perturbed_logprobs
21+
perturbed_token_logprobs_list = [
22+
{
23+
top_logprob.token: top_logprob.logprob
24+
for top_logprob in token_content.top_logprobs
25+
}
26+
for token_content in perturbed_logprobs.content
27+
]
28+
29+
# Probability change for each input token
30+
prob_difference_per_token = np.zeros(len(initial_tokens))
31+
NEAR_ZERO_PROB = -100 # Logprob constant for near zero probability
32+
33+
# Calculate the absolute difference in probabilities for each token
34+
for i, initial_token in enumerate(initial_token_logprobs):
35+
perturbed_token_logprobs = (
36+
perturbed_token_logprobs_list[i]
37+
if i < len(perturbed_token_logprobs_list)
38+
else {}
39+
)
40+
perturbed_logprob = perturbed_token_logprobs.get(
41+
initial_token[0], NEAR_ZERO_PROB
42+
)
43+
prob_difference_per_token[i] = abs(
44+
math.exp(initial_token[1]) - math.exp(perturbed_logprob)
45+
)
46+
47+
# Note: Different length outputs shift the mean upwards. This may or may not be desired behaviour.
48+
return prob_difference_per_token.mean(), initial_tokens, prob_difference_per_token
49+
50+
51+
def token_displacement(
52+
initial_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
53+
perturbed_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
54+
) -> Tuple[float, List[str], np.ndarray]:
55+
initial_tokens = [content.token for content in initial_logprobs.content]
56+
perturbed_top_tokens = [
57+
[top_logprob.token for top_logprob in token_content.top_logprobs]
58+
for token_content in perturbed_logprobs.content
59+
]
60+
61+
# Token displacement for each initially predicted token
62+
displacement_per_token = np.zeros(len(initial_tokens))
63+
MAX_TOKEN_DISPLACEMENT = 20
64+
for i, token in enumerate(initial_tokens):
65+
if i < len(perturbed_top_tokens) and token in perturbed_top_tokens[i]:
66+
displacement_per_token[i] = abs(i - perturbed_top_tokens[i].index(token))
67+
else:
68+
displacement_per_token[i] = MAX_TOKEN_DISPLACEMENT # TODO: Revise
69+
70+
return displacement_per_token.mean(), initial_tokens, displacement_per_token
71+
72+
73+
def max_logprob_difference(
74+
initial_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
75+
perturbed_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
76+
):
77+
# Get the logprobs of the top 20 tokens for the initial and perturbed outputs
78+
initial_top_logprobs = {
79+
logprob.token: logprob.logprob for logprob in initial_logprobs.content
80+
}
81+
perturbed_top_logprobs = {
82+
logprob.token: logprob.logprob for logprob in perturbed_logprobs.content
83+
}
84+
85+
# Calculate the maximum difference in logprobs
86+
max_difference = 0
87+
for token, initial_logprob in initial_top_logprobs.items():
88+
perturbed_logprob = perturbed_top_logprobs.get(token, 0)
89+
max_difference = max(max_difference, abs(initial_logprob - perturbed_logprob))
90+
91+
return max_difference
92+
93+
94+
def get_sentence_embeddings(
95+
sentence: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer
96+
) -> Tuple[np.ndarray, np.ndarray]:
97+
inputs = tokenizer(sentence, return_tensors="pt")
98+
embeddings = model.transformer.wte(inputs["input_ids"]) # Get the embeddings
99+
embeddings = embeddings.detach().numpy().squeeze()
100+
return embeddings.mean(axis=0), embeddings
101+
102+
103+
def cosine_similarity_attribution(
104+
original_output_choice: openai.types.chat.chat_completion.Choice,
105+
perturbed_output_choice: openai.types.chat.chat_completion.Choice,
106+
model: PreTrainedModel,
107+
tokenizer: PreTrainedTokenizer,
108+
) -> Tuple[float, np.ndarray]:
109+
# Extract embeddings
110+
initial_sentence_emb, initial_token_embs = get_sentence_embeddings(
111+
original_output_choice.message.content, model, tokenizer
112+
)
113+
perturbed_sentence_emb, perturbed_token_embs = get_sentence_embeddings(
114+
perturbed_output_choice.message.content, model, tokenizer
115+
)
116+
117+
# Reshape embeddings
118+
initial_sentence_emb = initial_sentence_emb.reshape(1, -1)
119+
perturbed_sentence_emb = perturbed_sentence_emb.reshape(1, -1)
120+
121+
# Calculate similarities
122+
self_similarity = float(
123+
cosine_similarity(initial_sentence_emb, initial_sentence_emb)
124+
)
125+
sentence_similarity = float(
126+
cosine_similarity(initial_sentence_emb, perturbed_sentence_emb)
127+
)
128+
129+
# Calculate token similarities for shared length
130+
shared_length = min(initial_token_embs.shape[0], perturbed_token_embs.shape[0])
131+
token_similarities_shared = cosine_similarity(
132+
initial_token_embs[:shared_length], perturbed_token_embs[:shared_length]
133+
).diagonal()
134+
135+
# Pad token similarities to match initial token embeddings shape
136+
token_similarities = np.pad(
137+
token_similarities_shared, (0, initial_token_embs.shape[0] - shared_length)
138+
)
139+
140+
# Return difference in sentence similarity and token similarities
141+
return self_similarity - sentence_similarity, 1 - token_similarities
142+
143+
144+
def _is_token_in_top_20(
145+
token: str,
146+
top_logprobs: List[openai.types.chat.chat_completion_token_logprob.TopLogprob],
147+
):
148+
top_20_tokens = set(logprob.token for logprob in top_logprobs)
149+
return token in top_20_tokens
150+
151+
152+
def any_tokens_in_top_20(
153+
initial_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
154+
new_logprobs: openai.types.chat.chat_completion.ChoiceLogprobs,
155+
) -> bool:
156+
if (
157+
initial_logprobs is None
158+
or new_logprobs is None
159+
or initial_logprobs.content is None
160+
or new_logprobs.content is None
161+
):
162+
return False
163+
164+
return any(
165+
_is_token_in_top_20(initial_token.token, new_token.top_logprobs)
166+
for initial_token, new_token in zip(
167+
initial_logprobs.content, new_logprobs.content
168+
)
169+
)

0 commit comments

Comments
 (0)