Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a tutorial notebook for emulator use case #811

Merged
merged 33 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dca9812
add argument silent to shap_values function in kernelshap
SarahAlidoost Jun 18, 2024
6c6b9a5
add draft nb for emulator
SarahAlidoost Jun 18, 2024
12768e3
add issue number for silent keyword
SarahAlidoost Jun 25, 2024
1ab8437
Merge branch 'main' into add_emulator_nb
SarahAlidoost Jun 25, 2024
fcff468
use train_test data to calculate shap values instead of raster data
SarahAlidoost Jun 26, 2024
2bcbc7f
refactor the notebook
SarahAlidoost Jul 2, 2024
da93b4f
rename the notebook
SarahAlidoost Jul 2, 2024
134adf7
add zenodo entries of emulator to downloader
SarahAlidoost Jul 3, 2024
40aae54
add sha values to emulator data and model entries in downloader
SarahAlidoost Jul 10, 2024
d2dc8a4
refcator nb, add summary_plot from shap
SarahAlidoost Jul 10, 2024
5ca5d62
fix the plots in emulator nb
SarahAlidoost Jul 31, 2024
4e9bdee
fix background_data
SarahAlidoost Aug 6, 2024
23b3891
add some notes to the nb
SarahAlidoost Aug 6, 2024
d3b0789
add another test case to nb
SarahAlidoost Aug 7, 2024
d041df6
add some comments, remove some tests
SarahAlidoost Aug 7, 2024
acf6754
add argument description in the docstring for silent argument
SarahAlidoost Aug 7, 2024
1faadef
replace emulator with land_atmosphere in the nb name
SarahAlidoost Aug 14, 2024
a20db7a
replace model_input with dataset
SarahAlidoost Aug 14, 2024
c7efb85
fix comments
SarahAlidoost Aug 14, 2024
539d501
fix plots
SarahAlidoost Aug 16, 2024
f1d52e7
add the model to main readme
SarahAlidoost Aug 16, 2024
426299a
add info about nb to tutorials readme
SarahAlidoost Aug 16, 2024
64b8a28
Update README.md
elboyran Aug 21, 2024
84a5c4b
Update README.md
elboyran Aug 21, 2024
304e942
Update README.md
elboyran Aug 21, 2024
da56fda
Update README.md
elboyran Aug 21, 2024
45e9f00
Update tutorals README.md
elboyran Aug 21, 2024
3b6c048
Update tutorial README.md
elboyran Aug 21, 2024
8292c6b
Update tutorial README.md
elboyran Aug 21, 2024
c2755d0
Update tutorials README.md
elboyran Aug 21, 2024
995ae8a
add kernelshap hyperparameters related to emulator nb
SarahAlidoost Aug 22, 2024
33e9a2f
add explanation about physical model related to emulator nb
SarahAlidoost Aug 22, 2024
c01ad12
improve the intro of emulator nb
SarahAlidoost Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ authors:
affiliation: 1
- name: Laura Ootes^[co-first author] # note this makes a footnote saying 'co-first author'
orcid: 0000-0002-2800-8309
affiliation: 1
affiliation: 1
- name: Pranav Chandramouli^[co-first author] # note this makes a footnote saying 'co-first author'
orcid: 0000-0002-7896-2969
affiliation: 1
affiliation: 1
- name: Aron Jansen^[co-first author] # note this makes a footnote saying 'co-first author'
orcid: 0000-0002-4764-9347
affiliation: 1
affiliation: 1
- name: Stef Smeets^[co-first author] # note this makes a footnote saying 'co-first author'
orcid: 0000-0002-5413-9038
affiliation: 1
affiliation: 1
affiliations:
- name: Netherlands eScience Center, Amsterdam, the Netherlands
index: 1
Expand Down Expand Up @@ -121,12 +121,12 @@ You need:

```python
model_path = 'your_model.onnx' # model trained on your data modality
data_item = <data_item> # data item for which the model's prediction needs to be explained
data_item = <data_item> # data item for which the model's prediction needs to be explained
```

