Skip to content

Commit 4a3d385

Browse files
committed
Rename axes_labels parameter to axis_labels
1 parent 7772f73 commit 4a3d385

File tree

7 files changed

+23
-23
lines changed

7 files changed

+23
-23
lines changed

dianna/methods/lime.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self,
2020
mask_string=None,
2121
random_state=None,
2222
char_level=False,
23-
axes_labels=None,
23+
axis_labels=None,
2424
preprocess_function=None,
2525
): # pylint: disable=too-many-arguments
2626
"""
@@ -37,7 +37,7 @@ def __init__(self,
3737
mask_string (str, optional): mask string
3838
random_state (int or np.RandomState, optional): seed or random state
3939
char_level (bool, optional): char level
40-
axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
40+
axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
4141
If a list, the name of each axis where the index
4242
in the list is the axis index
4343
preprocess_function (callable, optional): Function to preprocess input data with
@@ -62,7 +62,7 @@ def __init__(self,
6262
)
6363

6464
self.preprocess_function = preprocess_function
65-
self.axes_labels = axes_labels if axes_labels is not None else []
65+
self.axis_labels = axis_labels if axis_labels is not None else []
6666

6767
def explain_text(self,
6868
model_or_function,
@@ -165,7 +165,7 @@ def _prepare_image_data(self, input_data):
165165
Returns:
166166
transformed input data, preprocessing function to use with utils.get_function()
167167
"""
168-
input_data = utils.to_xarray(input_data, self.axes_labels, LIME.required_labels)
168+
input_data = utils.to_xarray(input_data, self.axis_labels, LIME.required_labels)
169169
# remove batch axis from input data; this is only here for a consistent API
170170
# but LIME wants data without batch axis
171171
if not len(input_data['batch']) == 1:

dianna/methods/rise.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ class RISE:
1919
required_labels = ('batch', 'channels')
2020

