Skip to content

Commit 5424e7f

Browse files
authored
Merge pull request #109 from dianna-ai/fix_axis_labels_name
Rename axes_labels parameter to axis_labels
2 parents 83cd079 + 4a3d385 commit 5424e7f

File tree

7 files changed

+23
-23
lines changed

7 files changed

+23
-23
lines changed

dianna/methods/lime.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self,
2020
mask_string=None,
2121
random_state=None,
2222
char_level=False,
23-
axes_labels=None,
23+
axis_labels=None,
2424
preprocess_function=None,
2525
): # pylint: disable=too-many-arguments
2626
"""
@@ -37,7 +37,7 @@ def __init__(self,
3737
mask_string (str, optional): mask string
3838
random_state (int or np.RandomState, optional): seed or random state
3939
char_level (bool, optional): char level
40-
axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
40+
axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
4141
If a list, the name of each axis where the index
4242
in the list is the axis index
4343
preprocess_function (callable, optional): Function to preprocess input data with
@@ -62,7 +62,7 @@ def __init__(self,
6262
)
6363

6464
self.preprocess_function = preprocess_function
65-
self.axes_labels = axes_labels if axes_labels is not None else []
65+
self.axis_labels = axis_labels if axis_labels is not None else []
6666

6767
def explain_text(self,
6868
model_or_function,
@@ -166,7 +166,7 @@ def _prepare_image_data(self, input_data):
166166
Returns:
167167
transformed input data, preprocessing function to use with utils.get_function()
168168
"""
169-
input_data = utils.to_xarray(input_data, self.axes_labels, LIME.required_labels)
169+
input_data = utils.to_xarray(input_data, self.axis_labels, LIME.required_labels)
170170
# remove batch axis from input data; this is only here for a consistent API
171171
# but LIME wants data without batch axis
172172
if not len(input_data['batch']) == 1:

dianna/methods/rise.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ class RISE:
1919
required_labels = ('batch', 'channels')
2020

