Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename axes_labels parameter to axis_labels #109

Merged
merged 1 commit into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dianna/methods/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self,
mask_string=None,
random_state=None,
char_level=False,
axes_labels=None,
axis_labels=None,
preprocess_function=None,
): # pylint: disable=too-many-arguments
"""
Expand All @@ -37,7 +37,7 @@ def __init__(self,
mask_string (str, optional): mask string
random_state (int or np.RandomState, optional): seed or random state
char_level (bool, optional): char level
axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
If a list, the name of each axis where the index
in the list is the axis index
preprocess_function (callable, optional): Function to preprocess input data with
Expand All @@ -62,7 +62,7 @@ def __init__(self,
)

self.preprocess_function = preprocess_function
self.axes_labels = axes_labels if axes_labels is not None else []
self.axis_labels = axis_labels if axis_labels is not None else []

def explain_text(self,
model_or_function,
Expand Down Expand Up @@ -165,7 +165,7 @@ def _prepare_image_data(self, input_data):
Returns:
transformed input data, preprocessing function to use with utils.get_function()
"""
input_data = utils.to_xarray(input_data, self.axes_labels, LIME.required_labels)
input_data = utils.to_xarray(input_data, self.axis_labels, LIME.required_labels)
# remove batch axis from input data; this is only here for a consistent API
# but LIME wants data without batch axis
if not len(input_data['batch']) == 1:
Expand Down
8 changes: 4 additions & 4 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ class RISE:
required_labels = ('batch', 'channels')

def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=too-many-arguments
axes_labels=None, preprocess_function=None):
axis_labels=None, preprocess_function=None):
"""RISE initializer.

Args:
n_masks (int): Number of masks to generate.
feature_res (int): Resolution of features in masks.
p_keep (float): Fraction of image to keep in each mask
axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
If a list, the name of each axis where the index
in the list is the axis index
preprocess_function (callable, optional): Function to preprocess input data with
Expand All @@ -37,7 +37,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=t
self.preprocess_function = preprocess_function
self.masks = None
self.predictions = None
self.axes_labels = axes_labels if axes_labels is not None else []
self.axis_labels = axis_labels if axis_labels is not None else []

def explain_text(self, model_or_function, input_text, labels=(0,), batch_size=100):
"""Runs the RISE explainer on text.
Expand Down Expand Up @@ -135,7 +135,7 @@ def explain_image(self, model_or_function, input_data, batch_size=100):
Explanation heatmap for each class (np.ndarray).
"""
# convert data to xarray
input_data = utils.to_xarray(input_data, self.axes_labels, RISE.required_labels)
input_data = utils.to_xarray(input_data, self.axis_labels, RISE.required_labels)
# batch axis should always be first
input_data = utils.move_axis(input_data, 'batch', 0)
input_data, full_preprocess_function = self._prepare_image_data(input_data)
Expand Down
12 changes: 6 additions & 6 deletions dianna/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@ def get_kwargs_applicable_to_function(function, kwargs):
if key in function.__code__.co_varnames}


def to_xarray(data, axes_labels, required_labels=None):
def to_xarray(data, axis_labels, required_labels=None):
"""Converts numpy data and axes labels to an xarray object."""
if isinstance(axes_labels, dict):
if isinstance(axis_labels, dict):
# key = axis index, value = label
# not all axes have to be present in the input, but we need to provide
# a name for each axis
# first ensure negative indices are converted to positive ones
indices = list(axes_labels.keys())
indices = list(axis_labels.keys())
for index in indices:
if index < 0:
axes_labels[data.ndim + index] = axes_labels.pop(index)
labels = [axes_labels[index] if index in axes_labels else f'dim_{index}' for index in range(data.ndim)]
axis_labels[data.ndim + index] = axis_labels.pop(index)
labels = [axis_labels[index] if index in axis_labels else f'dim_{index}' for index in range(data.ndim)]
else:
labels = list(axes_labels)
labels = list(axis_labels)

# check if the required labels are present
if required_labels is not None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@


input_data = np.random.random((1, 224, 224, 3))
axes_labels = {0: 'batch', -1: 'channels'}
axis_labels = {0: 'batch', -1: 'channels'}


def test_common_RISE_pipeline(): # noqa: N802 ignore case
heatmap = dianna.explain_image(run_model, input_data, method="RISE", axes_labels=axes_labels)[0]
heatmap = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels)[0]
dianna.visualization.plot_image(heatmap, show_plot=False)
dianna.visualization.plot_image(heatmap, original_data=input_data[0], show_plot=False)
4 changes: 2 additions & 2 deletions tests/test_lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_lime_function(self):
labels = ('batch', 'y', 'x', 'channels')
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy')

explainer = LIME(random_state=42, axes_labels=labels)
explainer = LIME(random_state=42, axis_labels=labels)
heatmap = explainer.explain_image(run_model, input_data, num_samples=100)

assert heatmap.shape == input_data[0].shape[:2]
Expand All @@ -38,7 +38,7 @@ def preprocess(data):
labels = ('batch', 'channels', 'y', 'x')

heatmap = dianna.explain_image(model_filename, input_data, method="LIME", preprocess_function=preprocess, random_state=42,
axes_labels=labels)
axis_labels=labels)

heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy')
assert heatmap.shape == input_data[0, 0].shape
Expand Down
8 changes: 4 additions & 4 deletions tests/test_rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def test_rise_function(self):
"""Test if rise runs and outputs the correct shape given some data and a model function."""
input_data = np.random.random((1, 224, 224, 3))
# y and x axis labels are not actually mandatory for this test
axes_labels = ['batch', 'y', 'x', 'channels']
axis_labels = ['batch', 'y', 'x', 'channels']

heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axes_labels=axes_labels, n_masks=200)
heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels, n_masks=200)

assert heatmaps[0].shape == input_data[0].shape[:2]

Expand All @@ -27,9 +27,9 @@ def test_rise_filename(self):
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1).astype(np.float32)
# y and x axis labels are not actually mandatory for this test
axes_labels = ['batch', 'channels', 'y', 'x']
axis_labels = ['batch', 'channels', 'y', 'x']

heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", axes_labels=axes_labels, n_masks=200)
heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", axis_labels=axis_labels, n_masks=200)

assert heatmaps[0].shape == input_data[0].shape[1:]

Expand Down
2 changes: 1 addition & 1 deletion tutorials/rise_mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
}
],
"source": [
"explainer = RISE(n_masks=2000, feature_res=8, p_keep=.8, axes_labels=axis_labels)\n",
"explainer = RISE(n_masks=2000, feature_res=8, p_keep=.8, axis_labels=axis_labels)\n",
"heatmaps = explainer.explain_image(run_model, X_test[[i_instance]])"
]
},
Expand Down