Skip to content

Commit 5e9540f

Browse files
make agnostic across huggingface models
1 parent e49e4e9 commit 5e9540f

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

README.md

+12-1
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
5353
if not isinstance(model, transformers.GPT2LMHeadModel):
5454
raise ValueError("model not found")
5555

56+
embeddings = model.transformer.wte.weight.detach()
5657
model.eval()
5758

58-
attributor = Attributor(model=model, tokenizer=tokenizer)
59+
attributor = Attributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
5960
attr_scores, token_ids = attributor.get_attributions(
6061
input_string="the five continents are asia, europe, afri",
6162
generation_length=7,
@@ -71,6 +72,16 @@ attributor.print_attributions(
7172

7273
You can run this script with `example.py`.
7374

75+
### Limitations
76+
77+
This library only supports models that have a common interface to pass in embeddings, and generate outputs without sampling of the form:
78+
79+
```python
80+
outputs = model(inputs_embeds=input_embeddings)
81+
```
82+
83+
This format is common across HuggingFace models.
84+
7485
### GPU Acceleration
7586

7687
To run the attribution process on a device of your choice, pass the device identifier into the `Attributor` class constructor:

attribution/attribution.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,39 @@
11
import gc
22
import logging
3-
from typing import Optional, Tuple
3+
from typing import Optional, Tuple, cast
44
import torch
5+
from torch import nn
56
import transformers
67
from attribution.visualization import RichTablePrinter
78

89

910
class Attributor:
1011
device: str
11-
model: transformers.GPT2LMHeadModel
12+
model: nn.Module
1213
tokenizer: transformers.PreTrainedTokenizerBase
14+
embeddings: torch.Tensor
1315

1416
def __init__(
1517
self,
16-
model: transformers.GPT2LMHeadModel,
18+
model: nn.Module,
1719
tokenizer: transformers.PreTrainedTokenizerBase,
20+
embeddings: torch.Tensor,
1821
device: Optional[str] = None,
1922
log_level: int = logging.WARNING,
2023
):
2124
logging.basicConfig(level=log_level)
2225

2326
if device is None:
24-
device = model.device.type
27+
if model.device:
28+
device = cast(str, model.device.type)
29+
else:
30+
device = "cpu"
2531

2632
logging.info(f"Using device: {device}")
2733
self.device = device
2834

2935
self.model = model
36+
self.embeddings = embeddings
3037
self.tokenizer = tokenizer
3138

3239
def get_attributions(
@@ -58,13 +65,12 @@ def get_attributions(
5865
token_ids: torch.Tensor = torch.tensor(
5966
self.tokenizer(input_string).input_ids
6067
).to(self.device)
61-
embeddings: torch.Tensor = self.model.transformer.wte.weight.detach()
6268
input_length: int = token_ids.shape[0]
6369

6470
attr_scores = torch.zeros(generation_length, generation_length + len(token_ids))
6571

6672
for it in range(generation_length):
67-
input_embeddings = self._get_input_embeddings(embeddings, token_ids)
73+
input_embeddings = self._get_input_embeddings(self.embeddings, token_ids)
6874
output = self.model(inputs_embeds=input_embeddings)
6975

7076
gen_tokens, next_token_id = self._generate_tokens(
@@ -126,7 +132,7 @@ def _get_input_embeddings(
126132

127133
def _generate_tokens(
128134
self,
129-
model: transformers.GPT2LMHeadModel,
135+
model: nn.Module,
130136
token_ids: torch.Tensor,
131137
tokenizer: transformers.PreTrainedTokenizerBase,
132138
):
@@ -166,7 +172,7 @@ def _get_attr_scores_next_token(
166172

167173
def _validate_inputs(
168174
self,
169-
model: transformers.GPT2LMHeadModel,
175+
model: nn.Module,
170176
tokenizer: transformers.PreTrainedTokenizerBase,
171177
input_string: str,
172178
generation_length: int,

example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
if not isinstance(model, transformers.GPT2LMHeadModel):
1010
raise ValueError("model not found")
1111

12+
embeddings = model.transformer.wte.weight.detach()
1213
model.eval()
1314

14-
attributor = Attributor(model=model, tokenizer=tokenizer)
15+
attributor = Attributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
1516
attr_scores, token_ids = attributor.get_attributions(
1617
input_string="the five continents are asia, europe, afri",
1718
generation_length=7,

tests/test_attribution.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def model():
1313
return model
1414

1515

16+
@pytest.fixture
17+
def embeddings(model):
18+
return model.transformer.wte.weight.detach()
19+
20+
1621
@pytest.fixture
1722
def tokenizer():
1823
tokenizer = transformers.GPT2Tokenizer.from_pretrained(
@@ -23,8 +28,8 @@ def tokenizer():
2328

2429

2530
@pytest.fixture
26-
def attributor(model, tokenizer):
27-
return Attributor(model=model, tokenizer=tokenizer)
31+
def attributor(model, embeddings, tokenizer):
32+
return Attributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
2833

2934

3035
def test_get_input_embeddings(attributor, model, tokenizer):

0 commit comments

Comments
 (0)