2121
def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=too-many-arguments
22-
axes_labels=None, preprocess_function=None):
22+
axis_labels=None, preprocess_function=None):
2323
"""RISE initializer.
2424
2525
Args:
2626
n_masks (int): Number of masks to generate.
2727
feature_res (int): Resolution of features in masks.
2828
p_keep (float): Fraction of image to keep in each mask
29-
axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
29+
axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
3030
If a list, the name of each axis where the index
3131
in the list is the axis index
3232
preprocess_function (callable, optional): Function to preprocess input data with
@@ -37,7 +37,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=t
3737
self.preprocess_function = preprocess_function
3838
self.masks = None
3939
self.predictions = None
40-
self.axes_labels = axes_labels if axes_labels is not None else []
40+
self.axis_labels = axis_labels if axis_labels is not None else []
4141

4242
def explain_text(self, model_or_function, input_text, labels=(0,), batch_size=100):
4343
"""Runs the RISE explainer on text.
@@ -136,7 +136,7 @@ def explain_image(self, model_or_function, input_data, labels=None, batch_size=1
136136
Explanation heatmap for each class (np.ndarray).
137137
"""
138138
# convert data to xarray
139-
input_data = utils.to_xarray(input_data, self.axes_labels, RISE.required_labels)
139+
input_data = utils.to_xarray(input_data, self.axis_labels, RISE.required_labels)
140140
# batch axis should always be first
141141
input_data = utils.move_axis(input_data, 'batch', 0)
142142
input_data, full_preprocess_function = self._prepare_image_data(input_data)

dianna/utils/misc.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,20 @@ def get_kwargs_applicable_to_function(function, kwargs):
3838
if key in function.__code__.co_varnames}
3939

4040

41-
def to_xarray(data, axes_labels, required_labels=None):
41+
def to_xarray(data, axis_labels, required_labels=None):
4242
"""Converts numpy data and axes labels to an xarray object."""
43-
if isinstance(axes_labels, dict):
43+
if isinstance(axis_labels, dict):
4444
# key = axis index, value = label
4545
# not all axes have to be present in the input, but we need to provide
4646
# a name for each axis
4747
# first ensure negative indices are converted to positive ones
48-
indices = list(axes_labels.keys())
48+
indices = list(axis_labels.keys())
4949
for index in indices:
5050
if index < 0:
51-
axes_labels[data.ndim + index] = axes_labels.pop(index)
52-
labels = [axes_labels[index] if index in axes_labels else f'dim_{index}' for index in range(data.ndim)]
51+
axis_labels[data.ndim + index] = axis_labels.pop(index)
52+
labels = [axis_labels[index] if index in axis_labels else f'dim_{index}' for index in range(data.ndim)]
5353
else:
54-
labels = list(axes_labels)
54+
labels = list(axis_labels)
5555

5656
# check if the required labels are present
5757
if required_labels is not None:

tests/test_common_usage.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66

77
input_data = np.random.random((1, 224, 224, 3))
8-
axes_labels = {0: 'batch', -1: 'channels'}
8+
axis_labels = {0: 'batch', -1: 'channels'}
99

1010

1111
def test_common_RISE_pipeline(): # noqa: N802 ignore case
12-
heatmap = dianna.explain_image(run_model, input_data, method="RISE", axes_labels=axes_labels)[0]
12+
heatmap = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels)[0]
1313
dianna.visualization.plot_image(heatmap, show_plot=False)
1414
dianna.visualization.plot_image(heatmap, original_data=input_data[0], show_plot=False)

tests/test_lime.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_lime_function(self):
1717
labels = ('batch', 'y', 'x', 'channels')
1818
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy')
1919

20-
explainer = LIME(random_state=42, axes_labels=labels)
20+
explainer = LIME(random_state=42, axis_labels=labels)
2121
heatmap = explainer.explain_image(run_model, input_data, num_samples=100)
2222

2323
assert heatmap[0].shape == input_data[0].shape[:2]
@@ -38,7 +38,7 @@ def preprocess(data):
3838
labels = ('batch', 'channels', 'y', 'x')
3939

4040
heatmap = dianna.explain_image(model_filename, input_data, method="LIME", preprocess_function=preprocess, random_state=42,
41-
axes_labels=labels)
41+
axis_labels=labels)
4242

4343
heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy')
4444
assert heatmap[0].shape == input_data[0, 0].shape

tests/test_rise.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def test_rise_function(self):
1616
"""Test if rise runs and outputs the correct shape given some data and a model function."""
1717
input_data = np.random.random((1, 224, 224, 3))
1818
# y and x axis labels are not actually mandatory for this test
19-
axes_labels = ['batch', 'y', 'x', 'channels']
19+
axis_labels = ['batch', 'y', 'x', 'channels']
2020

21-
heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axes_labels=axes_labels, n_masks=200)
21+
heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels, n_masks=200)
2222

2323
assert heatmaps[0].shape == input_data[0].shape[:2]
2424

@@ -27,9 +27,9 @@ def test_rise_filename(self):
2727
model_filename = 'tests/test_data/mnist_model.onnx'
2828
input_data = generate_data(batch_size=1).astype(np.float32)
2929
# y and x axis labels are not actually mandatory for this test
30-
axes_labels = ['batch', 'channels', 'y', 'x']
30+
axis_labels = ['batch', 'channels', 'y', 'x']
3131

32-
heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", axes_labels=axes_labels, n_masks=200)
32+
heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", axis_labels=axis_labels, n_masks=200)
3333

3434
assert heatmaps[0].shape == input_data[0].shape[1:]
3535

tutorials/rise_mnist.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@
155155
}
156156
],
157157
"source": [
158-
"explainer = RISE(n_masks=2000, feature_res=8, p_keep=.8, axes_labels=axis_labels)\n",
158+
"explainer = RISE(n_masks=2000, feature_res=8, p_keep=.8, axis_labels=axis_labels)\n",
159159
"heatmaps = explainer.explain_image(run_model, X_test[[i_instance]])"
160160
]
161161
},

0 commit comments

Comments
 (0)