Skip to content

Commit 45fd4ff

Browse files
authored
Merge pull request #844 from dianna-ai/838-create-tabular-tab-to-dashboard
838 create tabular tab to dashboard and redesign loaded data results #819
2 parents 47dfd99 + f38fd09 commit 45fd4ff

10 files changed

+357
-99
lines changed

dianna/dashboard/Home.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import importlib
22
import streamlit as st
3-
from _shared import add_sidebar_logo
43
from _shared import data_directory
54
from streamlit_option_menu import option_menu
65

76
st.set_page_config(page_title="Dianna's dashboard",
87
page_icon='📊',
9-
layout='centered',
8+
layout='wide',
109
initial_sidebar_state='auto',
1110
menu_items={
1211
'Get help':
@@ -22,6 +21,7 @@
2221
pages = {
2322
"Home": "home",
2423
"Images": "pages.Images",
24+
"Tabular": "pages.Tabular",
2525
"Text": "pages.Text",
2626
"Time series": "pages.Time_series"
2727
}
@@ -30,16 +30,14 @@
3030
selected = option_menu(
3131
menu_title=None,
3232
options=list(pages.keys()),
33-
icons=["house", "camera", "alphabet", "clock"],
33+
icons=["house", "camera", "table", "alphabet", "clock"],
3434
menu_icon="cast",
3535
default_index=0,
3636
orientation="horizontal"
3737
)
3838

3939
# Display the content of the selected page
4040
if selected == "Home":
41-
add_sidebar_logo()
42-
4341
st.image(str(data_directory / 'logo.png'))
4442

4543
st.markdown("""
@@ -50,9 +48,10 @@
5048
5149
### Pages
5250
53-
- <a href="/Images" target="_parent">Images</a>
54-
- <a href="/Text" target="_parent">Text</a>
55-
- <a href="/Time_series" target="_parent">Time series</a>
51+
- <a href="/Images" target="_parent">Image data</a>
52+
- <a href="/Tabular" target="_parent">Tabular data</a>
53+
- <a href="/Text" target="_parent">Text data</a>
54+
- <a href="/Time_series" target="_parent">Time series data</a>
5655
5756
5857
### More information
@@ -70,6 +69,10 @@
7069
for k in st.session_state.keys():
7170
if 'Image' in k:
7271
st.session_state.pop(k, None)
72+
if selected != 'Tabular':
73+
for k in st.session_state.keys():
74+
if 'Tabular' in k:
75+
st.session_state.pop(k, None)
7376
if selected != 'Text':
7477
for k in st.session_state.keys():
7578
if 'Text' in k:

dianna/dashboard/_model_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from pathlib import Path
22
import numpy as np
33
import onnx
4+
import pandas as pd
5+
6+
7+
def load_data(file):
8+
"""Open data from a file and returns it as pandas DataFrame."""
9+
df = pd.read_csv(file, parse_dates=True)
10+
# Add index column
11+
df.insert(0, 'Index', df.index)
12+
return df
413

514

615
def preprocess_function(image):
@@ -29,3 +38,7 @@ def load_labels(file):
2938
if labels is None or labels == ['']:
3039
raise ValueError(labels)
3140
return labels
41+
42+
43+
def load_training_data(file):
44+
return np.float32(np.load(file, allow_pickle=False))

dianna/dashboard/_models_tabular.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import tempfile
2+
import numpy as np
3+
import streamlit as st
4+
from dianna import explain_tabular
5+
from dianna.utils.onnx_runner import SimpleModelRunner
6+
7+
8+
@st.cache_data
9+
def predict(*, model, tabular_input):
10+
model_runner = SimpleModelRunner(model)
11+
predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32))
12+
return predictions
13+
14+
15+
@st.cache_data
16+
def _run_rise_tabular(_model, table, training_data, **kwargs):
17+
relevances = explain_tabular(
18+
_model,
19+
table,
20+
method='RISE',
21+
training_data=training_data,
22+
**kwargs,
23+
)
24+
return relevances
25+
26+
27+
@st.cache_data
28+
def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
29+
relevances = explain_tabular(
30+
_model,
31+
table,
32+
method='LIME',
33+
training_data=training_data,
34+
feature_names=_feature_names,
35+
**kwargs,
36+
)
37+
return relevances
38+
39+
@st.cache_data
40+
def _run_kernelshap_tabular(model, table, training_data, **kwargs):
41+
# Kernelshap interface is different. Write model to temporary file.
42+
with tempfile.NamedTemporaryFile() as f:
43+
f.write(model)
44+
f.flush()
45+
relevances = explain_tabular(f.name,
46+
table,
47+
method='KernelSHAP',
48+
training_data=training_data,
49+
**kwargs)
50+
return relevances[0]
51+
52+
53+
explain_tabular_dispatcher = {
54+
'RISE': _run_rise_tabular,
55+
'LIME': _run_lime_tabular,
56+
'KernelSHAP': _run_kernelshap_tabular
57+
}

dianna/dashboard/_shared.py

+48-48
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import base64
22
import sys
3-
from typing import Any
4-
from typing import Dict
53
from typing import Sequence
64
import numpy as np
75
import streamlit as st
@@ -46,71 +44,67 @@ def build_markup_for_logo(
4644

4745

4846
def add_sidebar_logo():
49-
"""Based on: https://stackoverflow.com/a/73278825."""
50-
png_file = data_directory / 'logo.png'
51-
logo_markup = build_markup_for_logo(png_file)
52-
st.markdown(
53-
logo_markup,
54-
unsafe_allow_html=True,
55-
)
47+
"""Upload DIANNA logo to sidebar element."""
48+
st.sidebar.image(str(data_directory / 'logo.png'))
5649

5750

5851
def _methods_checkboxes(*, choices: Sequence, key):
59-
"""Get methods from a horizontal row of checkboxes."""
52+
"""Get methods from a horizontal row of checkboxes and the corresponding parameters."""
6053
n_choices = len(choices)
6154
methods = []
55+
method_params = {}
56+
57+
# Create a container for the message
58+
message_container = st.empty()
59+
6260
for col, method in zip(st.columns(n_choices), choices):
6361
with col:
64-
if st.checkbox(method, key=key + method):
62+
if st.checkbox(method, key=f'{key}_{method}'):
6563
methods.append(method)
64+
with st.expander(f'Click to modify {method} parameters'):
65+
method_params[method] = _get_params(method, key=f'{key}_param')
6666

6767
if not methods:
68-
st.info('Select a method to continue')
68+
# Put the message in the container above
69+
message_container.info('Select a method to continue')
6970
st.stop()
7071

71-
return methods
72+
return methods, method_params
7273

7374

7475
def _get_params(method: str, key):
7576
if method == 'RISE':
7677
return {
7778
'n_masks':
78-
st.number_input('Number of masks', value=1000, key=key + method + 'nmasks'),
79+
st.number_input('Number of masks', value=1000, key=f'{key}_{method}_nmasks'),
7980
'feature_res':
80-
st.number_input('Feature resolution', value=6, key=key + method + 'fr'),
81+
st.number_input('Feature resolution', value=6, key=f'{key}_{method}_fr'),
8182
'p_keep':
82-
st.number_input('Probability to be kept unmasked', value=0.1, key=key + method + 'pkeep'),
83+
st.number_input('Probability to be kept unmasked', value=0.1, key=f'{key}_{method}_pkeep'),
8384
}
8485

8586
elif method == 'KernelSHAP':
86-
return {
87-
'nsamples': st.number_input('Number of samples', value=1000, key=key + method + 'nsamp'),
88-
'background': st.number_input('Background', value=0, key=key + method + 'background'),
89-
'n_segments': st.number_input('Number of segments', value=200, key=key + method + 'nseg'),
90-
'sigma': st.number_input('σ', value=0, key=key + method + 'sigma'),
91-
}
87+
if 'Tabular' in key:
88+
return {'training_data_kmeans': st.number_input('Training data kmeans', value=5,
89+
key=f'{key}_{method}_training_data_kmeans'),
90+
}
91+
else:
92+
return {
93+
'nsamples': st.number_input('Number of samples', value=1000, key=f'{key}_{method}_nsamp'),
94+
'background': st.number_input('Background', value=0, key=f'{key}_{method}_background'),
95+
'n_segments': st.number_input('Number of segments', value=200, key=f'{key}_{method}_nseg'),
96+
'sigma': st.number_input('σ', value=0, key=f'{key}_{method}_sigma'),
97+
}
9298

9399
elif method == 'LIME':
94100
return {
95-
'random_state': st.number_input('Random state', value=2, key=key + method + 'rs'),
101+
'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
96102
}
97103

98104
else:
99105
raise ValueError(f'No such method: {method}')
100106

101107

102-
def _get_method_params(methods: Sequence[str], key) -> Dict[str, Dict[str, Any]]:
103-
method_params = {}
104-
105-
with st.expander('Click to modify method parameters'):
106-
for method, col in zip(methods, st.columns(len(methods))):
107-
with col:
108-
st.header(method)
109-
method_params[method] = _get_params(method, key=key)
110-
111-
return method_params
112-
113-
114108
def _get_top_indices(predictions, n_top):
115109
indices = np.array(np.argpartition(predictions, -n_top)[-n_top:])
116110
indices = indices[np.argsort(predictions[indices])]
@@ -119,29 +113,35 @@ def _get_top_indices(predictions, n_top):
119113

120114

121115
def _get_top_indices_and_labels(*, predictions, labels):
122-
c1, c2 = st.columns(2)
116+
cols = st.columns(4)
123117

124-
with c2:
125-
n_top = st.number_input('Number of top results to show',
126-
value=2,
127-
min_value=1,
128-
max_value=len(labels))
118+
if labels is not None:
119+
with cols[-1]:
120+
n_top = st.number_input('Number of top classes to show',
121+
value=1,
122+
min_value=1,
123+
max_value=len(labels))
129124

130-
top_indices = _get_top_indices(predictions, n_top)
131-
top_labels = [labels[i] for i in top_indices]
125+
top_indices = _get_top_indices(predictions, n_top)
126+
top_labels = [labels[i] for i in top_indices]
132127

133-
with c1:
134-
st.metric('Predicted class', top_labels[0])
128+
with cols[0]:
129+
st.metric('Predicted class:', top_labels[0])
130+
else:
131+
# If not a classifier, only return the predicted value
132+
top_indices = top_labels = " "
133+
with cols[0]:
134+
st.metric('Predicted value:', f"{predictions[0]:.2f}")
135135

136136
return top_indices, top_labels
137137

138138
def reset_method():
139139
# Clear selection
140140
for k in st.session_state.keys():
141-
if '_cb_' in k:
142-
st.session_state[k] = False
143-
if 'params' in k:
141+
if '_param' in k:
144142
st.session_state.pop(k)
143+
elif '_cb' in k:
144+
st.session_state[k] = False
145145

146146
def reset_example():
147147
# Clear selection

dianna/dashboard/pages/Images.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from _model_utils import load_model
55
from _models_image import explain_image_dispatcher
66
from _models_image import predict
7-
from _shared import _get_method_params
87
from _shared import _get_top_indices_and_labels
98
from _shared import _methods_checkboxes
109
from _shared import add_sidebar_logo
@@ -88,15 +87,23 @@
8887
labels = load_labels(image_label_file)
8988

9089
choices = ('RISE', 'KernelSHAP', 'LIME')
91-
methods = _methods_checkboxes(choices=choices, key='Image_cb_')
9290

93-
method_params = _get_method_params(methods, key='Image_params_')
91+
st.text("")
92+
st.text("")
9493

95-
with st.spinner('Predicting class'):
96-
predictions = predict(model=model, image=image)
94+
with st.container(border=True):
95+
prediction_placeholder = st.empty()
96+
methods, method_params = _methods_checkboxes(choices=choices, key='Image_cb')
9797

98-
top_indices, top_labels = _get_top_indices_and_labels(predictions=predictions,
99-
labels=labels)
98+
with st.spinner('Predicting class'):
99+
predictions = predict(model=model, image=image)
100+
101+
with prediction_placeholder:
102+
top_indices, top_labels = _get_top_indices_and_labels(
103+
predictions=predictions,labels=labels)
104+
105+
st.text("")
106+
st.text("")
100107

101108
# check which axis is color channel
102109
original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :]
@@ -107,11 +114,11 @@
107114

108115
_, *columns = st.columns(column_spec)
109116
for col, method in zip(columns, methods):
110-
col.header(method)
117+
col.markdown(f"<h4 style='text-align: center; '>{method}</h4>", unsafe_allow_html=True)
111118

112119
for index, label in zip(top_indices, top_labels):
113120
index_col, *columns = st.columns(column_spec)
114-
index_col.markdown(f'##### {label}')
121+
index_col.markdown(f'##### Class: {label}')
115122

116123
for col, method in zip(columns, methods):
117124
kwargs = method_params[method].copy()

0 commit comments

Comments
 (0)