-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest_common_usage.py
32 lines (26 loc) · 1.23 KB
/
test_common_usage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import dianna
import dianna.visualization
from tests.utils import run_model
def test_common_RISE_image_pipeline(): # noqa: N802 ignore case
"""No errors thrown while creating a relevance map and visualizing it."""
input_image = np.random.random((224, 224, 3))
axis_labels = {-1: 'channels'}
labels = [0, 1]
heatmap = dianna.explain_image(run_model, input_image, "RISE", labels, axis_labels=axis_labels)[0]
dianna.visualization.plot_image(heatmap, show_plot=False)
dianna.visualization.plot_image(heatmap, original_data=input_image[0], show_plot=False)
def test_common_RISE_timeseries_pipeline(): # noqa: N802 ignore case
"""No errors thrown while creating a relevance map and visualizing it."""
input_image = np.random.random((31, 1))
labels = [0]
heatmap = dianna.explain_timeseries(run_model, input_image, "RISE", labels)[0]
heatmap_channel = heatmap[:, 0]
segments = []
for i in range(len(heatmap_channel) - 1):
segments.append({
'index': i,
'start': i,
'stop': i + 1,
'weight': heatmap_channel[i]})
dianna.visualization.plot_timeseries(range(len(heatmap_channel)), input_image[:, 0], segments, show_plot=False)