Skip to content

Commit 0bb8b2e

Browse files
author
Sebastian Sosa
committed
Fix bug when printing attr matrix for cosine method, and parametrize if input shoud be perturbed by token or by word
1 parent b2d7cbf commit 0bb8b2e

File tree

3 files changed

+841
-1805
lines changed

3 files changed

+841
-1805
lines changed

attribution/attribution_metrics.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -109,34 +109,40 @@ def cosine_similarity_attribution(
109109
tokenizer: PreTrainedTokenizer,
110110
) -> Tuple[float, np.ndarray]:
111111
# Extract embeddings
112-
initial_sentence_emb, initial_token_embs = get_sentence_embeddings(
112+
initial_output_sentence_emb, initial_output_token_embs = get_sentence_embeddings(
113113
original_output_choice.message.content, model, tokenizer
114114
)
115-
perturbed_sentence_emb, perturbed_token_embs = get_sentence_embeddings(
116-
perturbed_output_choice.message.content, model, tokenizer
115+
perturbed_output_sentence_emb, perturbed_output_token_embs = (
116+
get_sentence_embeddings(
117+
perturbed_output_choice.message.content, model, tokenizer
118+
)
117119
)
118120

119121
# Reshape embeddings
120-
initial_sentence_emb = initial_sentence_emb.reshape(1, -1)
121-
perturbed_sentence_emb = perturbed_sentence_emb.reshape(1, -1)
122+
initial_output_sentence_emb = initial_output_sentence_emb.reshape(1, -1)
123+
perturbed_output_sentence_emb = perturbed_output_sentence_emb.reshape(1, -1)
122124

123125
# Calculate similarities
124126
self_similarity = float(
125-
cosine_similarity(initial_sentence_emb, initial_sentence_emb)
127+
cosine_similarity(initial_output_sentence_emb, initial_output_sentence_emb)
126128
)
127129
sentence_similarity = float(
128-
cosine_similarity(initial_sentence_emb, perturbed_sentence_emb)
130+
cosine_similarity(initial_output_sentence_emb, perturbed_output_sentence_emb)
129131
)
130132

131133
# Calculate token similarities for shared length
132-
shared_length = min(initial_token_embs.shape[0], perturbed_token_embs.shape[0])
134+
shared_length = min(
135+
initial_output_token_embs.shape[0], perturbed_output_token_embs.shape[0]
136+
)
133137
token_similarities_shared = cosine_similarity(
134-
initial_token_embs[:shared_length], perturbed_token_embs[:shared_length]
138+
initial_output_token_embs[:shared_length],
139+
perturbed_output_token_embs[:shared_length],
135140
).diagonal()
136141

137142
# Pad token similarities to match initial token embeddings shape
138143
token_similarities = np.pad(
139-
token_similarities_shared, (0, initial_token_embs.shape[0] - shared_length)
144+
token_similarities_shared,
145+
(0, initial_output_token_embs.shape[0] - shared_length),
140146
)
141147

142148
# Return difference in sentence similarity and token similarities

attribution/experiment_logger.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self, experiment_id=0):
1515
"original_input",
1616
"original_output",
1717
"perturbation_strategy",
18+
"perturb_word_wise",
1819
"duration",
1920
]
2021
)
@@ -51,7 +52,11 @@ def __init__(self, experiment_id=0):
5152
)
5253

5354
def start_experiment(
54-
self, original_input: str, original_output: str, perturbation_strategy: str
55+
self,
56+
original_input: str,
57+
original_output: str,
58+
perturbation_strategy: str,
59+
perturb_word_wise: bool,
5560
):
5661
self.experiment_id += 1
5762
self.experiment_start_time = time.time()
@@ -60,6 +65,7 @@ def start_experiment(
6065
"original_input": original_input,
6166
"original_output": original_output,
6267
"perturbation_strategy": perturbation_strategy,
68+
"perturb_word_wise": perturb_word_wise,
6369
"duration": None,
6470
}
6571

@@ -140,11 +146,15 @@ def print_sentence_attribution(self):
140146
perturbation_strategy = self.df_experiments.loc[
141147
self.df_experiments["exp_id"] == exp_id, "perturbation_strategy"
142148
].values[0]
149+
perturb_word_wise = self.df_experiments.loc[
150+
self.df_experiments["exp_id"] == exp_id, "perturb_word_wise"
151+
].values[0]
143152

144153
sentence_data = {
145154
"exp_id": exp_id,
146155
"attribution_strategy": attr_strat,
147156
"perturbation_strategy": perturbation_strategy,
157+
"perturb_word_wise": perturb_word_wise,
148158
}
149159
sentence_data.update(
150160
{f"token_{i+1}": token_attr for i, token_attr in enumerate(token_attrs)}

0 commit comments

Comments
 (0)