Skip to content

Commit 20e6127

Browse files
authored
Merge pull request #776 from dianna-ai/overview_and_tutorials_change
Overview tutorial and other tutorial changes
2 parents 88eeb02 + 4369b49 commit 20e6127

26 files changed

+3074
-936
lines changed

README.md

+46-75
Original file line numberDiff line numberDiff line change
@@ -109,116 +109,87 @@ If you get an error related to OpenMP when importing dianna, have a look at [thi
109109
You need:
110110

111111
- your trained ONNX model ([convert my pytorch/tensorflow/keras/scikit-learn model to ONNX](https://github.com/dianna-ai/dianna#onnx-models))
112-
- 1 data item to be explained
112+
- a data item to be explained
113113

114114
You get:
115115

116116
- a relevance map overlayed over the data item
117117

118-
In the library's documentation, the general usage is explained in [How to use DIANNA](https://dianna.readthedocs.io/en/latest/usage.html)
118+
### Template example for any data modality and explainer
119119

120-
### Demo movie
121-
122-
[![Watch the video on YouTube](https://img.youtube.com/vi/u9_c5DJewLU/default.jpg)](https://youtu.be/u9_c5DJewLU)
123-
124-
### Text example:
120+
1. Provide your *trained model* and *data item* ( *text, image, time series or tabular* )
125121

126122
```python
127-
model_path = 'your_model.onnx' # model trained on text
128-
text = 'The movie started great but the ending is boring and unoriginal.'
123+
model_path = 'your_model.onnx' # model trained on your data modality
124+
data_item = <data_item> # data item for which the model's prediction needs to be explained
129125
```
130126

131-
Which of your model's classes do you want an explanation for?
127+
2. If the task is classification: which are the *classes* your model has been trained for?
132128

133-
```python
134-
labels = [positive_class, negative_class]
129+
```python
130+
labels = [class_a, class_b] # example of binary classification labels
135131
```
136-
137-
Run using the XAI method of your choice, for example LIME:
138-
132+
*Which* of these classes do you want an explanation for?
139133
```python
140-
explanation = dianna.explain_text(model_path, text, 'LIME')
141-
dianna.visualization.highlight_text(explanation[labels.index(positive_class)], text)
134+
explained_class_index = labels.index(<explained_class>) # explained_class can be any of the labels
142135
```
143136

144-
![image](https://user-images.githubusercontent.com/6087314/155532504-6f90f032-cbb4-4e71-9b99-aa9c0de4e86a.png)
145-
146-
### Image example:
137+
3. Run dianna with the *explainer* of your choice ( *'LIME', 'RISE' or 'KernalSHAP'*) and visualize the output:
147138

148139
```python
149-
model_path = 'your_model.onnx' # model trained on images
150-
image = PIL.Image.open('your_image.jpeg')
140+
explanation = dianna.<explanation_function>(model_path, data_item, explainer)
141+
dianna.visualization.<visualization_function>(explanation[explained_class_index], data_item)
151142
```
152143

153-
Tell us what label refers to the channels, or colors, in the image.
144+
### Text and image usage examples
145+
Lets illustrate the template above with *textual* data. The data item of interest is a sentence being (a part of) a movie review and the model has been trained to classify reviews into positive and negative sentiment classes.
146+
We are intersted which words are contributing positively (red) and which - negatively (blue) towards the model's desicion to classify the review as positive and we would like to use the *LIME* explainer:
154147

155148
```python
156-
axis_labels = {0: 'channels'}
149+
model_path = 'your_text_model.onnx'
150+
# also define a model runner here (details in dedicated notebook)
151+
review = 'The movie started great but the ending is boring and unoriginal.'
152+
labels = ["negative", "positive"]
153+
explained_class_index = labels.index("positive")
154+
explanation = dianna.explain_text(model_path, text, 'LIME')
155+
dianna.visualization.highlight_text(explanation[explained_class_index], model_runner.tokenizer.tokenize(review))
157156
```
158157

159-
Which of your model's classes do you want an explanation for?
160-
161-
```python
162-
labels = [class_a, class_b]
163-
```
158+
![image](https://user-images.githubusercontent.com/6087314/155532504-6f90f032-cbb4-4e71-9b99-aa9c0de4e86a.png)
164159

165-
Run using the XAI method of your choice, for example RISE:
160+
Here is another illustration on how to use dianna to explain which parts of a bee *image* contributied positively (red) or negativey (blue) towards a classifying the image as a *'bee'* using *RISE*.
161+
The Imagenet model has been trained to distinguish between 1000 classes (specified in ```labels```).
162+
For images, which are data of higher dimention compared to text, there are also some specifics to consider:
166163

167164
```python
165+
model_path = 'your_image_model.onnx'
166+
image = PIL.Image.open('your_bee_image.jpeg')
167+
axis_labels = {2: 'channels'}
168+
explained_class_index = labels.index('bee')
168169
explanation = dianna.explain_image(model_path, image, 'RISE', axis_labels=axis_labels, labels=labels)
169-
dianna.visualization.plot_image(explanation[labels.index(class_a)], original_data=image)
170+
dianna.visualization.plot_image(explanation[explained_class_index], utils.img_to_array(image)/255., heatmap_cmap='bwr')
171+
plt.show()
170172
```
173+
<img src="https://github.com/dianna-ai/dianna/assets/3244249/b03e4d4e-e3e8-4248-bf62-e3602b7f6d71" width="215" height="215">
171174

172-
![image](https://user-images.githubusercontent.com/6087314/155557077-e2052094-d8ac-49d3-a840-0160256d53a6.png)
173-
174-
### Time-series example:
175-
175+
And why would Imagenet think the same image would be a *garden spider*?
176176
```python
177-
model_path = 'your_model.onnx' # model trained on images
178-
timeseries_instance = pd.read_csv('your_data_instance.csv').astype(float)
179-
180-
num_features = len(timeseries_instance) # The number of features to include in the explanation.
181-
num_samples = 500 # The number of samples to generate for the LIME explainer.
182-
```
183-
184-
Which of your model's classes do you want an explanation for?
185-
186-
```python
187-
class_names= [class_a, class_b] # String representation of the different classes of interest
188-
labels = np.argsort(class_names) # Numerical representation of the different classes of interest for the model
189-
```
190-
191-
Run using the XAI method of your choice, for example LIME with the following additional arguments:
192-
193-
```python
194-
explanation = dianna.explain_timeseries(model_path, timeseries_data=timeseries_instance , method='LIME',
195-
labels=labels, class_names=class_names, num_features=num_features,
196-
num_samples=num_samples, distance_method='cosine')
197-
198-
```
199-
200-
For visualization of the heatmap please refer to the [tutorial](https://github.com/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_timeseries_coffee.ipynb)
201-
202-
### Tabular example:
203-
204-
```python
205-
model_path = 'your_model.onnx' # model trained on tabular data
206-
tabular_instance = pd.read_csv('your_data_instance.csv')
177+
explained_class_index = labels.index('garden_spider') # interested in the image being classified as a garden spider
178+
explanation = dianna.explain_image(model_path, image, 'RISE', axis_labels=axis_labels, labels=labels)
179+
dianna.visualization.plot_image(explanation[explained_class_index], utils.img_to_array(image)/255., heatmap_cmap='bwr')
180+
plt.show()
207181
```
208182

209-
Run using the XAI method of your choice. Note that you need to specify the mode, either regression or classification. This case, for instance a regression task using KernelSHAP with the following additional arguments:
183+
<img src="https://github.com/dianna-ai/dianna/assets/3244249/e7623803-2369-40ad-b4ef-4a6ae4e902f1" width="215" height="215">
210184

211-
```python
212-
explanation = dianna.explain_tabular(run_model, input_tabular=data_instance, method='kernelshap',
213-
mode ='regression', training_data = X_train,
214-
training_data_kmeans = 5, feature_names=input_features.columns)
215-
plot_tabular(explanation, X_test.columns, num_features=10) # display 10 most salient features
216-
```
185+
### Overview tutorial
186+
There are **full working examples** on how to use the supported explainers and how to use dianna for **all supported data modalities** in our [overview tutorial](./tutorials/overview.ipynb).
217187

218-
![image](https://github.com/dianna-ai/dianna/assets/25911757/ce0b76b8-f00c-468a-9732-c21704e289f6)
188+
#### Demo movie (update planned):
189+
[![Watch the video on YouTube](https://img.youtube.com/vi/u9_c5DJewLU/default.jpg)](https://youtu.be/u9_c5DJewLU)
219190

220191
### IMPORTANT: Sensitivity to hyperparameters
221-
The XAI methods (explainers) are sensitive to the choice of their hyperparameters! In this [work](https://staff.fnwi.uva.nl/a.s.z.belloum/MSctheses/MScthesis_Willem_van_der_Spec.pdf), this sensitivity to hyperparameters is researched and useful conclusions are drawn.
192+
The explainers are sensitive to the choice of their hyperparameters! In this [work](https://staff.fnwi.uva.nl/a.s.z.belloum/MSctheses/MScthesis_Willem_van_der_Spec.pdf), this sensitivity to hyperparameters is researched and useful conclusions are drawn.
222193
The default hyperparameters used in DIANNA for each explainer as well as the values for our tutorial examples are given in the Tutorials [README](./tutorials/README.md#important-hyperparameters).
223194

224195
## Dashboard
@@ -252,7 +223,7 @@ DIANNA comes with simple datasets. Their main goal is to provide intuitive insig
252223

253224
| Dataset | Description | Examples | Generation |
254225
| :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------ |
255-
| Coffee dataset <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162"> | Food spectographs time series dataset for a two class problem to distinguish between Robusta and Arabica coffee beans. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/763002c5-40ad-48cc-9de0-ea43d7fa8a75)"> | [data source](https://github.com/QIBChemometrics/Benchtop-NMR-Coffee-Survey) |
226+
| [Coffee dataset](https://www.timeseriesclassification.com/description.php?Dataset=Coffee) <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162"> | Food spectographs time series dataset for a two class problem to distinguish between Robusta and Arabica coffee beans. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/763002c5-40ad-48cc-9de0-ea43d7fa8a75)"> | [data source](https://github.com/QIBChemometrics/Benchtop-NMR-Coffee-Survey) |
256227
| [Weather dataset](https://zenodo.org/record/7525955) <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f"> | The light version of the weather prediciton dataset, which contains daily observations (89 features) for 11 European locations through the years 2000 to 2010. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/b0a505ac-8a6c-4e1c-b6ad-35e31e52f46d)"> | [data source](https://github.com/florian-huber/weather_prediction_dataset) |
257228

258229
### Tabular

dianna/visualization/text.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def highlight_text(explanation,
66
input_tokens=None,
77
show_plot=True,
88
output_filename=None,
9-
colormap="RdBu",
9+
colormap="bwr",
1010
alpha=1.0,
1111
heatmap_range=(-1, 1)):
1212
"""Highlights a given text based on values in a given explanation object.

docs/tutorials/0-overview.nblink

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"path": "../../tutorials/overview.ipynb"
3+
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

docs/tutorials/demo.nblink

-3
This file was deleted.

tutorials/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<img width="150" alt="Logo_ER10" src="https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png">
22

33
## Tutorials
4-
This folder contains DIANNA tutorial notebooks. To install the dependencies for the tutorials, run
4+
This folder contains DIANNA tutorial notebooks. To install the dependencies for the tutorials, run (in the main dianna folder)
55
```
66
pip install .[notebooks]
77
```
@@ -24,7 +24,7 @@ pip install .[notebooks]
2424
||[Simple Geometric (circles and triangles)](https://doi.org/10.5281/zenodo.5012824)| Binary shape *classificaiton* |<img width="20" alt="SimpleGeometric Logo" src="https://user-images.githubusercontent.com/3244249/151539027-f2fc3fc0-282a-4993-9680-74ee28bcd360.png">|
2525
||[Imagenet](https://image-net.org/download.php) |$1000$ classes natural images *classificaiton* | <img width="94" alt="ImageNet_autocrop" src="https://user-images.githubusercontent.com/3244249/152542090-fd78fde1-6dec-43b6-a7ae-eea964b8ae28.png">|
2626
|*Text*| [Stanford sentiment treebank](https://nlp.stanford.edu/sentiment/index.html) |Positive or negative movie reviews sentiment *classificaiton* | <img width="25" alt="nlp-logo_half_size" src="https://user-images.githubusercontent.com/3244249/152540890-c8e1e37d-f0cc-4f84-80a4-2c59176cbf4c.png">|
27-
|*Timeseries* | Coffee dataset | Binary *classificaiton* of Robusta and Aribica coffee beans| <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162">|
27+
|*Timeseries* | [Coffee dataset](https://www.timeseriesclassification.com/description.php?Dataset=Coffee) | Binary *classificaiton* of Robusta and Aribica coffee beans| <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162">|
2828
| | [Weather dataset](https://zenodo.org/record/7525955) |Binary *classification* (summer/winter) of temperature time-series |<img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f">|
2929
|*Tabular*| [Penguin dataset](https://www.kaggle.com/code/parulpandey/penguin-dataset-the-new-iris)| $3$ penguin spicies (Adele, Chinstrap, Gentoo) *classificaiton* | <img width="75" alt="Penguin Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/c7716ad3-f992-4557-80d9-1d8178c7ed57"> | |
3030
| | [Weather dataset](https://zenodo.org/record/7525955) | Next day sunshine hours prediction (*regression*) | <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f">|

0 commit comments

Comments
 (0)