2. If the task is classification: which are the *classes* your model has been trained for?

```python
```python
labels = [class_a, class_b] # example of binary classification labels
```
*Which* of these classes do you want an explanation for?
Expand All @@ -148,24 +148,24 @@ We are intersted which words are contributing positively (red) and which - negat
```python
model_path = 'your_text_model.onnx'
# also define a model runner here (details in dedicated notebook)
review = 'The movie started great but the ending is boring and unoriginal.'
labels = ["negative", "positive"]
explained_class_index = labels.index("positive")
review = 'The movie started great but the ending is boring and unoriginal.'
labels = ["negative", "positive"]
explained_class_index = labels.index("positive")
explanation = dianna.explain_text(model_path, text, 'LIME')
dianna.visualization.highlight_text(explanation[explained_class_index], model_runner.tokenizer.tokenize(review))
```

![image](https://user-images.githubusercontent.com/6087314/155532504-6f90f032-cbb4-4e71-9b99-aa9c0de4e86a.png)

Here is another illustration on how to use dianna to explain which parts of a bee *image* contributied positively (red) or negativey (blue) towards a classifying the image as a *'bee'* using *RISE*.
Here is another illustration on how to use dianna to explain which parts of a bee *image* contributied positively (red) or negativey (blue) towards a classifying the image as a *'bee'* using *RISE*.
The Imagenet model has been trained to distinguish between 1000 classes (specified in ```labels```).
For images, which are data of higher dimention compared to text, there are also some specifics to consider:

```python
model_path = 'your_image_model.onnx'
image = PIL.Image.open('your_bee_image.jpeg')
axis_labels = {2: 'channels'}
explained_class_index = labels.index('bee')
model_path = 'your_image_model.onnx'
image = PIL.Image.open('your_bee_image.jpeg')
axis_labels = {2: 'channels'}
explained_class_index = labels.index('bee')
explanation = dianna.explain_image(model_path, image, 'RISE', axis_labels=axis_labels, labels=labels)
dianna.visualization.plot_image(explanation[explained_class_index], utils.img_to_array(image)/255., heatmap_cmap='bwr')
plt.show()
Expand All @@ -185,7 +185,7 @@ plt.show()
### Overview tutorial
There are **full working examples** on how to use the supported explainers and how to use dianna for **all supported data modalities** in our [overview tutorial](./tutorials/overview.ipynb).

