Skip to content

Commit 7757b17

Browse files
authored
Merge pull request #752 from dianna-ai/690-timeseries-with-onnx-file
Add test for rise timeseries on onnx file
2 parents 7d9e813 + 3587a0e commit 7757b17

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

tests/methods/test_rise_timeseries.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,36 @@ def test_rise_timeseries_correct_output_shape():
1212
input_data = np.random.random((10, 1))
1313
labels = [1]
1414

15-
heatmaps = dianna.explain_timeseries(run_model, input_data, "RISE", labels,
16-
n_masks=200, p_keep=.5)
15+
heatmaps = dianna.explain_timeseries(run_model,
16+
input_data,
17+
"RISE",
18+
labels,
19+
n_masks=200,
20+
p_keep=.5)
1721

1822
assert heatmaps.shape == (len(labels), *input_data.shape)
1923

2024

25+
def test_rise_timeseries_with_model_file():
26+
"""Test if rise runs and outputs the correct shape given some data and a model file."""
27+
filename = 'dianna/models/season_prediction_model_temp_max_binary.onnx'
28+
input_data = np.random.random((28, 1))
29+
labels = [0]
30+
31+
# this model requires float input, while numpy uses double
32+
def preprocess(data):
33+
return data.astype(np.float32)
34+
35+
heatmaps = dianna.explain_timeseries(filename,
36+
input_data,
37+
"RISE",
38+
labels,
39+
n_masks=200,
40+
p_keep=.5,
41+
preprocess_function=preprocess)
42+
print(heatmaps.shape)
43+
44+
2145
@pytest.mark.parametrize('series_length', [
2246
10,
2347
3,

0 commit comments

Comments
 (0)