diff --git a/dianna/methods/lime.py b/dianna/methods/lime.py index 6d7489fa..46a9761a 100644 --- a/dianna/methods/lime.py +++ b/dianna/methods/lime.py @@ -20,7 +20,7 @@ def __init__(self, mask_string=None, random_state=None, char_level=False, - axes_labels=None, + axis_labels=None, preprocess_function=None, ): # pylint: disable=too-many-arguments """ @@ -37,7 +37,7 @@ def __init__(self, mask_string (str, optional): mask string random_state (int or np.RandomState, optional): seed or random state char_level (bool, optional): char level - axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name. + axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name. If a list, the name of each axis where the index in the list is the axis index preprocess_function (callable, optional): Function to preprocess input data with @@ -62,7 +62,7 @@ def __init__(self, ) self.preprocess_function = preprocess_function - self.axes_labels = axes_labels if axes_labels is not None else [] + self.axis_labels = axis_labels if axis_labels is not None else [] def explain_text(self, model_or_function, @@ -165,7 +165,7 @@ def _prepare_image_data(self, input_data): Returns: transformed input data, preprocessing function to use with utils.get_function() """ - input_data = utils.to_xarray(input_data, self.axes_labels, LIME.required_labels) + input_data = utils.to_xarray(input_data, self.axis_labels, LIME.required_labels) # remove batch axis from input data; this is only here for a consistent API # but LIME wants data without batch axis if not len(input_data['batch']) == 1: diff --git a/dianna/methods/rise.py b/dianna/methods/rise.py index e53e27be..3e45f44c 100644 --- a/dianna/methods/rise.py +++ b/dianna/methods/rise.py @@ -19,14 +19,14 @@ class RISE: required_labels = ('batch', 'channels') def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=too-many-arguments - axes_labels=None, preprocess_function=None): + axis_labels=None, preprocess_function=None): """RISE initializer. Args: n_masks (int): Number of masks to generate. feature_res (int): Resolution of features in masks. p_keep (float): Fraction of image to keep in each mask - axes_labels (dict/list, optional): If a dict, key,value pairs of axis index, name. + axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name. If a list, the name of each axis where the index in the list is the axis index preprocess_function (callable, optional): Function to preprocess input data with @@ -37,7 +37,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, # pylint: disable=t self.preprocess_function = preprocess_function self.masks = None self.predictions = None - self.axes_labels = axes_labels if axes_labels is not None else [] + self.axis_labels = axis_labels if axis_labels is not None else [] def explain_text(self, model_or_function, input_text, labels=(0,), batch_size=100): """Runs the RISE explainer on text. @@ -135,7 +135,7 @@ def explain_image(self, model_or_function, input_data, batch_size=100): Explanation heatmap for each class (np.ndarray). """ # convert data to xarray - input_data = utils.to_xarray(input_data, self.axes_labels, RISE.required_labels) + input_data = utils.to_xarray(input_data, self.axis_labels, RISE.required_labels) # batch axis should always be first input_data = utils.move_axis(input_data, 'batch', 0) input_data, full_preprocess_function = self._prepare_image_data(input_data) diff --git a/dianna/utils/misc.py b/dianna/utils/misc.py index 25626d10..97fa56db 100644 --- a/dianna/utils/misc.py +++ b/dianna/utils/misc.py @@ -37,20 +37,20 @@ def get_kwargs_applicable_to_function(function, kwargs): if key in function.__code__.co_varnames} -def to_xarray(data, axes_labels, required_labels=None): +def to_xarray(data, axis_labels, required_labels=None): """Converts numpy data and axes labels to an xarray object.""" - if isinstance(axes_labels, dict): + if isinstance(axis_labels, dict): # key = axis index, value = label # not all axes have to be present in the input, but we need to provide # a name for each axis # first ensure negative indices are converted to positive ones - indices = list(axes_labels.keys()) + indices = list(axis_labels.keys()) for index in indices: if index < 0: - axes_labels[data.ndim + index] = axes_labels.pop(index) - labels = [axes_labels[index] if index in axes_labels else f'dim_{index}' for index in range(data.ndim)] + axis_labels[data.ndim + index] = axis_labels.pop(index) + labels = [axis_labels[index] if index in axis_labels else f'dim_{index}' for index in range(data.ndim)] else: - labels = list(axes_labels) + labels = list(axis_labels) # check if the required labels are present if required_labels is not None: diff --git a/tests/test_common_usage.py b/tests/test_common_usage.py index 5b3591fa..0bc07984 100644 --- a/tests/test_common_usage.py +++ b/tests/test_common_usage.py @@ -5,10 +5,10 @@ input_data = np.random.random((1, 224, 224, 3)) -axes_labels = {0: 'batch', -1: 'channels'} +axis_labels = {0: 'batch', -1: 'channels'} def test_common_RISE_pipeline(): # noqa: N802 ignore case - heatmap = dianna.explain_image(run_model, input_data, method="RISE", axes_labels=axes_labels)[0] + heatmap = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels)[0] dianna.visualization.plot_image(heatmap, show_plot=False) dianna.visualization.plot_image(heatmap, original_data=input_data[0], show_plot=False) diff --git a/tests/test_lime.py b/tests/test_lime.py index 5a8706df..78676b3d 100644 --- a/tests/test_lime.py +++ b/tests/test_lime.py @@ -17,7 +17,7 @@ def test_lime_function(self): labels = ('batch', 'y', 'x', 'channels') heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy') - explainer = LIME(random_state=42, axes_labels=labels) + explainer = LIME(random_state=42, axis_labels=labels) heatmap = explainer.explain_image(run_model, input_data, num_samples=100) assert heatmap.shape == input_data[0].shape[:2] @@ -38,7 +38,7 @@ def preprocess(data): labels = ('batch', 'channels', 'y', 'x') heatmap = dianna.explain_image(model_filename, input_data, method="LIME", preprocess_function=preprocess, random_state=42, - axes_labels=labels) + axis_labels=labels) heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy') assert heatmap.shape == input_data[0, 0].shape diff --git a/tests/test_rise.py b/tests/test_rise.py index 71df5810..7bc3cad0 100644 --- a/tests/test_rise.py +++ b/tests/test_rise.py @@ -16,9 +16,9 @@ def test_rise_function(self): """Test if rise runs and outputs the correct shape given some data and a model function.""" input_data = np.random.random((1, 224, 224, 3)) # y and x axis labels are not actually mandatory for this test - axes_labels = ['batch', 'y', 'x', 'channels'] + axis_labels = ['batch', 'y', 'x', 'channels'] - heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axes_labels=axes_labels, n_masks=200) + heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels, n_masks=200) assert heatmaps[0].shape == input_data[0].shape[:2] @@ -27,9 +27,9 @@ def test_rise_filename(self): model_filename = 'tests/test_data/mnist_model.onnx' input_data = generate_data(batch_size=1).astype(np.float32) # y and x axis labels are not actually mandatory for this test - axes_labels = ['batch', 'channels', 'y', 'x'] + axis_labels = ['batch', 'channels', 'y', 'x'] - heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", axes_labels=axes_labels, n_masks=200) + heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", axis_labels=axis_labels, n_masks=200) assert heatmaps[0].shape == input_data[0].shape[1:] diff --git a/tutorials/rise_mnist.ipynb b/tutorials/rise_mnist.ipynb index 2f45bffe..6ede5f90 100644 --- a/tutorials/rise_mnist.ipynb +++ b/tutorials/rise_mnist.ipynb @@ -155,7 +155,7 @@ } ], "source": [ - "explainer = RISE(n_masks=2000, feature_res=8, p_keep=.8, axes_labels=axis_labels)\n", + "explainer = RISE(n_masks=2000, feature_res=8, p_keep=.8, axis_labels=axis_labels)\n", "heatmaps = explainer.explain_image(run_model, X_test[[i_instance]])" ] },