Skip to content

Commit b2c1201

Browse files
author
Sebastian Sosa
committed
Fix unit tests
1 parent ef209a3 commit b2c1201

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/integration/test_attribution.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import transformers
44

5-
from attribution.local_attribution import Attributor
5+
from attribution.local_attribution import LocalLLMAttributor
66

77

88
@pytest.fixture
@@ -28,12 +28,12 @@ def tokenizer():
2828

2929
@pytest.fixture
3030
def attributor(model, embeddings, tokenizer):
31-
return Attributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
31+
return LocalLLMAttributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
3232

3333

3434
def test_attribution(model, embeddings, tokenizer):
35-
attributor = Attributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
36-
attr_scores, token_ids = attributor.get_attributions(
35+
attributor = LocalLLMAttributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
36+
attr_scores, token_ids = attributor.compute_attributions(
3737
input_string="the five continents are asia, europe, afri",
3838
generation_length=7,
3939
)

tests/unit/test_attribution.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch import nn
66

7-
from attribution.local_attribution import Attributor
7+
from attribution.local_attribution import LocalLLMAttributor
88

99

1010
@pytest.fixture
@@ -20,7 +20,7 @@ def attributor():
2020
tokenizer = Mock()
2121
tokenizer.encode.return_value = [0, 1, 2, 3]
2222

23-
return Attributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
23+
return LocalLLMAttributor(model=model, embeddings=embeddings, tokenizer=tokenizer)
2424

2525

2626
def test_get_input_embeddings(attributor):

0 commit comments

Comments
 (0)