Skip to content

Commit 209a0ed

Browse files
authored
Merge pull request #773 from dianna-ai/fix-lime-text-special-chars
Run text model inputs one-by-one through model to avoid shape mismatch errors
2 parents 780f1c3 + 2fff22f commit 209a0ed

File tree

5 files changed

+56
-55
lines changed

5 files changed

+56
-55
lines changed

dianna/dashboard/_movie_model.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,26 @@ def __call__(self, sentences):
2626
if isinstance(sentences, str):
2727
sentences = [sentences]
2828

29-
tokenized_sentences = [
30-
self.tokenize(sentence) for sentence in sentences
31-
]
29+
output = []
30+
for sentence in sentences:
31+
# tokenize and pad to minimum length
32+
tokens = self.tokenizer.tokenize(sentence.lower())
33+
if len(tokens) < self.max_filter_size:
34+
tokens += ['<pad>'] * (self.max_filter_size - len(tokens))
3235

33-
expected_length = len(tokenized_sentences[0])
34-
if not all(
35-
len(tokens) == expected_length
36-
for tokens in tokenized_sentences):
37-
raise ValueError(
38-
'Mismatch in length of tokenized sentences.'
39-
'This is a problem in the tokenizer:'
40-
'https://github.com/dianna-ai/dianna/issues/531', )
36+
# numericalize the tokens
37+
tokens_numerical = [
38+
self.vocab.stoi[token]
39+
if token in self.vocab.stoi else self.vocab.stoi['<unk>']
40+
for token in tokens
41+
]
4142

42-
# run the model, applying a sigmoid because the model outputs logits
43-
logits = self.run_model(tokenized_sentences)
44-
pred = np.apply_along_axis(sigmoid, 1, logits)
43+
# run the model, applying a sigmoid because the model outputs logits, remove any remaining batch axis
44+
pred = float(sigmoid(self.run_model([tokens_numerical])))
45+
output.append(pred)
4546

46-
# output pos/neg
47-
positivity = pred[:, 0]
47+
# output two classes
48+
positivity = np.array(output)
4849
negativity = 1 - positivity
4950
return np.transpose([negativity, positivity])
5051

tests/methods/test_lime_text.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def test_lime_text(self):
3131
def test_lime_text_special_chars(self):
3232
"""Tests exact expected output given a text with special characters and model for Lime."""
3333
review = 'such a bad movie "!?\'"'
34-
expected_words = ['bad', '?', '!', 'movie', 'such', 'a', "'", '"', '"']
35-
expected_word_indices = [2, 6, 5, 3, 0, 1, 7, 4, 8]
34+
expected_words = ['bad', 'movie', '?', 'such', '!', "'", '"', 'a', '"']
35+
expected_word_indices = [2, 3, 6, 0, 5, 7, 8, 1, 4]
3636
expected_scores = [
37-
0.50032869, 0.06458735, -0.05793979, 0.01413776, -0.01246357,
38-
-0.00528022, 0.00305347, 0.00185159, -0.00165128
37+
0.51140699, 0.02827488, 0.02657974, -0.02208464, -0.02140743,
38+
0.00962419, 0.00746798, -0.00743376, -0.0012061
3939
]
4040

4141
explanation = dianna.explain_text(self.runner,
@@ -44,7 +44,7 @@ def test_lime_text_special_chars(self):
4444
labels=[0],
4545
method='LIME',
4646
random_state=42)[0]
47-
47+
print(explanation)
4848
assert_explanation_satisfies_expectations(explanation, expected_scores,
4949
expected_word_indices,
5050
expected_words)

tests/utils.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, model_path, word_vector_file, max_filter_size):
8080
self.max_filter_size = max_filter_size
8181

8282
def __call__(self, sentences):
83-
"""Call function."""
83+
"""Call Runner."""
8484
# ensure the input has a batch axis
8585
if isinstance(sentences, str):
8686
sentences = [sentences]
@@ -89,26 +89,28 @@ def __call__(self, sentences):
8989
input_name = sess.get_inputs()[0].name
9090
output_name = sess.get_outputs()[0].name
9191

