1
1
import base64
2
2
import sys
3
- from typing import Any
4
- from typing import Dict
5
3
from typing import Sequence
6
4
import numpy as np
7
5
import streamlit as st
@@ -46,71 +44,67 @@ def build_markup_for_logo(
46
44
47
45
48
46
def add_sidebar_logo ():
49
- """Based on: https://stackoverflow.com/a/73278825."""
50
- png_file = data_directory / 'logo.png'
51
- logo_markup = build_markup_for_logo (png_file )
52
- st .markdown (
53
- logo_markup ,
54
- unsafe_allow_html = True ,
55
- )
47
+ """Upload DIANNA logo to sidebar element."""
48
+ st .sidebar .image (str (data_directory / 'logo.png' ))
56
49
57
50
58
51
def _methods_checkboxes (* , choices : Sequence , key ):
59
- """Get methods from a horizontal row of checkboxes."""
52
+ """Get methods from a horizontal row of checkboxes and the corresponding parameters ."""
60
53
n_choices = len (choices )
61
54
methods = []
55
+ method_params = {}
56
+
57
+ # Create a container for the message
58
+ message_container = st .empty ()
59
+
62
60
for col , method in zip (st .columns (n_choices ), choices ):
63
61
with col :
64
- if st .checkbox (method , key = key + method ):
62
+ if st .checkbox (method , key = f' { key } _ { method } ' ):
65
63
methods .append (method )
64
+ with st .expander (f'Click to modify { method } parameters' ):
65
+ method_params [method ] = _get_params (method , key = f'{ key } _param' )
66
66
67
67
if not methods :
68
- st .info ('Select a method to continue' )
68
+ # Put the message in the container above
69
+ message_container .info ('Select a method to continue' )
69
70
st .stop ()
70
71
71
- return methods
72
+ return methods , method_params
72
73
73
74
74
75
def _get_params (method : str , key ):
75
76
if method == 'RISE' :
76
77
return {
77
78
'n_masks' :
78
- st .number_input ('Number of masks' , value = 1000 , key = key + method + 'nmasks ' ),
79
+ st .number_input ('Number of masks' , value = 1000 , key = f' { key } _ { method } _nmasks ' ),
79
80
'feature_res' :
80
- st .number_input ('Feature resolution' , value = 6 , key = key + method + 'fr ' ),
81
+ st .number_input ('Feature resolution' , value = 6 , key = f' { key } _ { method } _fr ' ),
81
82
'p_keep' :
82
- st .number_input ('Probability to be kept unmasked' , value = 0.1 , key = key + method + 'pkeep ' ),
83
+ st .number_input ('Probability to be kept unmasked' , value = 0.1 , key = f' { key } _ { method } _pkeep ' ),
83
84
}
84
85
85
86
elif method == 'KernelSHAP' :
86
- return {
87
- 'nsamples' : st .number_input ('Number of samples' , value = 1000 , key = key + method + 'nsamp' ),
88
- 'background' : st .number_input ('Background' , value = 0 , key = key + method + 'background' ),
89
- 'n_segments' : st .number_input ('Number of segments' , value = 200 , key = key + method + 'nseg' ),
90
- 'sigma' : st .number_input ('σ' , value = 0 , key = key + method + 'sigma' ),
91
- }
87
+ if 'Tabular' in key :
88
+ return {'training_data_kmeans' : st .number_input ('Training data kmeans' , value = 5 ,
89
+ key = f'{ key } _{ method } _training_data_kmeans' ),
90
+ }
91
+ else :
92
+ return {
93
+ 'nsamples' : st .number_input ('Number of samples' , value = 1000 , key = f'{ key } _{ method } _nsamp' ),
94
+ 'background' : st .number_input ('Background' , value = 0 , key = f'{ key } _{ method } _background' ),
95
+ 'n_segments' : st .number_input ('Number of segments' , value = 200 , key = f'{ key } _{ method } _nseg' ),
96
+ 'sigma' : st .number_input ('σ' , value = 0 , key = f'{ key } _{ method } _sigma' ),
97
+ }
92
98
93
99
elif method == 'LIME' :
94
100
return {
95
- 'random_state' : st .number_input ('Random state' , value = 2 , key = key + method + 'rs ' ),
101
+ 'random_state' : st .number_input ('Random state' , value = 2 , key = f' { key } _ { method } _rs ' ),
96
102
}
97
103
98
104
else :
99
105
raise ValueError (f'No such method: { method } ' )
100
106
101
107
102
- def _get_method_params (methods : Sequence [str ], key ) -> Dict [str , Dict [str , Any ]]:
103
- method_params = {}
104
-
105
- with st .expander ('Click to modify method parameters' ):
106
- for method , col in zip (methods , st .columns (len (methods ))):
107
- with col :
108
- st .header (method )
109
- method_params [method ] = _get_params (method , key = key )
110
-
111
- return method_params
112
-
113
-
114
108
def _get_top_indices (predictions , n_top ):
115
109
indices = np .array (np .argpartition (predictions , - n_top )[- n_top :])
116
110
indices = indices [np .argsort (predictions [indices ])]
@@ -119,29 +113,35 @@ def _get_top_indices(predictions, n_top):
119
113
120
114
121
115
def _get_top_indices_and_labels (* , predictions , labels ):
122
- c1 , c2 = st .columns (2 )
116
+ cols = st .columns (4 )
123
117
124
- with c2 :
125
- n_top = st .number_input ('Number of top results to show' ,
126
- value = 2 ,
127
- min_value = 1 ,
128
- max_value = len (labels ))
118
+ if labels is not None :
119
+ with cols [- 1 ]:
120
+ n_top = st .number_input ('Number of top classes to show' ,
121
+ value = 1 ,
122
+ min_value = 1 ,
123
+ max_value = len (labels ))
129
124
130
- top_indices = _get_top_indices (predictions , n_top )
131
- top_labels = [labels [i ] for i in top_indices ]
125
+ top_indices = _get_top_indices (predictions , n_top )
126
+ top_labels = [labels [i ] for i in top_indices ]
132
127
133
- with c1 :
134
- st .metric ('Predicted class' , top_labels [0 ])
128
+ with cols [0 ]:
129
+ st .metric ('Predicted class:' , top_labels [0 ])
130
+ else :
131
+ # If not a classifier, only return the predicted value
132
+ top_indices = top_labels = " "
133
+ with cols [0 ]:
134
+ st .metric ('Predicted value:' , f"{ predictions [0 ]:.2f} " )
135
135
136
136
return top_indices , top_labels
137
137
138
138
def reset_method ():
139
139
# Clear selection
140
140
for k in st .session_state .keys ():
141
- if '_cb_' in k :
142
- st .session_state [k ] = False
143
- if 'params' in k :
141
+ if '_param' in k :
144
142
st .session_state .pop (k )
143
+ elif '_cb' in k :
144
+ st .session_state [k ] = False
145
145
146
146
def reset_example ():
147
147
# Clear selection
0 commit comments