You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: README.md
+170-68
Original file line number
Diff line number
Diff line change
@@ -1,28 +1,85 @@
1
1
# LLM Attribution Library
2
2
3
-
The LLM Attribution Library is a Python package designed to compute the attributions of each token in an input string to the generated tokens in a language model. This is particularly useful for understanding the influence of specific input tokens on the output of a language model.
The library uses gradient-based attribution to quantify the influence of input tokens on the output of a GPT-2 model. For each output token, it computes the gradients with respect to the input embeddings. The L1 norm of these gradients is then used as the attribution score, representing the total influence of each input token on the output. This approach provides a direct measure of the sensitivity of the output to changes in the input, aiding in model interpretation and diagnosis.
3
+
The LLM Attribution Library is designed to compute the contribution of each token in a prompt to the generated response of a language model.
Usage examples can be found in the `examples/` folder.
135
+
## API Design
136
+
137
+
The attributors are designed to compute the contribution made by each token in an input string to the tokens generated by a language model.
80
138
81
-
The following shows a simple example of attrubution using gemma-2b:
139
+
### BaseLLMAttributor
140
+
141
+
[BaseLLMAttributor] is an abstract base class that defines the interface for all LLM attributors. It declares the `compute_attributions` method, which must be implemented by any concrete attributor class. This method takes an input text and computes the attribution scores for each token.
82
142
83
143
```python
84
-
from transformers import AutoModelForCausalLM, AutoTokenizer
144
+
classBaseLLMAttributor(ABC):
145
+
@abstractmethod
146
+
defcompute_attributions(
147
+
self, input_text: str, **kwargs
148
+
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
149
+
pass
150
+
```
85
151
86
-
from attribution.attribution import Attributor
152
+
### LocalLLMAttributor
87
153
88
-
model_id ="google/gemma-2b-it"
89
-
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto").cuda()
The `compute_attributions` method generates tokens from the input string and computes the gradients of the output with respect to the input embeddings. These gradients are used to compute the attribution scores.
174
+
175
+
`LocalLLMAttributor` uses gradient-based attribution to quantify the influence of input tokens on the output of a model. For each output token, it computes the gradients with respect to the input embeddings. The L1 norm of these gradients is then used as the attribution score, representing the total influence of each input token on the output
176
+
177
+
178
+
#### Cleaning Up
179
+
180
+
A convenience method is provided to clean up memory used by Python and Torch. This can be useful when running the library in a cloud notebook environment:
181
+
182
+
```python
183
+
local_attributor.cleanup()
105
184
```
106
185
107
-
### Limitations
186
+
### APILLMAttributor
187
+
188
+
`APILLMAttributor` uses the OpenAI API to compute attributions. Given that gradients are not accessible, the attributor perturbs the input with a given `PerturbationStrategy` and measures the magnitude of change of the generated output with an `attribution_strategy`.
189
+
190
+
The `compute_attributions` method:
191
+
1. Sends a chat completion request to the OpenAI API.
192
+
2. Uses a `PerturbationStrategy` to modify the input prompt, and sends the perturbed input to OpenAI's API to generate a perturbed output. Each token of the input prompt is perturbed separately, to obtain an attribution score for each input token.
193
+
3. Uses an `attribution_strategy` to compute the magnitude of change between the original and perturbed output.
194
+
4. Logs attribution scores to an `ExperimentLogger` if passed.
`PerturbationStrategy` is an abstract base class that defines the interface for all perturbation strategies. It declares the `get_replacement_token` method, which must be implemented by any concrete perturbation strategy class. This method takes a token id and returns a replacement token id.
212
+
213
+
The `attribution_strategy` parameter is a string that specifies the method to use for computing attributions. The available strategies are "cosine", "prob_diff", and "token_displacement".
214
+
215
+
-**Cosine Similarity Attribution**: Measures the cosine similarity between the embeddings of the original and perturbed outputs. The embeddings are obtained from a pre-trained model (e.g. GPT2). The cosine similarity is calculated for each pair of tokens in the same position on the original and perturbed outputs. For example, it compares the token in position 0 in the original response to the token in position 0 in the perturbed response. Additionally, the cosine similarity of the entire sentence embeddings is computed. The difference in sentence similarity and token similarities are returned.
108
216
109
-
#### Batch dimensions
217
+
-**Probability Difference Attribution**: Calculates the absolute difference in probabilities for each token in the original and perturbed outputs. The probabilities are obtained from the `top_logprobs` field of the tokens, which contains the most likely tokens for each token position. The mean of these differences is returned, as well as the probability difference for each token position.
218
+
219
+
-**Token Displacement Attribution**: Calculates the displacement of each token in the original output within the perturbed output's `top_logprobs` predicted tokens. The `top_logprobs` field contains the most likely tokens for each token position. If a token from the original output is not found in the `top_logprobs` of the perturbed output, a maximum displacement value is assigned. The mean of these displacements is returned, as well as the displacement of each original output token.
220
+
221
+
222
+
### ExperimentLogger
223
+
224
+
The `ExperimentLogger` class is used to log the results of different experiment runs. It provides methods for starting and stopping an experiment, logging the input and output tokens, and logging the attribution scores. The `api_llm_attribution.ipynb` notebook shows an example of how to use `ExperimentLogger` to compare the results of different attribution strategies.
225
+
226
+
227
+
## Limitations
228
+
229
+
### Batch dimensions
110
230
111
231
Currently this library only supports models that take inputs with a batch dimension. This is common across most modern models, but not always the case (e.g. GPT2).
112
232
113
-
####Input Embeddings
233
+
### Input Embeddings
114
234
115
235
This library only supports models that have a common interface to pass in embeddings, and generate outputs without sampling of the form:
To run the attribution process on a device of your choice, pass the device identifier into the `Attributor` class constructor:
126
246
@@ -136,27 +256,6 @@ The device identifider must match the device used on the first embeddings layer
136
256
137
257
If no device is specified, the model device will be used by default.
138
258
139
-
### Logging
140
-
141
-
The library uses the `logging` module to log messages. You can configure the logging level via an optional argument in the `Attributor` class constructor:
142
-
143
-
```python
144
-
import logging
145
-
146
-
attributor = Attributor(
147
-
model=model,
148
-
tokenizer=tokenizer,
149
-
log_level=logging.INFO
150
-
)
151
-
```
152
-
153
-
### Cleaning Up
154
-
155
-
A convenience method is provided to clean up memory used by Python and Torch. This can be useful when running the library in a cloud notebook environment:
156
-
157
-
```python
158
-
attributor.cleanup()
159
-
```
160
259
161
260
## Development
162
261
@@ -181,3 +280,6 @@ To run the integration tests:
181
280
```bash
182
281
python -m pytest tests/integration
183
282
```
283
+
284
+
## Research
285
+
Some preliminary exploration and research into using attribution and quantitatively measuring attribution success can be found in the research folder of this repository. We'd be excited to see expansion of this small library, including both algorithmic improvements, further attribution and perturbation methods, and more rigorous and exhaustive experimentation. We welcome pull requests and issues from external collaborators.
0 commit comments