Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make label a required argument (Fixes #131) #426

Merged
merged 4 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dashboard/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def global_store_i(method_sel, model_path, image_test, labels=list(range(2)),
elif method_sel == "KernelSHAP":
relevances = dianna.explain_image(
model_path, image_test,
labels=labels,
method=method_sel, nsamples=n_samples,
background=background, n_segments=n_segments, sigma=sigma,
axis_labels=axis_labels)

else:
relevances = dianna.explain_image(
model_path, image_test * 256, 'LIME',
Expand Down
8 changes: 4 additions & 4 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
__version__ = "0.6.0"


def explain_image(model_or_function, input_data, method, labels=(1,), **kwargs):
def explain_image(model_or_function, input_data, method, labels, **kwargs):
"""
Explain an image (input_data) given a model and a chosen method.

Expand All @@ -42,7 +42,7 @@ def explain_image(model_or_function, input_data, method, labels=(1,), **kwargs):
the path to a ONNX model on disk.
input_data (np.ndarray): Image data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (tuple): Labels to be explained
labels (Iterable(int)): Labels to be explained

Returns:
One heatmap (2D array) per class.
Expand All @@ -56,7 +56,7 @@ def explain_image(model_or_function, input_data, method, labels=(1,), **kwargs):
return explainer.explain(model_or_function, input_data, labels, **explain_image_kwargs)


def explain_text(model_or_function, input_text, tokenizer, method, labels=(1,), **kwargs):
def explain_text(model_or_function, input_text, tokenizer, method, labels, **kwargs):
"""
Explain text (input_text) given a model and a chosen method.

Expand All @@ -66,7 +66,7 @@ def explain_text(model_or_function, input_text, tokenizer, method, labels=(1,),
input_text (string): Text to be explained
tokenizer : Tokenizer class with tokenize and convert_tokens_to_string methods, and mask_token attribute
method (string): One of the supported methods: RISE or LIME
labels (tuple): Labels to be explained
labels (Iterable(int)): Labels to be explained

Returns:
List of (word, index of word in raw text, importance for target class) tuples.
Expand Down
4 changes: 2 additions & 2 deletions dianna/methods/kernelshap.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def explain(
self,
model,
input_data,
labels=(0,),
labels,
nsamples="auto",
background=None,
n_segments=100,
Expand All @@ -84,7 +84,7 @@ def explain(
example. The input dimension must be
[batch, height, width, color_channels] or
[batch, color_channels, height, width] (see axis_labels)
labels (tuple): Indices of classes to be explained
labels (Iterable(int)): Indices of classes to be explained
nsamples ("auto" or int): Number of times to re-evaluate the model when
explaining each prediction. More samples lead
to lower variance estimates of the SHAP values.
Expand Down
8 changes: 4 additions & 4 deletions dianna/methods/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self,
def explain(self,
model_or_function,
input_text,
labels=(0,),
labels,
tokenizer=None,
top_labels=None,
num_features=10,
Expand All @@ -67,7 +67,7 @@ def explain(self,
the path to a ONNX model on disk.
tokenizer : Tokenizer class with tokenize and convert_tokens_to_string methods, and mask_token attribute
input_text (np.ndarray): Data to be explained
labels ([int], optional): Iterable of indices of class to be explained
labels (Iterable(int)): Iterable of indices of class to be explained

Other keyword arguments: see the LIME documentation for LimeTextExplainer.explain_instance:
https://lime-ml.readthedocs.io/en/latest/lime.html#lime.lime_text.LimeTextExplainer.explain_instance.
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(self,
def explain(self,
model_or_function,
input_data,
labels=(1,),
labels,
top_labels=None,
num_features=10,
num_samples=5000,
Expand All @@ -161,7 +161,7 @@ def explain(self,
the path to a ONNX model on disk.
input_data (np.ndarray): Data to be explained. Must be an "RGB image", i.e. with values in
the [0,255] range.
labels (tuple): Indices of classes to be explained
labels (Iterable(int)): Indices of classes to be explained
Other keyword arguments: see the LIME documentation for LimeImageExplainer.explain_instance and
ImageExplanation.get_image_and_mask:

Expand Down
8 changes: 4 additions & 4 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=None,
self.masks = None
self.predictions = None

def explain(self, model_or_function, input_text, labels=(0,), tokenizer=None, batch_size=100):
def explain(self, model_or_function, input_text, labels, tokenizer=None, batch_size=100):
"""Runs the RISE explainer on text.

The model will be called with masked versions of the input text.
Expand All @@ -54,7 +54,7 @@ def explain(self, model_or_function, input_text, labels=(0,), tokenizer=None, ba
the path to a ONNX model on disk.
input_text (np.ndarray): Text to be explained
tokenizer: Tokenizer class with tokenize and convert_tokens_to_string methods, and mask_token attribute
labels (list(int)): Labels to be explained
labels (Iterable(int)): Labels to be explained
batch_size (int): Batch size to use for running the model.

Returns:
Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=None,
self.predictions = None
self.axis_labels = axis_labels if axis_labels is not None else []

def explain(self, model_or_function, input_data, labels=None, batch_size=100):
def explain(self, model_or_function, input_data, labels, batch_size=100):
"""Runs the RISE explainer on images.

The model will be called with masked images,
Expand All @@ -158,7 +158,7 @@ def explain(self, model_or_function, input_data, labels=None, batch_size=100):
the path to a ONNX model on disk.
input_data (np.ndarray): Image to be explained
batch_size (int): Batch size to use for running the model.
labels (tuple): Labels to be explained
labels (Iterable(int)): Labels to be explained

Returns:
Explanation heatmap for each class (np.ndarray).
Expand Down
5 changes: 3 additions & 2 deletions tests/test_common_usage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import numpy as np

import dianna
import dianna.visualization
from tests.utils import run_model


input_data = np.random.random((224, 224, 3))
axis_labels = {-1: 'channels'}
labels = [0, 1]


def test_common_RISE_pipeline(): # noqa: N802 ignore case
heatmap = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels)[0]
heatmap = dianna.explain_image(run_model, input_data, "RISE", labels, axis_labels=axis_labels)[0]
dianna.visualization.plot_image(heatmap, show_plot=False)
dianna.visualization.plot_image(heatmap, original_data=input_data[0], show_plot=False)
2 changes: 2 additions & 0 deletions tests/test_kernelshap.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def test_shap_explain_image(self):
onnx_model_path = "./tests/test_data/mnist_model.onnx"
n_segments = 50
explainer = KERNELSHAPImage()
labels = [0]
shap_values, _ = explainer.explain(
onnx_model_path,
input_data,
labels,
nsamples=1000,
background=0,
n_segments=n_segments,
Expand Down
10 changes: 6 additions & 4 deletions tests/test_lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def test_lime_function(self):
np.random.seed(42)
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model, input_data, num_samples=100)
heatmap = explainer.explain(run_model, input_data, labels, num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)
Expand All @@ -27,10 +28,11 @@ def test_lime_filename(self):
np.random.seed(42)
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1)[0].astype(np.float32)
labels = ('channels', 'y', 'x')
axis_labels = ('channels', 'y', 'x')
labels = [1]

heatmap = dianna.explain_image(model_filename, input_data, method="LIME", random_state=42,
axis_labels=labels)
heatmap = dianna.explain_image(model_filename, input_data, method="LIME", labels=labels, random_state=42,
axis_labels=axis_labels)

heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy')
assert heatmap[0].shape == input_data[0].shape
Expand Down
11 changes: 7 additions & 4 deletions tests/test_rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,25 @@ def test_rise_function(self):
"""Test if rise runs and outputs the correct shape given some data and a model function."""
input_data = np.random.random((224, 224, 3))
axis_labels = ['y', 'x', 'channels']

heatmaps = dianna.explain_image(run_model, input_data, method="RISE", axis_labels=axis_labels, n_masks=200, p_keep=.5)
labels = [1]
heatmaps_expected = np.load('tests/test_data/heatmap_rise_function.npy')

heatmaps = dianna.explain_image(run_model, input_data, "RISE", labels, axis_labels=axis_labels, n_masks=200, p_keep=.5)

assert heatmaps[0].shape == input_data.shape[:2]
assert np.allclose(heatmaps, heatmaps_expected, atol=1e-5)

def test_rise_filename(self):
"""Test if rise runs and outputs the correct shape given some data and a model file."""
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1).astype(np.float32)[0]

heatmaps = dianna.explain_image(model_filename, input_data, method="RISE", n_masks=200, p_keep=.5)
heatmaps_expected = np.load('tests/test_data/heatmap_rise_filename.npy')
labels = [1]

heatmaps = dianna.explain_image(model_filename, input_data, "RISE", labels, n_masks=200, p_keep=.5)

assert heatmaps[0].shape == input_data.shape[1:]
print(heatmaps_expected.shape)
assert np.allclose(heatmaps, heatmaps_expected, atol=1e-5)

def test_rise_determine_p_keep_for_images(self):
Expand Down
Loading