|
1 | 1 | import gc
|
2 | 2 | import logging
|
3 |
| -from typing import Optional, Tuple |
| 3 | +from typing import Optional, Tuple, cast |
4 | 4 | import torch
|
| 5 | +from torch import nn |
5 | 6 | import transformers
|
6 | 7 | from attribution.visualization import RichTablePrinter
|
7 | 8 |
|
8 | 9 |
|
9 | 10 | class Attributor:
|
10 | 11 | device: str
|
11 |
| - model: transformers.GPT2LMHeadModel |
| 12 | + model: nn.Module |
12 | 13 | tokenizer: transformers.PreTrainedTokenizerBase
|
| 14 | + embeddings: torch.Tensor |
13 | 15 |
|
14 | 16 | def __init__(
|
15 | 17 | self,
|
16 |
| - model: transformers.GPT2LMHeadModel, |
| 18 | + model: nn.Module, |
17 | 19 | tokenizer: transformers.PreTrainedTokenizerBase,
|
| 20 | + embeddings: torch.Tensor, |
18 | 21 | device: Optional[str] = None,
|
19 | 22 | log_level: int = logging.WARNING,
|
20 | 23 | ):
|
21 | 24 | logging.basicConfig(level=log_level)
|
22 | 25 |
|
23 | 26 | if device is None:
|
24 |
| - device = model.device.type |
| 27 | + if model.device: |
| 28 | + device = cast(str, model.device.type) |
| 29 | + else: |
| 30 | + device = "cpu" |
25 | 31 |
|
26 | 32 | logging.info(f"Using device: {device}")
|
27 | 33 | self.device = device
|
28 | 34 |
|
29 | 35 | self.model = model
|
| 36 | + self.embeddings = embeddings |
30 | 37 | self.tokenizer = tokenizer
|
31 | 38 |
|
32 | 39 | def get_attributions(
|
@@ -58,13 +65,12 @@ def get_attributions(
|
58 | 65 | token_ids: torch.Tensor = torch.tensor(
|
59 | 66 | self.tokenizer(input_string).input_ids
|
60 | 67 | ).to(self.device)
|
61 |
| - embeddings: torch.Tensor = self.model.transformer.wte.weight.detach() |
62 | 68 | input_length: int = token_ids.shape[0]
|
63 | 69 |
|
64 | 70 | attr_scores = torch.zeros(generation_length, generation_length + len(token_ids))
|
65 | 71 |
|
66 | 72 | for it in range(generation_length):
|
67 |
| - input_embeddings = self._get_input_embeddings(embeddings, token_ids) |
| 73 | + input_embeddings = self._get_input_embeddings(self.embeddings, token_ids) |
68 | 74 | output = self.model(inputs_embeds=input_embeddings)
|
69 | 75 |
|
70 | 76 | gen_tokens, next_token_id = self._generate_tokens(
|
@@ -126,7 +132,7 @@ def _get_input_embeddings(
|
126 | 132 |
|
127 | 133 | def _generate_tokens(
|
128 | 134 | self,
|
129 |
| - model: transformers.GPT2LMHeadModel, |
| 135 | + model: nn.Module, |
130 | 136 | token_ids: torch.Tensor,
|
131 | 137 | tokenizer: transformers.PreTrainedTokenizerBase,
|
132 | 138 | ):
|
@@ -166,7 +172,7 @@ def _get_attr_scores_next_token(
|
166 | 172 |
|
167 | 173 | def _validate_inputs(
|
168 | 174 | self,
|
169 |
| - model: transformers.GPT2LMHeadModel, |
| 175 | + model: nn.Module, |
170 | 176 | tokenizer: transformers.PreTrainedTokenizerBase,
|
171 | 177 | input_string: str,
|
172 | 178 | generation_length: int,
|
|
0 commit comments