Skip to content

Commit 1a0020b

Browse files
authored
Merge pull request #327 from dianna-ai/fix-text-dashboard
Fix text visualization in dashboard
2 parents 1186563 + 0aff6c3 commit 1a0020b

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

dashboard/callbacks.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
folder_on_server = "app_data"
3030
os.makedirs(folder_on_server, exist_ok=True)
31+
tokenizer = SpacyTokenizer() # for now always use SpacyTokenizer, needs to be changed
3132

3233
# Build App
3334
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
@@ -376,7 +377,6 @@ def global_store_t(method_sel, model_runner, input_text):
376377
labels = tuple(class_name_text)
377378
pred_idx = labels.index(pred_class)
378379

379-
tokenizer = SpacyTokenizer() # for now always use SpacyTokenizer, needs to be changed
380380

381381
# expensive query
382382
relevances = dianna.explain_text(
@@ -450,6 +450,7 @@ def update_multi_options_t(fn_m, input_text, sel_methods, new_model, new_text):
450450
model_runner = MovieReviewsModelRunner(onnx_model_path, word_vector_path, max_filter_size=5)
451451

452452
try:
453+
input_tokens = tokenizer.tokenize(input_text)
453454
predictions = model_runner(input_text)
454455
class_name = class_name_text
455456
pred_class = class_name[np.argmax(predictions)]
@@ -463,7 +464,7 @@ def update_multi_options_t(fn_m, input_text, sel_methods, new_model, new_text):
463464
relevances_lime = global_store_t(
464465
m, model_runner, input_text)
465466

466-
output = _create_html(input_text, relevances_lime[0], max_opacity=0.8)
467+
output = _create_html(input_tokens, relevances_lime[0], max_opacity=0.8)
467468
hti = Html2Image()
468469
expl_path = 'text_expl.jpg'
469470

@@ -493,7 +494,7 @@ def update_multi_options_t(fn_m, input_text, sel_methods, new_model, new_text):
493494
relevances_rise = global_store_t(
494495
m, model_runner, input_text)
495496

496-
output = _create_html(input_text, relevances_rise[0], max_opacity=0.8)
497+
output = _create_html(input_tokens, relevances_rise[0], max_opacity=0.8)
497498
hti = Html2Image()
498499
expl_path = 'text_expl.jpg'
499500

dashboard/utilities.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,24 @@ def preprocess_function(image):
117117
"""For LIME: we divided the input data by 256 for the model (binary mnist) and LIME needs RGB values."""
118118
return (image / 256).astype(np.float32)
119119

120-
121-
def _create_html(original_text, explanation, max_opacity):
120+
def _create_html(input_tokens, explanation, max_opacity):
122121
"""Creates text explaination map using html format."""
123122
max_importance = max(abs(item[2]) for item in explanation)
124-
body = original_text
125-
words_in_reverse_order = sorted(explanation, key=lambda item: item[1], reverse=True)
126-
for word, word_start, importance in words_in_reverse_order:
127-
word_end = word_start + len(word)
128-
highlighted_word = _highlight_word(word, importance, max_importance, max_opacity)
129-
body = body[:word_start] + highlighted_word + body[word_end:]
130-
return '<html><body>' + body + '</body></html>'
123+
explained_indices = [index for _, index, _ in explanation]
124+
highlighted_words = []
125+
for index, word in enumerate(input_tokens):
126+
# if word has an explanation, highlight based on that, otherwise
127+
# make it grey
128+
try:
129+
explained_index = explained_indices.index(index)
130+
importance = explanation[explained_index][2]
131+
highlighted_words.append(
132+
_highlight_word(word, importance, max_importance, max_opacity)
133+
)
134+
except ValueError:
135+
highlighted_words.append(f'<span style="background:rgba(128, 128, 128, 0.3)">{word}</span>')
136+
137+
return '<html><body>' + ' '.join(highlighted_words) + '</body></html>'
131138

132139

133140
def _highlight_word(word, importance, max_importance, max_opacity):

0 commit comments

Comments
 (0)