Skip to content

Commit cde1700

Browse files
committed
Merge branch '829-no-more-torchtext-dependency' into 756-add-notebook-for-eulaw
# Conflicts: # dianna/utils/downloader.py
2 parents 6607639 + 983d287 commit cde1700

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2529
-608
lines changed

.bumpversion.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 1.5.0
2+
current_version = 1.6.0
33

44
[comment]
55
comment = The contents of this file cannot be merged with that of setup.cfg until https://github.com/c4urself/bump2version/issues/185 is resolved

.github/actions/install-python-and-package/action.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ runs:
3636
python -m pip install --upgrade pip setuptools wheel
3737
3838
# only necessary on linux to avoid bloated installs
39+
# pining the version of torch is temporary, see see #829
3940
- name: Install tensorflow/pytorch cpu version
4041
if: runner.os == 'Linux' && steps.cache-python-env.outputs.cache-hit != 'true'
4142
shell: bash {0}
4243
run: |
43-
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
44+
python -m pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
4445
python -m pip install tensorflow-cpu
4546
4647
- name: Install DIANNA

.github/workflows/build.yml

+16-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ jobs:
2626
with:
2727
python-version: '3.11'
2828

29-
- name: Run unit tests including downloader
30-
run: pytest -v --downloader
29+
- name: Run unit tests
30+
run: pytest -v
3131

3232
- name: Verify that we can build the package
3333
run: python setup.py sdist bdist_wheel
@@ -53,11 +53,24 @@ jobs:
5353
python-version: ${{ matrix.python-version }}
5454

5555
- name: Run unit tests
56-
run: python -m pytest -v --downloader
56+
run: python -m pytest -v
5757

5858
- name: Verify that we can build the package
5959
run: python setup.py sdist bdist_wheel
6060

61+
test_downloader:
62+
name: Test file downloader
63+
if: github.event.pull_request.draft == false
64+
runs-on: ubuntu-latest
65+
steps:
66+
- uses: actions/checkout@v3
67+
- uses: ./.github/actions/install-python-and-package
68+
with:
69+
python-version: '3.11'
70+
extras-require: dev
71+
- name: Run downloader test
72+
run: python -m pytest -v --downloader -k downloader
73+
6174
test_dashboard:
6275
name: Test dashboard
6376
if: github.event.pull_request.draft == false

.github/workflows/linting.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ jobs:
3333
3434
- name: Check code style
3535
run: |
36-
ruff dianna tests
36+
ruff check dianna tests

.github/workflows/notebooks.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@ on:
1717
jobs:
1818

1919
notebooks:
20-
name: Run notebooks on (3.10, ${{ matrix.os }})
20+
name: Run notebooks on (${{ matrix.python-version }}, ${{ matrix.os }})
2121
if: github.event.pull_request.draft == false
2222
runs-on: ${{ matrix.os }}
2323
strategy:
2424
fail-fast: false
2525
matrix:
2626
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
27+
python-version: ['3.10']
2728
steps:
2829
- uses: actions/checkout@v3
2930
- uses: ./.github/actions/install-python-and-package
3031
with:
32+
python-version: ${{ matrix.python-version }}
3133
extras-require: dev,notebooks
3234
- name: Run tutorial notebooks
3335
run: pytest --nbmake tutorials

CITATION.cff

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ authors:
5353
name-particle: "van der"
5454

5555
doi: 10.5281/zenodo.5801485
56-
version: "1.5.0"
56+
version: "1.6.0"
5757
repository-code: "https://github.com/dianna-ai/dianna"
5858
keywords:
5959
- XAI

README.md

+30-29
Large diffs are not rendered by default.

dianna/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
__author__ = 'DIANNA Team'
3434
__email__ = 'dianna-ai@esciencecenter.nl'
35-
__version__ = '1.5.0'
35+
__version__ = '1.6.0'
3636

3737

3838
def explain_timeseries(model_or_function: Union[Callable, str],

dianna/dashboard/.streamlit/config.toml

+3
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ secondaryBackgroundColor="#e4f3f9"
88

99
[browser]
1010
gatherUsageStats = false
11+
12+
[client]
13+
showSidebarNavigation = false

dianna/dashboard/Home.py

+61-18
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import importlib
12
import streamlit as st
2-
from _shared import add_sidebar_logo
33
from _shared import data_directory
4+
from streamlit_option_menu import option_menu
45

56
st.set_page_config(page_title="Dianna's dashboard",
67
page_icon='📊',
7-
layout='centered',
8+
layout='wide',
89
initial_sidebar_state='auto',
910
menu_items={
1011
'Get help':
@@ -16,26 +17,68 @@
1617
'https://github.com/dianna-ai/dianna')
1718
})
1819

19-
add_sidebar_logo()
20+
# Define dictionary of dashboard pages
21+
pages = {
22+
"Home": "home",
23+
"Images": "pages.Images",
24+
"Tabular": "pages.Tabular",
25+
"Text": "pages.Text",
26+
"Time series": "pages.Time_series"
27+
}
2028

21-
st.image(str(data_directory / 'logo.png'))
29+
# Set up the top menu
30+
selected = option_menu(
31+
menu_title=None,
32+
options=list(pages.keys()),
33+
icons=["house", "camera", "table", "alphabet", "clock"],
34+
menu_icon="cast",
35+
default_index=0,
36+
orientation="horizontal"
37+
)
2238