2121
def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=too-many-arguments
22-
axes_labels=None, preprocess_function=None):
22+
axis_labels=None, preprocess_function=None):
2323
"""RISE initializer.
2424
2525
Args:
2626
n_masks (int): Number of masks to generate.
2727
feature_res (int): Resolution of features in masks.
2828
p_keep (float): Fraction of image to keep in each mask
29-
axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
29+
axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
3030
If a list, the name of each axis where the index
3131
in the list is the axis index
3232
preprocess_function (callable, optional): Function to preprocess input data with
@@ -37,7 +37,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=t
3737
self.preprocess_function = preprocess_function
3838
self.masks = None
3939
self.predictions = None
40-
self.axes_labels = axes_labels if axes_labels is not None else []
40+
self.axis_labels = axis_labels if axis_labels is not None else []
4141

4242
def explain_text(self, model_or_function, input_text, labels=(0,), batch_size=100):
4343
"""Runs the RISE explainer on text.
@@ -135,7 +135,7 @@ def explain_image(self, model_or_function, input_data, batch_size=100):
135135
Explanation heatmap for each class (np.ndarray).
136136
"""
137137
# convert data to xarray
138-
input_data = utils.to_xarray(input_data, self.axes_labels, RISE.required_labels)
138+
input_data = utils.to_xarray(input_data, self.axis_labels, RISE.required_labels)
139139
# batch axis should always be first
140140
input_data = utils.move_axis(input_data, 'batch', 0)
141141
input_data, full_preprocess_function = self._prepare_image_data(input_data)

dianna/utils/misc.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,20 @@ def get_kwargs_applicable_to_function(function, kwargs):
3737
if key in function.__code__.co_varnames}
3838

3939

40-
def to_xarray(data, axes_labels, required_labels=None):
40+
def to_xarray(data, axis_labels, required_labels=None):
4141
"""Converts numpy data and axes labels to an xarray object."""
42-
if isinstance(axes_labels, dict):
42+
if isinstance(axis_labels, dict):
4343
# key = axis index, value = label
4444
# not all axes have to be present in the input, but we need to provide
4545
# a name for each axis
4646
# first ensure negative indices are converted to positive ones
47-
indices = list(axes_labels.keys())
47+
indices = list(axis_labels.keys())
4848
for index in indices:
4949
if index < 0:
50-
axes_labels[data.ndim + index] = axes_labels.pop(index)
51-
labels = [axes_labels[index] if index in axes_labels else f'dim_{index}' for index in range(data.ndim)]
50+
axis_labels[data.ndim + index] = axis_labels.pop(index)
51+
labels = [axis_labels[index] if index in axis_labels else f'dim_{index}' for index in range(data.ndim)]
5252
else:
53-
labels = list(axes_labels)
53+
labels = list(axis_labels)
5454

5555
# check if the required labels are present
5656
if required_labels is not None:

tests/test_common_usage.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66

77
input_data = np.random.random((1, 224, 224, 3))
8-
axes_labels = {0: 'batch', -1: 'channels'}
8+
axis_labels = {0: 'batch', -1: 'channels'}
99

1010

1111
def test_common_RISE_pipeline(): # noqa: N802 ignore case
12-
heatmap = dianna.explain_image(run_model, input_data, method="RISE", axes_labels=axes_labels)[0]
12+
heatmap = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels)[0]
1313
dianna.visualization.plot_image(heatmap, show_plot=False)
1414
dianna.visualization.plot_image(heatmap, original_data=input_data[0], show_plot=False)

tests/test_lime.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_lime_function(self):
1717
labels = ('batch', 'y', 'x', 'channels')
1818
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy')
1919

20-
explainer = LIME(random_state=42, axes_labels=labels)
20+
explainer = LIME(random_state=42, axis_labels=labels)
2121
heatmap = explainer.explain_image(run_model, input_data, num_samples=100)
2222

2323
assert heatmap.shape == input_data[0].shape[:2]
@@ -38,7 +38,7 @@ def preprocess(data):
3838
labels = ('batch', 'channels', 'y', 'x')
3939

4040
heatmap = dianna.explain_image(model_filename, input_data, method="LIME", preprocess_function=preprocess, random_state=42,
41-
axes_labels=labels)
41+
axis_labels=labels)
4242

4343
heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy')
4444
assert heatmap.shape == input_data[0, 0].shape

tests/test_rise.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def test_rise_function(self):
1616
"""Test if rise runs and outputs the correct shape given some data and a model function."""
1717
input_data = np.random.random((1, 224, 224, 3))
1818
# y and x axis labels are not actually mandatory for this test
19-
axes_labels = ['batch', 'y', 'x', 'channels']
19+
axis_labels = ['batch', 'y', 'x', 'channels']
2020

21-
heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axes_labels=axes_labels, n_masks=200)
21+
heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels, n_masks=200)
2222

2323
assert heatmaps[0].shape == input_data[0].shape[:2]
2424

@@ -27,9 +27,9 @@ def test_rise_filename(self):
2727
model_filename = 'tests/test_data/mnist_model.onnx'
2828
input_data = generate_data(batch_size=1).astype(np.float32)
2929
# y and x axis labels are not actually mandatory for this test
30-
axes_labels = ['batch', 'channels', 'y', 'x']
30+
axis_labels = ['batch', 'channels', 'y', 'x']
3131

32-
heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", axes_labels=axes_labels, n_masks=200)
32+
heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", axis_labels=axis_labels, n_masks=200)
3333

3434
assert heatmaps[0].shape == input_data[0].shape[1:]
3535

tutorials/rise_mnist.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@
155155
}
156156
],
157157
"source": [
158-
"explainer = RISE(n_masks=2000, feature_res=8, p_keep=.8, axes_labels=axis_labels)\n",
158+
"explainer = RISE(n_masks=2000, feature_res=8, p_keep=.8, axis_labels=axis_labels)\n",
159159
"heatmaps = explainer.explain_image(run_model, X_test[[i_instance]])"
160160
]
161161
},

0 commit comments

Comments
 (0)