92-
tokenized_sentences = [
93-
self.tokenize(sentence) for sentence in sentences
94-
]
95-
96-
expected_length = len(tokenized_sentences[0])
97-
if not all(
98-
len(tokens) == expected_length
99-
for tokens in tokenized_sentences):
100-
raise ValueError(
101-
'Mismatch in length of tokenized sentences.'
102-
'This is a problem in the tokenizer:'
103-
'https://github.com/dianna-ai/dianna/issues/531', )
104-
105-
# run the model, applying a sigmoid because the model outputs logits
106-
onnx_input = {input_name: tokenized_sentences}
107-
logits = sess.run([output_name], onnx_input)[0]
108-
pred = np.apply_along_axis(sigmoid, 1, logits)
109-
110-
# output pos/neg
111-
positivity = pred[:, 0]
92+
output = []
93+
for sentence in sentences:
94+
# tokenize and pad to minimum length
95+
tokens = self.tokenizer.tokenize(sentence.lower())
96+
if len(tokens) < self.max_filter_size:
97+
tokens += ['<pad>'] * (self.max_filter_size - len(tokens))
98+
99+
# numericalize the tokens
100+
tokens_numerical = [
101+
self.vocab.stoi[token]
102+
if token in self.vocab.stoi else self.vocab.stoi['<unk>']
103+
for token in tokens
104+
]
105+
106+
# run the model, applying a sigmoid because the model outputs logits, remove any remaining batch axis
107+
onnx_input = {input_name: [tokens_numerical]}
108+
logits = sess.run([output_name], onnx_input)[0]
109+
pred = float(sigmoid(logits))
110+
output.append(pred)
111+
112+
# output two classes
113+
positivity = np.array(output)
112114
negativity = 1 - positivity
113115
return np.transpose([negativity, positivity])
114116

tutorials/explainers/LIME/lime_text.ipynb

+10-12
Original file line numberDiff line numberDiff line change
@@ -187,18 +187,18 @@
187187
"source": [
188188
"class MovieReviewsModelRunner:\n",
189189
" def __init__(self, model, word_vectors, max_filter_size):\n",
190-
" self.run_model = utils.get_function(str(model))\n",
190+
" self.run_model = utils.get_function(model)\n",
191191
" self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))\n",
192192
" self.max_filter_size = max_filter_size\n",
193193
" \n",
194-
" self.tokenizer = SpacyTokenizer(name='en_core_web_sm')\n",
194+
" self.tokenizer = SpacyTokenizer(name='en_core_web_sm')\n",
195195
"\n",
196196
" def __call__(self, sentences):\n",
197197
" # ensure the input has a batch axis\n",
198198
" if isinstance(sentences, str):\n",
199199
" sentences = [sentences]\n",
200200
"\n",
201-
" tokenized_sentences = []\n",
201+
" output = []\n",
202202
" for sentence in sentences:\n",
203203
" # tokenize and pad to minimum length\n",
204204
" tokens = self.tokenizer.tokenize(sentence.lower())\n",
@@ -208,17 +208,15 @@
208208
" # numericalize the tokens\n",
209209
" tokens_numerical = [self.vocab.stoi[token] if token in self.vocab.stoi else self.vocab.stoi['<unk>']\n",
210210
" for token in tokens]\n",
211-
" tokenized_sentences.append(tokens_numerical)\n",
212-
" \n",
213-
" # run the model, applying a sigmoid because the model outputs logits\n",
214-
" logits = self.run_model(tokenized_sentences)\n",
215-
" pred = np.apply_along_axis(sigmoid, 1, logits)\n",
216-
" \n",
211+
"\n",
212+
" # run the model, applying a sigmoid because the model outputs logits, remove any remaining batch axis\n",
213+
" pred = float(sigmoid(self.run_model([tokens_numerical])))\n",
214+
" output.append(pred)\n",
215+
"\n",
217216
" # output two classes\n",
218-
" positivity = pred[:, 0]\n",
217+
" positivity = np.array(output)\n",
219218
" negativity = 1 - positivity\n",
220-
" return np.transpose([negativity, positivity])\n",
221-
" "
219+
" return np.transpose([negativity, positivity]) "
222220
]
223221
},
224222
{

tutorials/explainers/RISE/rise_text.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@
169169
" output = []\n",
170170
" for sentence in sentences:\n",
171171
" # tokenize and pad to minimum length\n",
172-
" tokens = self.tokenizer.tokenize(sentence)\n",
172+
" tokens = self.tokenizer.tokenize(sentence.lower())\n",
173173
" if len(tokens) < self.max_filter_size:\n",
174174
" tokens += ['<pad>'] * (self.max_filter_size - len(tokens))\n",
175175
" \n",

0 commit comments

Comments
 (0)