23-
st.markdown("""
24-
DIANNA is a Python package that brings explainable AI (XAI) to your research project.
25-
It wraps carefully selected XAI methods in a simple, uniform interface. It's built by,
26-
with and for (academic) researchers and research software engineers working on machine
27-
learning projects.
39+
# Display the content of the selected page
40+
if selected == "Home":
41+
st.image(str(data_directory / 'logo.png'))
2842

29-
### Pages
43+
st.markdown("""
44+
DIANNA is a Python package that brings explainable AI (XAI) to your research project.
45+
It wraps carefully selected XAI methods in a simple, uniform interface. It's built by,
46+
with and for (academic) researchers and research software engineers working on machine
47+
learning projects.
3048
31-
- <a href="/Images" target="_parent">Images</a>
32-
- <a href="/Text" target="_parent">Text</a>
33-
- <a href="/Time_series" target="_parent">Time series</a>
49+
### Pages
3450
51+
- <a href="/Images" target="_parent">Image data</a>
52+
- <a href="/Tabular" target="_parent">Tabular data</a>
53+
- <a href="/Text" target="_parent">Text data</a>
54+
- <a href="/Time_series" target="_parent">Time series data</a>
3555
36-
### More information
3756
38-
- [Source code](https://github.com/dianna-ai/dianna)
39-
- [Documentation](https://dianna.readthedocs.io/)
40-
""",
41-
unsafe_allow_html=True)
57+
### More information
58+
59+
- [Source code](https://github.com/dianna-ai/dianna)
60+
- [Documentation](https://dianna.readthedocs.io/)
61+
""",
62+
unsafe_allow_html=True)
63+
64+
else:
65+
# Dynamically import and execute the page
66+
page_module = pages[selected]
67+
# Make sure that all variables are reset when switching page
68+
if selected != 'Images':
69+
for k in st.session_state.keys():
70+
if 'Image' in k:
71+
st.session_state.pop(k, None)
72+
if selected != 'Tabular':
73+
for k in st.session_state.keys():
74+
if 'Tabular' in k:
75+
st.session_state.pop(k, None)
76+
if selected != 'Text':
77+
for k in st.session_state.keys():
78+
if 'Text' in k:
79+
st.session_state.pop(k, None)
80+
if selected != 'Time series':
81+
for k in st.session_state.keys():
82+
if 'TS' in k:
83+
st.session_state.pop(k, None)
84+
page = importlib.import_module(page_module)

dianna/dashboard/_model_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from pathlib import Path
22
import numpy as np
33
import onnx
4+
import pandas as pd
5+
6+
7+
def load_data(file):
8+
"""Open data from a file and returns it as pandas DataFrame."""
9+
df = pd.read_csv(file, parse_dates=True)
10+
# Add index column
11+
df.insert(0, 'Index', df.index)
12+
return df
413

514

615
def preprocess_function(image):
@@ -29,3 +38,7 @@ def load_labels(file):
2938
if labels is None or labels == ['']:
3039
raise ValueError(labels)
3140
return labels
41+
42+
43+
def load_training_data(file):
44+
return np.float32(np.load(file, allow_pickle=False))

dianna/dashboard/_models_tabular.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import tempfile
2+
import numpy as np
3+
import streamlit as st
4+
from dianna import explain_tabular
5+
from dianna.utils.onnx_runner import SimpleModelRunner
6+
7+
8+
@st.cache_data
9+
def predict(*, model, tabular_input):
10+
model_runner = SimpleModelRunner(model)
11+
predictions = model_runner(tabular_input.reshape(1,-1).astype(np.float32))
12+
return predictions
13+
14+
15+
@st.cache_data
16+
def _run_rise_tabular(_model, table, training_data, **kwargs):
17+
relevances = explain_tabular(
18+
_model,
19+
table,
20+
method='RISE',
21+
training_data=training_data,
22+
**kwargs,
23+
)
24+
return relevances
25+
26+
27+
@st.cache_data
28+
def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs):
29+
relevances = explain_tabular(
30+
_model,
31+
table,
32+
method='LIME',
33+
training_data=training_data,
34+
feature_names=_feature_names,
35+
**kwargs,
36+
)
37+
return relevances
38+
39+
@st.cache_data
40+
def _run_kernelshap_tabular(model, table, training_data, **kwargs):
41+
# Kernelshap interface is different. Write model to temporary file.
42+
with tempfile.NamedTemporaryFile() as f:
43+
f.write(model)
44+
f.flush()
45+
relevances = explain_tabular(f.name,
46+
table,
47+
method='KernelSHAP',
48+
training_data=training_data,
49+
**kwargs)
50+
return relevances[0]
51+
52+
53+
explain_tabular_dispatcher = {
54+
'RISE': _run_rise_tabular,
55+
'LIME': _run_lime_tabular,
56+
'KernelSHAP': _run_kernelshap_tabular
57+
}

dianna/dashboard/_models_ts.py

+8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def predict(*, model, ts_data):
2121

2222
@st.cache_data
2323
def _run_rise_timeseries(_model, ts_data, **kwargs):
24+
# convert streamlit kwarg requirement back to dianna kwarg requirement
25+
if "_preprocess_function" in kwargs:
26+
kwargs["preprocess_function"] = kwargs["_preprocess_function"]
27+
del kwargs["_preprocess_function"]
2428

2529
def run_model(ts_data):
2630
return predict(model=_model, ts_data=ts_data)
@@ -37,6 +41,10 @@ def run_model(ts_data):
3741

3842
@st.cache_data
3943
def _run_lime_timeseries(_model, ts_data, **kwargs):
44+
# convert streamlit kwarg requirement back to dianna kwarg requirement
45+
if "_preprocess_function" in kwargs:
46+
kwargs["preprocess_function"] = kwargs["_preprocess_function"]
47+
del kwargs["_preprocess_function"]
4048

4149
def run_model(ts_data):
4250
return predict(model=_model, ts_data=ts_data)

0 commit comments

Comments
 (0)