#### Demo movie (update planned):
#### Demo movie (update planned):
[![Watch the video on YouTube](https://img.youtube.com/vi/u9_c5DJewLU/default.jpg)](https://youtu.be/u9_c5DJewLU)

### IMPORTANT: Sensitivity to hyperparameters
Expand All @@ -194,7 +194,7 @@ The default hyperparameters used in DIANNA for each explainer as well as the val

## Dashboard

Explore the explanations of your trained model using the DIANNA dashboard (for now images, text and time series classification is supported).
Explore the explanations of your trained model using the DIANNA dashboard (for now images, text and time series classification is supported).
[Click here](https://github.com/dianna-ai/dianna/tree/main/dianna/dashboard) for more information.

<a href="https://github.com/dianna-ai/dianna/tree/main/dianna/dashboard" target="_blank">
Expand Down Expand Up @@ -223,17 +223,17 @@ DIANNA comes with simple datasets. Their main goal is to provide intuitive insig

| Dataset | Description | Examples | Generation |
| :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------ |
| [Coffee dataset](https://www.timeseriesclassification.com/description.php?Dataset=Coffee) <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162"> | Food spectographs time series dataset for a two class problem to distinguish between Robusta and Arabica coffee beans. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/763002c5-40ad-48cc-9de0-ea43d7fa8a75)"> | [data source](https://github.com/QIBChemometrics/Benchtop-NMR-Coffee-Survey) |
| [Weather dataset](https://zenodo.org/record/7525955) <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f"> | The light version of the weather prediciton dataset, which contains daily observations (89 features) for 11 European locations through the years 2000 to 2010. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/b0a505ac-8a6c-4e1c-b6ad-35e31e52f46d)"> | [data source](https://github.com/florian-huber/weather_prediction_dataset) |
| [Coffee dataset](https://www.timeseriesclassification.com/description.php?Dataset=Coffee) <img width="25" alt="Coffe Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/9ab50a0f-5da3-41d2-80e9-70d2c8769162"> | Food spectographs time series dataset for a two class problem to distinguish between Robusta and Arabica coffee beans. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/763002c5-40ad-48cc-9de0-ea43d7fa8a75"> | [data source](https://github.com/QIBChemometrics/Benchtop-NMR-Coffee-Survey) |
| [Weather dataset](https://zenodo.org/record/7525955) <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f"> | The light version of the weather prediciton dataset, which contains daily observations (89 features) for 11 European locations through the years 2000 to 2010. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/b0a505ac-8a6c-4e1c-b6ad-35e31e52f46d"> | [data source](https://github.com/florian-huber/weather_prediction_dataset) |

### Tabular
### Tabular

| Dataset | Description | Examples | Generation |
| :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------ |
| [Pengiun dataset](https://www.kaggle.com/code/parulpandey/penguin-dataset-the-new-iris) <img width="75" alt="Penguins Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/c7716ad3-f992-4557-80d9-1d8178c7ed57)"> | Palmer Archipelago (Antarctica) penguin dataset is a great intro dataset for data exploration & visualization similar to the famous Iris dataset. | <img width="500" alt="example image" src="https://github.com/allisonhorst/palmerpenguins/blob/main/man/figures/README-mass-flipper-1.png"> | [data source](https://github.com/allisonhorst/palmerpenguins) |
| [Weather dataset](https://zenodo.org/record/7525955) <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f"> | The light version of the weather prediciton dataset, which contains daily observations (89 features) for 11 European locations through the years 2000 to 2010. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/b0a505ac-8a6c-4e1c-b6ad-35e31e52f46d)"> | [data source](https://github.com/florian-huber/weather_prediction_dataset) |

## ONNX models
| [Pengiun dataset](https://www.kaggle.com/code/parulpandey/penguin-dataset-the-new-iris) <img width="75" alt="Penguins Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/c7716ad3-f992-4557-80d9-1d8178c7ed57"> | Palmer Archipelago (Antarctica) penguin dataset is a great intro dataset for data exploration & visualization similar to the famous Iris dataset. | <img width="500" alt="example image" src="https://github.com/allisonhorst/palmerpenguins/blob/main/man/figures/README-mass-flipper-1.png"> | [data source](https://github.com/allisonhorst/palmerpenguins) |
| [Weather dataset](https://zenodo.org/record/7525955) <img width="25" alt="Weather Logo" src="https://github.com/dianna-ai/dianna/assets/3244249/3ff3d639-ed2f-4a38-b7ac-957c984bce9f"> | The light version of the weather prediciton dataset, which contains daily observations (89 features) for 11 European locations through the years 2000 to 2010. | <img width="500" alt="example image" src="https://github.com/dianna-ai/dianna/assets/3244249/b0a505ac-8a6c-4e1c-b6ad-35e31e52f46d"> | [data source](https://github.com/florian-huber/weather_prediction_dataset) |
| [Land atmosphere dataset](https://zenodo.org/records/12623257) <img width="25" alt="Atmosphere Logo" src="https://github.com/user-attachments/assets/bee353dd-c19a-4aec-a778-4ca3574765f0"> | It contains land-atmosphere variables and latent heat flux (LEtot) simulated by STEMMUS-SCOPE (soil-plant model), version 1.5.0, over 19 Fluxnet sites and for the year 2014 with hourly intervals. | <img width="500" alt="example image" src="https://github.com/user-attachments/assets/a6e10b08-08d8-4e57-887a-cd4fca9f2ff0"> | [data source](https://zenodo.org/records/12623257) |
## ONNX model

<!-- TODO: Add all links, see issue https://github.com/dianna-ai/dianna/issues/135 -->

Expand Down Expand Up @@ -276,6 +276,7 @@ And here are links to notebooks showing how we created our models on the benchma
| :-------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [Penguin model (classification)](https://zenodo.org/records/10580743) | [Penguin model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/penguin_species/generate_model.ipynb) |
| [Sunshine hours prediction model (regression)](https://zenodo.org/records/10580833) | [Sunshine hours prediction model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/sunshine_prediction/generate_model.ipynb) |
| [Latent heat flux prediction model (regression)](https://zenodo.org/records/12623257) | [Latent heat flux prediction model](doi:10.5281/zenodo.12623256/stemmus_scope_emulator_model_LEtot.onnx) |


**_We envision the birth of the ONNX Scientific models zoo soon..._**
Expand All @@ -294,22 +295,22 @@ DIANNA supports different data modalities and XAI methods (explainers). We have
| Tabular | planned | ✅ | ✅ |
| Embedding | work in progress | | |
| Graphs* | next steps | ... | ... |
[LRP](https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0130140&type=printable) and [PatternAttribution](https://arxiv.org/pdf/1705.05598.pdf) also feature in the top 5 of our thoroughly evaluated explainers.

[LRP](https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0130140&type=printable) and [PatternAttribution](https://arxiv.org/pdf/1705.05598.pdf) also feature in the top 5 of our thoroughly evaluated explainers.
Also [GradCAM](https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf)) has been recently found to be *semantically continous*! **Contributing by adding these and more (new) post-hoc explainability methods on ONNX models is very welcome!**


### Scientific use-cases
Our goal is that the scientific community embrases XAI as a source for novel and unexplored perspectives on scientific problems.
Our goal is that the scientific community embrases XAI as a source for novel and unexplored perspectives on scientific problems.
Here, we offer [tutorials](./tutorials) on specific scientific use-cases of uisng XAI:

| Use-case (data) \ XAI | [RISE](http://bmvc2018.org/contents/papers/1064.pdf) | [LIME](https://www.kdd.org/kdd2016/papers/files/rfp0573-ribeiroA.pdf) | [KernelSHAP](https://proceedings.neurips.cc/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf) |
| :----------------------------------------------------------------| :----------------------------------------------------| :---------------------------------------------------------------------| :-------------------------------------------------------------------------------------------------------|
| Biology (Phytomorphology): Tree Leaves classification (images) | | ✅ | |
| Astronomy: Fast Radio Burst detection (timeseries) | ✅ | | |
| Geo-science (raster data) | planned | ... | ... |
| Social sciences (text) | work in progress | ... |... |
| Climate | planned | ... | ... |
| Land-atmosphere modeling: Latent heat flux prediction (tabular) | | | |
| Social sciences (text) | work in progress | ... |... |
| Climate | planned | ... | ... |

## Reference documentation

Expand Down
6 changes: 4 additions & 2 deletions dianna/methods/kernelshap_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
mode: str = "classification",
feature_names: List[int] = None,
training_data_kmeans: Optional[int] = None,
silent: bool = False,
) -> None:
"""Initializer of KERNELSHAPTabular.

Expand All @@ -32,6 +33,7 @@ def __init__(
in the training data.
training_data_kmeans(int, optional): summarize the whole training set with
weighted kmeans
silent (bool, optional): whether to print progress messages
"""
if training_data_kmeans:
self.training_data = shap.kmeans(training_data,
Expand All @@ -41,6 +43,7 @@ def __init__(
self.feature_names = feature_names
self.mode = mode
self.explainer: KernelExplainer
self.silent = silent

def explain(
self,
Expand Down Expand Up @@ -73,8 +76,7 @@ def explain(
explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
self.explainer.shap_values, kwargs)

saliency = self.explainer.shap_values(input_tabular,
**explain_instance_kwargs)
saliency = self.explainer.shap_values(input_tabular, silent=self.silent, **explain_instance_kwargs)

if self.mode == 'regression':
saliency = saliency[0]
Expand Down
Loading
Loading