Skip to content

Commit a77e66b

Browse files
authored
Merge pull request #799 from dianna-ai/738-text-lime-in-dashboard-returns-an-error
738 text lime in dashboard returns an error
2 parents ae3612a + 82c0028 commit a77e66b

File tree

4 files changed

+34
-17
lines changed

4 files changed

+34
-17
lines changed

dianna/dashboard/_models_image.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import tempfile
22
import streamlit as st
3-
from _model_utils import fill_segmentation
43
from _model_utils import preprocess_function
54
from onnx_tf.backend import prepare
65
from dianna import explain_image
@@ -42,12 +41,11 @@ def _run_kernelshap_image(model, image, i, **kwargs):
4241
with tempfile.NamedTemporaryFile() as f:
4342
f.write(model)
4443
f.flush()
45-
shap_values, segments_slic = explain_image(f.name,
46-
image,
47-
method='KernelSHAP',
48-
**kwargs)
49-
50-
return fill_segmentation(shap_values[i][0], segments_slic)
44+
relevances = explain_image(f.name,
45+
image,
46+
method='KernelSHAP',
47+
**kwargs)
48+
return relevances[0]
5149

5250

5351
explain_image_dispatcher = {

dianna/dashboard/_shared.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _get_params(method: str):
9090

9191
elif method == 'LIME':
9292
return {
93-
'rand_state': st.number_input('Random state', value=2),
93+
'random_state': st.number_input('Random state', value=2),
9494
}
9595

9696
else:

dianna/methods/lime_timeseries.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
verbose=False,
2424
preprocess_function=None,
2525
feature_selection='auto',
26+
random_state = None
2627
):
2728
"""Initializes Lime explainer for timeseries.
2829
@@ -32,6 +33,7 @@ def __init__(
3233
feature_selection (str): Feature selection method to be used by explainer.
3334
preprocess_function (callable, optional): Function to preprocess the time series data before passing it
3435
to the explainer. Defaults to None.
36+
random_state (int or np.RandomState, optional): seed or random state. Unused variable for current ts method
3537
"""
3638

3739
def kernel(d):

tests/test_dashboard.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,24 @@ def test_text_page(page: Page):
103103
expect(page.get_by_text('Select a method to continue')).to_be_visible()
104104

105105
page.locator('label').filter(has_text='RISE').locator('span').click()
106+
page.locator('label').filter(has_text='LIME').locator('span').click()
106107

107-
page.get_by_text('Running...').wait_for(state='detached', timeout=45_000)
108+
page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
108109

109110
for selector in (
110111
page.get_by_role('heading', name='RISE').get_by_text('RISE'),
111-
# first image
112+
page.get_by_role('heading', name='LIME').get_by_text('LIME'),
113+
# Images for positive (RISE/LIME)
112114
page.get_by_role('heading',
113115
name='positive').get_by_text('positive'),
114116
page.get_by_role('img', name='0').first,
115-
# second image
117+
page.get_by_role('img', name='0').nth(1),
118+
119+
# Images for negative (RISE/LIME)
116120
page.get_by_role('heading',
117121
name='negative').get_by_text('negative'),
118-
page.get_by_role('img', name='0').nth(1),
122+
page.get_by_role('img', name='0').nth(2),
123+
page.get_by_role('img', name='0').nth(3),
119124
):
120125
print(selector)
121126
expect(selector).to_be_visible()
@@ -131,27 +136,35 @@ def test_image_page(page: Page):
131136

132137
expect(
133138
page.get_by_text('Add your input data in the left panel to continue')
134-
).to_be_visible(timeout=30_000)
139+
).to_be_visible(timeout=100_000)
135140

136141
page.locator('label').filter(
137142
has_text='Load example data').locator('span').click()
138143

139144
expect(page.get_by_text('Select a method to continue')).to_be_visible()
140145

141146
page.locator('label').filter(has_text='RISE').locator('span').click()
147+
page.locator('label').filter(has_text='KernelSHAP').locator('span').click()
148+
page.locator('label').filter(has_text='LIME').locator('span').click()
142149

143150
page.get_by_text('Running...').wait_for(state='detached', timeout=45_000)
144151

145152
for selector in (
146153
page.get_by_role('heading', name='RISE').get_by_text('RISE'),
154+
page.get_by_role('heading', name='KernelSHAP').get_by_text('KernelSHAP'),
155+
page.get_by_role('heading', name='LIME').get_by_text('LIME'),
147156
# first image
148157
page.get_by_role('heading', name='0').get_by_text('0'),
149158
page.get_by_role('img', name='0').first,
159+
page.get_by_role('img', name='0').nth(1),
160+
page.get_by_role('img', name='0').nth(2),
150161
# second image
151162
page.get_by_role('heading', name='1').get_by_text('1'),
152-
page.get_by_role('img', name='0').nth(1),
163+
page.get_by_role('img', name='0').nth(3),
164+
page.get_by_role('img', name='0').nth(4),
165+
page.get_by_role('img', name='0').nth(5),
153166
):
154-
expect(selector).to_be_visible(timeout=45_000)
167+
expect(selector).to_be_visible(timeout=100_000)
155168

156169

157170
def test_timeseries_page(page: Page):
@@ -171,17 +184,21 @@ def test_timeseries_page(page: Page):
171184

172185
expect(page.get_by_text('Select a method to continue')).to_be_visible()
173186

187+
page.locator('label').filter(has_text='LIME').locator('span').click()
174188
page.locator('label').filter(has_text='RISE').locator('span').click()
175189

176-
page.get_by_text('Running...').wait_for(state='detached', timeout=45_000)
190+
page.get_by_text('Running...').wait_for(state='detached', timeout=100_000)
177191

178192
for selector in (
193+
page.get_by_role('heading', name='LIME').get_by_text('LIME'),
179194
page.get_by_role('heading', name='RISE').get_by_text('RISE'),
180195
# first image
181196
page.get_by_role('heading', name='winter').get_by_text('winter'),
182197
page.get_by_role('img', name='0').first,
198+
page.get_by_role('img', name='0').nth(1),
183199
# second image
184200
page.get_by_role('heading', name='summer').get_by_text('summer'),
185-
page.get_by_role('img', name='0').nth(1),
201+
page.get_by_role('img', name='0').nth(2),
202+
page.get_by_role('img', name='0').nth(3),
186203
):
187204
expect(selector).to_be_visible()

0 commit comments

Comments
 (0)