diff --git a/dianna/__init__.py b/dianna/__init__.py
index 8070944d..5e9d0c42 100644
--- a/dianna/__init__.py
+++ b/dianna/__init__.py
@@ -131,7 +131,7 @@ def explain_text(model_or_function: Union[Callable,
def explain_tabular(model_or_function: Union[Callable, str],
input_tabular: np.ndarray,
method: str,
- labels=(1, ),
+ labels=None,
**kwargs) -> np.ndarray:
"""Explain tabular (input_text) given a model and a chosen method.
diff --git a/dianna/methods/kernelshap_tabular.py b/dianna/methods/kernelshap_tabular.py
index 336fadd8..19b8125c 100644
--- a/dianna/methods/kernelshap_tabular.py
+++ b/dianna/methods/kernelshap_tabular.py
@@ -34,7 +34,8 @@ def __init__(
weighted kmeans
"""
if training_data_kmeans:
- self.training_data = shap.kmeans(training_data, training_data_kmeans)
+ self.training_data = shap.kmeans(training_data,
+ training_data_kmeans)
else:
self.training_data = training_data
self.feature_names = feature_names
@@ -65,17 +66,15 @@ def explain(
An array (np.ndarray) containing the KernelExplainer explanations for each class.
"""
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
- KernelExplainer, kwargs
- )
- self.explainer = KernelExplainer(
- model_or_function, self.training_data, link, **init_instance_kwargs
- )
+ KernelExplainer, kwargs)
+ self.explainer = KernelExplainer(model_or_function, self.training_data,
+ link, **init_instance_kwargs)
explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
- self.explainer.shap_values, kwargs
- )
+ self.explainer.shap_values, kwargs)
- saliency = self.explainer.shap_values(input_tabular, **explain_instance_kwargs)
+ saliency = self.explainer.shap_values(input_tabular,
+ **explain_instance_kwargs)
if self.mode == 'regression':
saliency = saliency[0]
diff --git a/dianna/methods/lime_tabular.py b/dianna/methods/lime_tabular.py
index 23db971a..f6b4b1fd 100644
--- a/dianna/methods/lime_tabular.py
+++ b/dianna/methods/lime_tabular.py
@@ -1,6 +1,8 @@
"""LIME tabular explainer."""
+import sys
from typing import Iterable
from typing import List
+from typing import Optional
from typing import Union
import numpy as np
from lime.lime_tabular import LimeTabularExplainer
@@ -58,12 +60,10 @@ def __init__(
"""
self.mode = mode
init_instance_kwargs = utils.get_kwargs_applicable_to_function(
- LimeTabularExplainer, kwargs
- )
+ LimeTabularExplainer, kwargs)
# temporary solution for setting num_features and top_labels
self.num_features = len(feature_names)
- self.top_labels = len(class_names)
self.explainer = LimeTabularExplainer(
training_data,
@@ -83,7 +83,7 @@ def explain(
self,
model_or_function: Union[str, callable],
input_tabular: np.array,
- labels: Iterable[int] = (1,),
+ labels: Optional[Iterable[int]] = None,
num_samples: int = 5000,
**kwargs,
) -> np.array:
@@ -93,7 +93,7 @@ def explain(
model_or_function (callable or str): The function that runs the model to be explained
or the path to a ONNX model on disk.
input_tabular (np.ndarray): Data to be explained.
- labels (Iterable(int), optional): Indices of classes to be explained.
+ labels (Iterable(int)): Indices of classes to be explained.
num_samples (int, optional): Number of samples
kwargs: These parameters are passed on
@@ -105,15 +105,14 @@ def explain(
"""
# run the explanation.
explain_instance_kwargs = utils.get_kwargs_applicable_to_function(
- self.explainer.explain_instance, kwargs
- )
+ self.explainer.explain_instance, kwargs)
runner = utils.get_function(model_or_function)
explanation = self.explainer.explain_instance(
input_tabular,
runner,
labels=labels,
- top_labels=self.top_labels,
+ top_labels=sys.maxsize,
num_features=self.num_features,
num_samples=num_samples,
**explain_instance_kwargs,
@@ -126,10 +125,13 @@ def explain(
elif self.mode == 'classification':
# extract scores from lime explainer
saliency = []
- for i in range(self.top_labels):
+ for i in range(len(explanation.local_exp.items())):
local_exp = sorted(explanation.local_exp[i])
# shape of local_exp [(index, saliency)]
selected_saliency = [x[1] for x in local_exp]
saliency.append(selected_saliency[:])
+ else:
+ raise ValueError(f'Unsupported mode "{self.mode}"')
+
return np.array(saliency)
diff --git a/dianna/methods/lime_timeseries.py b/dianna/methods/lime_timeseries.py
index d1851e2c..d6b590d7 100644
--- a/dianna/methods/lime_timeseries.py
+++ b/dianna/methods/lime_timeseries.py
@@ -81,7 +81,7 @@ def explain(
# wrap up the input model or function using the runner
runner = utils.get_function(
model_or_function, preprocess_function=self.preprocess_function)
- masks = generate_time_series_masks(input_timeseries,
+ masks = generate_time_series_masks(input_timeseries.shape,
num_samples,
p_keep=0.1)
# NOTE: Required by `lime_base` explainer since the first instance must be the original data
diff --git a/dianna/methods/rise_tabular.py b/dianna/methods/rise_tabular.py
new file mode 100644
index 00000000..37ddac15
--- /dev/null
+++ b/dianna/methods/rise_tabular.py
@@ -0,0 +1,111 @@
+"""RISE tabular explainer."""
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Union
+import numpy as np
+from dianna import utils
+from dianna.utils.maskers import generate_tabular_masks
+from dianna.utils.maskers import mask_data_tabular
+from dianna.utils.predict import make_predictions
+from dianna.utils.rise_utils import normalize
+
+
+class RISETabular:
+ """RISE explainer for tabular data."""
+
+ def __init__(
+ self,
+ training_data: np.array,
+ mode: str = "classification",
+ feature_names: List[str] = None,
+ categorical_features: List[int] = None,
+ n_masks: int = 1000,
+ feature_res: int = 8,
+ p_keep: float = 0.5,
+ preprocess_function: Optional[callable] = None,
+ class_names=None,
+ keep_masks: bool = False,
+ keep_masked: bool = False,
+ keep_predictions: bool = False,
+ ) -> np.ndarray:
+ """RISE initializer.
+
+ Args:
+ n_masks: Number of masks to generate.
+ feature_res: Resolution of features in masks.
+ p_keep: Fraction of input data to keep in each mask (Default: auto-tune this value).
+ preprocess_function: Function to preprocess input data with
+ categorical_features: list of categorical features
+ class_names: Names of the classes
+ feature_names: Names of the features
+ mode: Either classification of regression
+ training_data: Training data used for imputation of masked features
+ keep_masks: keep masks in memory for the user to inspect
+ keep_masked: keep masked data in memory for the user to inspect
+ keep_predictions: keep model predictions in memory for the user to inspect
+ """
+ self.training_data = training_data
+ self.n_masks = n_masks
+ self.feature_res = feature_res
+ self.p_keep = p_keep
+ self.preprocess_function = preprocess_function
+ self.masks = None
+ self.masked = None
+ self.predictions = None
+ self.keep_masks = keep_masks
+ self.keep_masked = keep_masked
+ self.keep_predictions = keep_predictions
+ self.mode = mode
+
+ def explain(
+ self,
+ model_or_function: Union[str, callable],
+ input_tabular: np.array,
+ labels: Optional[Iterable[int]] = None,
+ mask_type: Optional[Union[str, callable]] = 'most_frequent',
+ batch_size: Optional[int] = 100,
+ ) -> np.array:
+ """Run the RISE explainer.
+
+ Args:
+ model_or_function: The function that runs the model to be explained
+ or the path to a ONNX model on disk.
+ input_tabular: Data to be explained.
+ labels: Indices of classes to be explained.
+ num_samples: Number of samples
+ mask_type: Imputation strategy for masked features
+ batch_size: Number of samples to process by the model per batch
+
+ Returns:
+ explanation: An Explanation object containing the LIME explanations for each class.
+ """
+ # run the explanation.
+ runner = utils.get_function(model_or_function)
+
+ masks = np.stack(
+ list(
+ generate_tabular_masks(input_tabular.shape,
+ number_of_masks=self.n_masks,
+ p_keep=self.p_keep)))
+ self.masks = masks if self.keep_masks else None
+
+ masked = mask_data_tabular(input_tabular,
+ masks,
+ self.training_data,
+ mask_type=mask_type)
+ self.masked = masked if self.keep_masked else None
+ predictions = make_predictions(masked, runner, batch_size)
+ self.predictions = predictions if self.keep_predictions else None
+ n_labels = predictions.shape[1]
+
+ masks_reshaped = masks.reshape(self.n_masks, -1)
+
+ saliency = predictions.T.dot(masks_reshaped).reshape(
+ n_labels, *input_tabular.shape)
+
+ if self.mode == 'regression':
+ return saliency[0]
+
+ selected_saliency = saliency if labels is None else saliency[labels]
+ return normalize(selected_saliency, self.n_masks, self.p_keep)
diff --git a/dianna/methods/rise_timeseries.py b/dianna/methods/rise_timeseries.py
index 059d54d9..7479a8f0 100644
--- a/dianna/methods/rise_timeseries.py
+++ b/dianna/methods/rise_timeseries.py
@@ -68,7 +68,7 @@ def explain(self,
runner = utils.get_function(
model_or_function, preprocess_function=self.preprocess_function)
- masks = generate_time_series_masks(input_timeseries,
+ masks = generate_time_series_masks(input_timeseries.shape,
number_of_masks=self.n_masks,
feature_res=self.feature_res,
p_keep=self.p_keep)
diff --git a/dianna/utils/maskers.py b/dianna/utils/maskers.py
index 4c1bf916..5cb59a1a 100644
--- a/dianna/utils/maskers.py
+++ b/dianna/utils/maskers.py
@@ -5,10 +5,40 @@
import numpy as np
from numpy import ndarray
from skimage.transform import resize
+from sklearn.impute import SimpleImputer
+
+
+def generate_tabular_masks(
+ input_data_shape: tuple[int],
+ number_of_masks: int,
+ p_keep: float = 0.5,
+):
+ """Generator function to create masks for tabular data.
+
+ Args:
+ input_data_shape: Shape of the tabular data to be masked.
+ number_of_masks: Number of masks to generate.
+ p_keep: probability that any value should remain unmasked.
+
+ Returns:
+ Single array containing all masks where the first dimension represents the batch.
+ """
+ instance_length = np.product(input_data_shape)
+
+ for i in range(number_of_masks):
+ n_masked = _determine_number_masked(p_keep, instance_length)
+ trues = n_masked * [False]
+ falses = (instance_length - n_masked) * [True]
+ options = trues + falses
+ yield np.random.choice(
+ a=options,
+ size=input_data_shape,
+ replace=False,
+ )
def generate_time_series_masks(
- input_data: np.ndarray,
+ input_data_shape: tuple[int],
number_of_masks: int,
feature_res: int = 8,
p_keep: float = 0.5,
@@ -26,7 +56,7 @@ def generate_time_series_masks(
For univariate data, only time step masks are returned.
Args:
- input_data: Timeseries data to be masked.
+ input_data_shape: Shape of the time series data to be masked.
number_of_masks: Number of masks to generate.
p_keep: the probability that any value remains unmasked.
feature_res: Resolution of features in masks.
@@ -34,8 +64,8 @@ def generate_time_series_masks(
Returns:
Single array containing all masks where the first dimension represents the batch.
"""
- if input_data.shape[-1] == 1: # univariate data
- return generate_time_step_masks(input_data,
+ if input_data_shape[-1] == 1: # univariate data
+ return generate_time_step_masks(input_data_shape,
number_of_masks,
p_keep,
number_of_features=feature_res)
@@ -45,30 +75,30 @@ def generate_time_series_masks(
number_of_time_step_masks = number_of_channel_masks
number_of_combined_masks = number_of_masks - number_of_time_step_masks - number_of_channel_masks
- time_step_masks = generate_time_step_masks(input_data,
+ time_step_masks = generate_time_step_masks(input_data_shape,
number_of_time_step_masks,
p_keep, feature_res)
- channel_masks = generate_channel_masks(input_data, number_of_channel_masks,
- p_keep)
+ channel_masks = generate_channel_masks(input_data_shape,
+ number_of_channel_masks, p_keep)
# Product of two masks: we need sqrt p_keep to ensure correct resulting p_keep
sqrt_p_keep = np.sqrt(p_keep)
combined_masks = generate_time_step_masks(
- input_data, number_of_combined_masks,
+ input_data_shape, number_of_combined_masks,
sqrt_p_keep, feature_res) * generate_channel_masks(
- input_data, number_of_combined_masks, sqrt_p_keep)
+ input_data_shape, number_of_combined_masks, sqrt_p_keep)
return np.concatenate([time_step_masks, channel_masks, combined_masks],
axis=0)
-def generate_channel_masks(input_data: np.ndarray, number_of_masks: int,
+def generate_channel_masks(input_data_shape: tuple[int], number_of_masks: int,
p_keep: float):
"""Generate masks that mask one or multiple channels independently at a time."""
- number_of_channels = input_data.shape[1]
+ number_of_channels = input_data_shape[1]
number_of_channels_masked = _determine_number_masked(
p_keep, number_of_channels)
- masked_data_shape = [number_of_masks] + list(input_data.shape)
+ masked_data_shape = [number_of_masks] + list(input_data_shape)
masks = np.ones(masked_data_shape, dtype=bool)
for i in range(number_of_masks):
channels_to_mask = np.random.choice(number_of_channels,
@@ -77,6 +107,43 @@ def generate_channel_masks(input_data: np.ndarray, number_of_masks: int,
return masks
+def mask_data_tabular(data: np.array, masks: np.array, training_data: np.array,
+ mask_type: Union[object, str]) -> np.array:
+ """Mask tabular data given using a set of masks.
+
+ Args:
+ data: Input data.
+ masks: an array with shape [number_of_masks] + data.shape
+ mask_type: Masking strategy. Can be 'most_frequent', 'mean' or a function f(data, masks, training_data).
+ training_data: Data used to sample from for imputation of masked values.
+
+ Returns:
+ Single array containing all masked input where the first dimension represents the batch.
+ """
+ if isinstance(mask_type, str):
+
+ def strategy(data, masks, training_data):
+ imputer = SimpleImputer(missing_values=np.nan, strategy=mask_type)
+ imputer.fit(training_data)
+ masked_data_list = []
+ for mask in masks:
+ current_data = np.array(data)
+ current_data[~mask] = np.nan
+ current_data_masked = imputer.transform(current_data[None,
+ ...])[0]
+ masked_data_list.append(current_data_masked)
+ masked_data = np.stack(masked_data_list)
+ return masked_data
+ elif callable(mask_type):
+ strategy = mask_type
+ else:
+ raise ValueError(
+ f'Mask type must be callable or type str but got type `{type(mask_type)}` instead.'
+ )
+
+ return strategy(data, masks, training_data)
+
+
def mask_data(data: np.array, masks: np.array, mask_type: Union[object, str]):
"""Mask data given using a set of masks.
@@ -107,7 +174,9 @@ def _get_mask_value(data: np.array, mask_type: object) -> int:
raise ValueError(f'Unknown mask_type selected: {mask_type}')
-def _determine_number_masked(p_keep: float, series_length: int) -> int:
+def _determine_number_masked(p_keep: float,
+ series_length: int,
+ element_name='feature') -> int:
"""Determine the number of time steps that need to be masked."""
mean = series_length * (1 - p_keep)
floor = np.floor(mean)
@@ -121,26 +190,27 @@ def _determine_number_masked(p_keep: float, series_length: int) -> int:
if user_requested_steps >= series_length:
warnings.warn(
- 'Warning: p_keep chosen too low. Continuing with leaving 1 time step unmasked per mask.'
+ f'Warning: p_keep chosen too low. Continuing with leaving 1 {element_name} unmasked per mask.'
)
return series_length - 1
if user_requested_steps <= 0:
warnings.warn(
- 'Warning: p_keep chosen too high. Continuing with masking 1 time step per mask.'
+ f'Warning: p_keep chosen too high. Continuing with masking 1 {element_name} per mask.'
)
return 1
return user_requested_steps
-def generate_time_step_masks(input_data: np.ndarray, number_of_masks: int,
- p_keep: float, number_of_features: int):
- """Generate masks that masks complete time steps at a time while masking time steps in a segmented fashion.
+def generate_time_step_masks(input_data_shape: tuple[int],
+ number_of_masks: int, p_keep: float,
+ number_of_features: int):
+ """Generate masks that mask all channels simultaneously for clusters of time steps.
For a conceptual description see:
https://medium.com/escience-center/masking-time-series-for-explainable-ai-90247ac252b4.
"""
- time_series_length = input_data.shape[0]
- number_of_channels = input_data.shape[1]
+ time_series_length = input_data_shape[0]
+ number_of_channels = input_data_shape[1]
float_masks = generate_interpolated_float_masks_for_timeseries(
[time_series_length, 1], number_of_masks, number_of_features)[:, :, 0]
@@ -169,7 +239,7 @@ def _mask_bottom_ratio(float_mask: np.ndarray, p_keep: float) -> np.ndarray:
flat = float_mask.flatten()
time_indices = list(range(len(flat)))
number_of_unmasked_cells = _determine_number_masked(
- p_keep, len(time_indices))
+ p_keep, len(time_indices), element_name='time step')
top_indices = heapq.nsmallest(number_of_unmasked_cells,
time_indices,
key=lambda time_step: flat[time_step])
diff --git a/docs/tutorials/rise_tabular_penguin.nblink b/docs/tutorials/rise_tabular_penguin.nblink
new file mode 100644
index 00000000..5aa4971d
--- /dev/null
+++ b/docs/tutorials/rise_tabular_penguin.nblink
@@ -0,0 +1,3 @@
+{
+ "path": "../../tutorials/explainers/RISE/rise_tabular_penguin.ipynb"
+}
diff --git a/tests/methods/test_lime_image.py b/tests/methods/test_lime_image.py
index ae63a294..16797e3f 100644
--- a/tests/methods/test_lime_image.py
+++ b/tests/methods/test_lime_image.py
@@ -4,7 +4,7 @@
import dianna
from dianna.methods.lime_image import LIMEImage
from tests.methods.test_onnx_runner import generate_data
-from tests.utils import run_model
+from tests.utils import get_dummy_model_function
class LimeOnImages(TestCase):
@@ -18,7 +18,7 @@ def test_lime_function():
labels = [1]
explainer = LIMEImage(random_state=42)
- heatmap = explainer.explain(run_model,
+ heatmap = explainer.explain(get_dummy_model_function(n_outputs=2),
input_data,
labels,
num_samples=100)
@@ -53,7 +53,7 @@ def test_lime_values():
labels = [1]
explainer = LIMEImage(random_state=42)
- heatmap = explainer.explain(run_model,
+ heatmap = explainer.explain(get_dummy_model_function(n_outputs=2),
input_data,
labels,
return_masks=False,
diff --git a/tests/methods/test_lime_tabular.py b/tests/methods/test_lime_tabular.py
deleted file mode 100644
index 3627712e..00000000
--- a/tests/methods/test_lime_tabular.py
+++ /dev/null
@@ -1,36 +0,0 @@
-"""Test LIME tabular method."""
-from unittest import TestCase
-import numpy as np
-import dianna
-from dianna.methods.lime_tabular import LIMETabular
-from tests.utils import run_model
-
-
-class LIMEOnTabular(TestCase):
- """Suite of LIME tests for the tabular case."""
-
- def test_lime_tabular_classification_correct_output_shape(self):
- """Test the output of explainer."""
- training_data = np.random.random((10, 2))
- input_data = np.random.random(2)
- feature_names = ["feature_1", "feature_2"]
- explainer = LIMETabular(training_data,
- mode ='classification',
- feature_names=feature_names,
- class_names = ["class_1", "class_2"])
- exp = explainer.explain(
- run_model,
- input_data,
- )
- assert len(exp[0]) == len(feature_names)
-
- def test_lime_tabular_regression_correct_output_shape(self):
- """Test the output of explainer."""
- training_data = np.random.random((10, 2))
- input_data = np.random.random(2)
- feature_names = ["feature_1", "feature_2"]
- exp = dianna.explain_tabular(run_model, input_tabular=input_data, method='lime',
- mode ='regression', training_data = training_data,
- feature_names=feature_names, class_names=['class_1'])
-
- assert len(exp) == len(feature_names)
diff --git a/tests/methods/test_lime_timeseries.py b/tests/methods/test_lime_timeseries.py
index 77ef85b2..1ca65d3b 100644
--- a/tests/methods/test_lime_timeseries.py
+++ b/tests/methods/test_lime_timeseries.py
@@ -3,7 +3,7 @@
from dianna.methods.lime_timeseries import LIMETimeseries
from dianna.utils.maskers import generate_time_series_masks
from dianna.utils.maskers import mask_data
-from tests.utils import run_model
+from tests.utils import get_dummy_model_function
class LIMEOnTimeseries(TestCase):
@@ -15,7 +15,7 @@ def test_lime_timeseries_correct_output_shape(self):
num_features = 10
explainer = LIMETimeseries()
exp = explainer.explain(
- run_model,
+ get_dummy_model_function(n_outputs=2),
input_data,
labels=(0, ),
class_names=("test", ),
@@ -30,7 +30,7 @@ def test_distance_shape(self):
"""Test the shape of returned distance array."""
dummy_timeseries = np.random.random((50, 1))
number_of_masks = 50
- masks = generate_time_series_masks(dummy_timeseries,
+ masks = generate_time_series_masks(dummy_timeseries.shape,
number_of_masks,
p_keep=0.9)
masked = mask_data(dummy_timeseries, masks, mask_type="mean")
diff --git a/tests/methods/test_maskers.py b/tests/methods/test_maskers.py
index b790f554..ad829cfe 100644
--- a/tests/methods/test_maskers.py
+++ b/tests/methods/test_maskers.py
@@ -4,27 +4,28 @@
from dianna.utils.maskers import generate_channel_masks
from dianna.utils.maskers import generate_interpolated_float_masks_for_image
from dianna.utils.maskers import generate_interpolated_float_masks_for_timeseries
+from dianna.utils.maskers import generate_tabular_masks
from dianna.utils.maskers import generate_time_series_masks
from dianna.utils.maskers import generate_time_step_masks
from dianna.utils.maskers import mask_data
-def test_mask_has_correct_shape_univariate():
+def test_timeseries_mask_has_correct_shape_univariate():
"""Test masked data has the correct shape for a univariate input."""
input_data = _get_univariate_time_series()
number_of_masks = 5
- result = generate_time_series_masks(input_data, number_of_masks)
+ result = generate_time_series_masks(input_data.shape, number_of_masks)
assert result.shape == tuple([number_of_masks] + list(input_data.shape))
-def test_mask_has_correct_type_univariate():
+def test_timeseries_mask_has_correct_type_univariate():
"""Test masked data has the correct dtype for a univariate input."""
input_data = _get_univariate_time_series()
number_of_masks = 5
- result = generate_time_series_masks(input_data,
+ result = generate_time_series_masks(input_data.shape,
number_of_masks=number_of_masks)
assert result.dtype == bool
@@ -35,7 +36,7 @@ def test_generate_time_step_masks_dtype_multivariate():
input_data = _get_multivariate_time_series()
number_of_masks = 5
- result = generate_time_step_masks(input_data,
+ result = generate_time_step_masks(input_data.shape,
number_of_masks=number_of_masks,
number_of_features=8,
p_keep=0.5)
@@ -48,7 +49,7 @@ def test_generate_segmented_time_step_masks_dtype_multivariate():
input_data = _get_multivariate_time_series()
number_of_masks = 5
- result = generate_time_step_masks(input_data,
+ result = generate_time_step_masks(input_data.shape,
number_of_masks=number_of_masks,
number_of_features=8,
p_keep=0.5)
@@ -56,13 +57,13 @@ def test_generate_segmented_time_step_masks_dtype_multivariate():
assert result.dtype == bool
-def test_mask_has_correct_shape_multivariate():
+def test_timeseries_mask_has_correct_shape_multivariate():
"""Test masked data has the correct shape for a multivariate input."""
input_data = _get_multivariate_time_series()
number_of_masks = 5
- result = _call_masking_function(input_data,
- number_of_masks=number_of_masks)
+ result = _call_timeseries_masking_function(input_data,
+ number_of_masks=number_of_masks)
assert result.shape == tuple([number_of_masks] + list(input_data.shape))
@@ -76,24 +77,24 @@ def test_mask_has_correct_shape_multivariate():
(0.5, 0.5),
(0.99, 0.9), # Mask only 1
])
-def test_mask_contains_correct_number_of_unmasked_parts(
+def test_timeseries_mask_contains_correct_number_of_unmasked_parts(
p_keep_and_expected_rate):
"""Number of unmasked parts should be conforming the given p_keep."""
p_keep, expected_rate = p_keep_and_expected_rate
input_data = _get_univariate_time_series()
- result = _call_masking_function(input_data, p_keep=p_keep)
+ result = _call_timeseries_masking_function(input_data, p_keep=p_keep)
assert np.sum(result == input_data) / np.product(
result.shape) == expected_rate
-def test_mask_contains_correct_parts_are_mean_masked():
+def test_timeseries_mask_contains_correct_parts_are_mean_masked():
"""All parts that are masked should now contain the mean of the input."""
input_data = _get_univariate_time_series()
mean = np.mean(input_data)
- result = _call_masking_function(input_data, mask_type='mean')
+ result = _call_timeseries_masking_function(input_data, mask_type='mean')
masked_parts = result[(result != input_data)]
assert np.alltrue(
@@ -115,7 +116,7 @@ def _get_multivariate_time_series(number_of_channels: int = 6) -> np.array:
])
-def _call_masking_function(
+def _call_timeseries_masking_function(
input_data,
number_of_masks=5,
p_keep=.3,
@@ -123,7 +124,7 @@ def _call_masking_function(
feature_res=5,
):
"""Helper function with some defaults to call the code under test."""
- masks = generate_time_series_masks(input_data,
+ masks = generate_time_series_masks(input_data.shape,
number_of_masks,
feature_res,
p_keep=p_keep)
@@ -135,7 +136,7 @@ def test_channel_mask_has_correct_shape_multivariate():
number_of_masks = 15
input_data = _get_multivariate_time_series()
- result = generate_channel_masks(input_data, number_of_masks, 0.5)
+ result = generate_channel_masks(input_data.shape, number_of_masks, 0.5)
assert result.shape == tuple([number_of_masks] + list(input_data.shape))
@@ -145,7 +146,7 @@ def test_channel_mask_has_does_not_contain_conflicting_values():
number_of_masks = 15
input_data = _get_multivariate_time_series()
- result = generate_channel_masks(input_data, number_of_masks, 0.5)
+ result = generate_channel_masks(input_data.shape, number_of_masks, 0.5)
unexpected_results = []
for mask_i, mask in enumerate(result):
@@ -165,33 +166,33 @@ def test_channel_mask_masks_correct_number_of_cells():
input_data = _get_multivariate_time_series(number_of_channels=10)
p_keep = 0.3
- result = generate_channel_masks(input_data, number_of_masks, p_keep)
+ result = generate_channel_masks(input_data.shape, number_of_masks, p_keep)
assert result.sum() / np.product(result.shape) == p_keep
-def test_masking_has_correct_shape_multivariate():
+def test_timeseries_masking_has_correct_shape_multivariate():
"""Test for the correct output shape for the general masking function."""
number_of_masks = 15
input_data = _get_multivariate_time_series()
- result = generate_time_series_masks(input_data, number_of_masks)
+ result = generate_time_series_masks(input_data.shape, number_of_masks)
assert result.shape == tuple([number_of_masks] + list(input_data.shape))
-def test_masking_univariate_leaves_anything_unmasked():
+def test_timeseries_masking_univariate_leaves_anything_unmasked():
"""Tests that something remains unmasked and some parts are masked for the univariate case."""
number_of_masks = 1
input_data = _get_univariate_time_series()
- result = generate_time_series_masks(input_data, number_of_masks)
+ result = generate_time_series_masks(input_data.shape, number_of_masks)
assert np.any(result)
assert np.any(~result)
-def test_masking_keep_first_instance():
+def test_timeseries_masking_keep_first_instance():
"""First instance must be the original data for Lime timeseries.
Required by `lime_base` explainer, the first instance of masked (or perturbed)
@@ -202,7 +203,9 @@ def test_masking_keep_first_instance():
"""
input_data = _get_multivariate_time_series()
number_of_masks = 5
- masks = generate_time_series_masks(input_data, number_of_masks, p_keep=0.9)
+ masks = generate_time_series_masks(input_data.shape,
+ number_of_masks,
+ p_keep=0.9)
masks[0, :, :] = 1.0
masked = mask_data(input_data, masks, mask_type="mean")
assert np.array_equal(masked[0, :, :], input_data)
@@ -216,7 +219,7 @@ def test_masks_approximately_correct_number_of_masked_parts_per_time_step(
number_of_masks = 500
input_data = _get_univariate_time_series(num_steps=num_steps)
- masks = generate_time_series_masks(input_data,
+ masks = generate_time_series_masks(input_data.shape,
number_of_masks=number_of_masks,
feature_res=num_steps,
p_keep=p_keep)[:, :, 0]
@@ -235,7 +238,7 @@ def test_masks_approximately_correct_number_of_masked_parts_per_time_step_projec
number_of_masks = 500
input_data = _get_univariate_time_series(num_steps=num_steps)
- masks = generate_time_series_masks(input_data,
+ masks = generate_time_series_masks(input_data.shape,
number_of_masks=number_of_masks,
feature_res=6,
p_keep=p_keep)[:, :, 0]
@@ -310,3 +313,66 @@ def test_generate_interpolated_mean_float_masks_for_image(
print('\n')
print(masks_mean)
assert np.allclose(masks_mean, p_keep, atol=0.1)
+
+
+def test_tabular_mask_has_correct_shape():
+ """Test whether tabular masks has the correct shape."""
+ input_data_shape = (10, )
+ number_of_masks = 30
+
+ masks = np.stack(
+ list(
+ generate_tabular_masks(
+ input_data_shape,
+ number_of_masks,
+ p_keep=0.2,
+ )))
+
+ assert masks.shape == (number_of_masks, *input_data_shape)
+
+
+@pytest.mark.parametrize('p_keep_and_n_unmasked', [
+ (0.2, 0.2),
+ (0.7, 0.7),
+ (0.75, 0.75),
+ (0.999, 0.9),
+ (0.01, 0.1),
+])
+def test_tabular_mask_has_correct_number_masked(p_keep_and_n_unmasked):
+ """Test whether the expected number of features was masked.
+
+ Also taking into account min=1 and max=n-1 of features per instance and edge cases where the exact p_keep can't
+ be met.
+ """
+ input_data_shape = (10, )
+ number_of_masks = 50
+ p_keep, n_unmasked = p_keep_and_n_unmasked
+
+ masks = np.stack(
+ list(
+ generate_tabular_masks(
+ input_data_shape,
+ number_of_masks,
+ p_keep=p_keep,
+ )))
+
+ mean_element = masks.sum() / (np.product(input_data_shape) *
+ number_of_masks)
+ assert np.isclose(mean_element, n_unmasked, atol=0.03)
+
+
+def test_tabular_mask_prob_masked_per_feature_correct():
+ """Test whether every feature has the same probability of being masked."""
+ input_data_shape = (10, )
+ number_of_masks = 1000
+ p_keep = 0.2
+
+ masks = np.stack(
+ list(
+ generate_tabular_masks(
+ input_data_shape,
+ number_of_masks,
+ p_keep=p_keep,
+ )))
+
+ assert np.allclose(masks.mean(axis=0), p_keep, atol=0.05)
diff --git a/tests/methods/test_rise_image.py b/tests/methods/test_rise_image.py
index 51ae742b..5a517f7c 100644
--- a/tests/methods/test_rise_image.py
+++ b/tests/methods/test_rise_image.py
@@ -5,8 +5,8 @@
from dianna.methods.rise_image import RISEImage
from dianna.utils import get_function
from tests.methods.test_onnx_runner import generate_data
+from tests.utils import get_dummy_model_function
from tests.utils import get_mnist_1_data
-from tests.utils import run_model
class RiseOnImages(TestCase):
@@ -22,7 +22,7 @@ def test_rise_function():
"tests/test_data/heatmap_rise_function.npy")
heatmaps = dianna.explain_image(
- run_model,
+ get_dummy_model_function(n_outputs=2),
input_data,
"RISE",
labels,
diff --git a/tests/methods/test_rise_timeseries.py b/tests/methods/test_rise_timeseries.py
index f2457951..3479ff0b 100644
--- a/tests/methods/test_rise_timeseries.py
+++ b/tests/methods/test_rise_timeseries.py
@@ -4,7 +4,7 @@
from dianna.methods.rise_timeseries import RISETimeseries
from tests.methods.time_series_test_case import average_temperature_timeseries_with_1_cold_and_1_hot_day
from tests.methods.time_series_test_case import run_expert_model_3_step
-from tests.utils import run_model
+from tests.utils import get_dummy_model_function
def test_rise_timeseries_correct_output_shape():
@@ -12,7 +12,7 @@ def test_rise_timeseries_correct_output_shape():
input_data = np.random.random((10, 1))
labels = [1]
- heatmaps = dianna.explain_timeseries(run_model,
+ heatmaps = dianna.explain_timeseries(get_dummy_model_function(n_outputs=2),
input_data,
"RISE",
labels,
diff --git a/tests/methods/test_shap_tabular.py b/tests/methods/test_shap_tabular.py
deleted file mode 100644
index f2ecc7fe..00000000
--- a/tests/methods/test_shap_tabular.py
+++ /dev/null
@@ -1,35 +0,0 @@
-"""Test LIME tabular method."""
-from unittest import TestCase
-import numpy as np
-import dianna
-from dianna.methods.kernelshap_tabular import KERNELSHAPTabular
-from tests.utils import run_model
-
-
-class LIMEOnTabular(TestCase):
- """Suite of LIME tests for the tabular case."""
-
- def test_shap_tabular_classification_correct_output_shape(self):
- """Test whether the output of explainer has the correct shape."""
- training_data = np.random.random((10, 2))
- input_data = np.random.random(2)
- feature_names = ["feature_1", "feature_2"]
- explainer = KERNELSHAPTabular(training_data,
- mode ='classification',
- feature_names=feature_names,)
- exp = explainer.explain(
- run_model,
- input_data,
- )
- assert len(exp[0]) == len(feature_names)
-
- def test_shap_tabular_regression_correct_output_shape(self):
- """Test whether the output of explainer has the correct length."""
- training_data = np.random.random((10, 2))
- input_data = np.random.random(2)
- feature_names = ["feature_1", "feature_2"]
- exp = dianna.explain_tabular(run_model, input_tabular=input_data, method='kernelshap',
- mode ='regression', training_data = training_data,
- training_data_kmeans = 2, feature_names=feature_names)
-
- assert len(exp) == len(feature_names)
diff --git a/tests/methods/test_tabular.py b/tests/methods/test_tabular.py
new file mode 100644
index 00000000..08288adc
--- /dev/null
+++ b/tests/methods/test_tabular.py
@@ -0,0 +1,103 @@
+import numpy as np
+import pytest
+import dianna
+from dianna.methods.kernelshap_tabular import KERNELSHAPTabular
+from dianna.methods.lime_tabular import LIMETabular
+from dianna.methods.rise_tabular import RISETabular
+from tests.utils import get_dummy_model_function
+
+explainer_names = [
+ 'rise',
+ 'lime',
+ 'kernelshap',
+]
+
+explainer_classes = [
+ RISETabular,
+ LIMETabular,
+ KERNELSHAPTabular,
+]
+
+
+@pytest.mark.parametrize('method', explainer_names)
+def test_tabular_regression_correct_output_shape(method):
+ """Runs the explainer class with random data and asserts the output shape."""
+ number_of_features = 2
+ number_of_outputs = 3 # Only the first is used for regression explanation
+ training_data = np.random.random((10, number_of_features))
+ input_data = np.random.random(number_of_features)
+ feature_names = ["feature_1", "feature_2"]
+ exp = dianna.explain_tabular(
+ get_dummy_model_function(n_outputs=number_of_outputs),
+ input_tabular=input_data,
+ method=method,
+ mode='regression',
+ training_data=training_data,
+ feature_names=feature_names,
+ )
+ assert exp.shape == (number_of_features, )
+
+
+@pytest.mark.parametrize('explainer_class', explainer_classes)
+def test_tabular_classification_correct_output_shape(explainer_class):
+ """Runs the explainer class with random data and asserts the output shape."""
+ number_of_features = 3
+ number_of_classes = 2
+ training_data = np.random.random((10, number_of_features))
+ input_data = np.random.random(number_of_features)
+ feature_names = ["feature_1", "feature_2", "feature_3"]
+ explainer = explainer_class(
+ training_data,
+ mode='classification',
+ feature_names=feature_names,
+ )
+ exp = explainer.explain(
+ get_dummy_model_function(n_outputs=number_of_classes),
+ input_data,
+ )
+ assert exp.shape == (number_of_classes, number_of_features)
+
+
+def _pprint(explanations):
+ """Pretty prints the explanation for each class while classifying tabular data."""
+ print()
+ rows = [' '.join([f'{v:>4d}' for v in range(25)])]
+ rows += [
+ ' '.join([f'{v:.2f}' for v in explanation])
+ for explanation in explanations
+ ]
+ print('\n'.join(rows))
+
+
+@pytest.mark.parametrize('explainer_class', explainer_classes)
+def test_tabular_simple_dummy_model(explainer_class):
+ """Tests if the explainer can find the single important feature in otherwise random data."""
+ np.random.seed(0)
+ num_features = 25
+ input_data = np.array(num_features // 2 * [1.0] +
+ (num_features - num_features // 2) * [0.0])
+ training_data = np.stack([input_data for _ in range(len(input_data))]).T
+
+ feature_names = [f"feature_{i}" for i in range(num_features)]
+ important_feature_i = 2
+
+ def dummy_model(tabular_data):
+ """Model with output dependent on a single feature of the first instance."""
+ prediction = tabular_data[:, important_feature_i]
+ return np.vstack([prediction, -prediction + 1]).T
+
+ explainer = explainer_class(
+ training_data,
+ mode='classification',
+ feature_names=feature_names,
+ )
+ explanations = explainer.explain(
+ dummy_model,
+ input_data,
+ labels=[0, 1],
+ )
+
+ _pprint(explanations)
+
+ assert np.argmax(explanations[0]) == important_feature_i
+ assert np.argmin(explanations[1]) == important_feature_i
diff --git a/tests/test_common_usage.py b/tests/test_common_usage.py
index 6db271bd..27e4df90 100644
--- a/tests/test_common_usage.py
+++ b/tests/test_common_usage.py
@@ -1,7 +1,7 @@
import numpy as np
import dianna
import dianna.visualization
-from tests.utils import run_model
+from tests.utils import get_dummy_model_function
def test_common_RISE_image_pipeline(): # noqa: N802 ignore case
@@ -10,7 +10,7 @@ def test_common_RISE_image_pipeline(): # noqa: N802 ignore case
axis_labels = {-1: 'channels'}
labels = [0, 1]
- heatmap = dianna.explain_image(run_model,
+ heatmap = dianna.explain_image(get_dummy_model_function(n_outputs=2),
input_image,
'RISE',
labels,
@@ -26,8 +26,8 @@ def test_common_RISE_timeseries_pipeline(): # noqa: N802 ignore case
input_timeseries = np.random.random((31, 1))
labels = [0]
- heatmap = dianna.explain_timeseries(run_model, input_timeseries, 'RISE',
- labels)[0]
+ heatmap = dianna.explain_timeseries(get_dummy_model_function(n_outputs=2),
+ input_timeseries, 'RISE', labels)[0]
segments = []
for channel_number in range(heatmap.shape[1]):
heatmap_channel = heatmap[:, channel_number]
diff --git a/tests/test_kwargs.py b/tests/test_kwargs.py
index 028be46e..166da1c9 100644
--- a/tests/test_kwargs.py
+++ b/tests/test_kwargs.py
@@ -3,8 +3,8 @@
import pytest
import dianna
from tests.methods.test_onnx_runner import generate_data
+from tests.utils import get_dummy_model_function
from tests.utils import load_movie_review_model
-from tests.utils import run_model
class ImageKwargs(TestCase):
@@ -17,24 +17,25 @@ def test_lime_image_correct_kwargs(self):
axis_labels = ('channels', 'y', 'x')
labels = [1]
- dianna.explain_image(model_filename,
- input_data,
- method='LIME',
- labels=labels,
- kernel=None,
- kernel_width=25,
- verbose=False,
- feature_selection='auto',
- random_state=None,
- axis_labels=axis_labels,
- preprocess_function=None,
- top_labels=None,
- num_features=10,
- num_samples=10,
- return_masks=True,
- positive_only=False,
- hide_rest=True,
- )
+ dianna.explain_image(
+ model_filename,
+ input_data,
+ method='LIME',
+ labels=labels,
+ kernel=None,
+ kernel_width=25,
+ verbose=False,
+ feature_selection='auto',
+ random_state=None,
+ axis_labels=axis_labels,
+ preprocess_function=None,
+ top_labels=None,
+ num_features=10,
+ num_samples=10,
+ return_masks=True,
+ positive_only=False,
+ hide_rest=True,
+ )
def test_lime_image_extra_kwarg(self):
"""Test to ensure extra kwargs to lime raise warnings."""
@@ -45,24 +46,24 @@ def test_lime_image_extra_kwarg(self):
error_message = "Error due to following unused kwargs: {'extra_kwarg': None}"
with pytest.raises(TypeError, match=error_message):
dianna.explain_image(model_filename,
- input_data,
- method='LIME',
- labels=labels,
- kernel=None,
- kernel_width=25,
- verbose=False,
- feature_selection='auto',
- random_state=None,
- axis_labels=axis_labels,
- preprocess_function=None,
- top_labels=None,
- num_features=10,
- num_samples=10,
- return_masks=True,
- positive_only=False,
- hide_rest=True,
- extra_kwarg=None
- )
+ input_data,
+ method='LIME',
+ labels=labels,
+ kernel=None,
+ kernel_width=25,
+ verbose=False,
+ feature_selection='auto',
+ random_state=None,
+ axis_labels=axis_labels,
+ preprocess_function=None,
+ top_labels=None,
+ num_features=10,
+ num_samples=10,
+ return_masks=True,
+ positive_only=False,
+ hide_rest=True,
+ extra_kwarg=None)
+
class TextKwargs(TestCase):
"""Suite of tests for kwargs to explainers for Images."""
@@ -71,8 +72,7 @@ def test_rise_text_correct_kwargs(self):
"""Test to ensure correct kwargs to lime run without issues."""
review = "such a bad movie"
- dianna.explain_text(
- self.runner,
+ dianna.explain_text(self.runner,
review,
tokenizer=self.runner.tokenizer,
method='RISE',
@@ -81,8 +81,7 @@ def test_rise_text_correct_kwargs(self):
feature_res=8,
p_keep=0.5,
preprocess_function=None,
- batch_size=100
- )
+ batch_size=100)
def test_rise_text_extra_kwarg(self):
"""Test to ensure extra kwargs to lime raise warnings."""
@@ -90,8 +89,7 @@ def test_rise_text_extra_kwarg(self):
error_message = "Error due to following unused kwargs: {'extra_kwarg': None}"
with pytest.raises(TypeError, match=error_message):
- dianna.explain_text(
- self.runner,
+ dianna.explain_text(self.runner,
review,
tokenizer=self.runner.tokenizer,
method='RISE',
@@ -101,14 +99,14 @@ def test_rise_text_extra_kwarg(self):
p_keep=0.5,
preprocess_function=None,
batch_size=100,
- extra_kwarg=None
- )
+ extra_kwarg=None)
def setUp(self) -> None:
"""Set seed and load runner."""
np.random.seed(0)
self.runner = load_movie_review_model()
+
class TimeseriesKwargs(TestCase):
"""Suite of tests for kwargs to explainers for Images."""
@@ -117,22 +115,22 @@ def test_lime_timeseries_correct_kwargs(self):
input_data = np.random.random((10, 1))
dianna.explain_timeseries(
- run_model,
- input_timeseries=input_data,
- method='LIME',
- labels=[0,1],
- class_names=["summer", "winter"],
- kernel_width=25,
- verbose=False,
- preprocess_function=None,
- feature_selection='auto',
- num_features=10,
- num_samples=10,
- num_slices=10,
- batch_size=10,
- mask_type='mean',
- distance_method='cosine',
- )
+ get_dummy_model_function(n_outputs=2),
+ input_timeseries=input_data,
+ method='LIME',
+ labels=[0, 1],
+ class_names=["summer", "winter"],
+ kernel_width=25,
+ verbose=False,
+ preprocess_function=None,
+ feature_selection='auto',
+ num_features=10,
+ num_samples=10,
+ num_slices=10,
+ batch_size=10,
+ mask_type='mean',
+ distance_method='cosine',
+ )
def test_lime_timeseries_extra_kwargs(self):
"""Test to ensure extra kwargs to lime raise warnings."""
@@ -140,24 +138,23 @@ def test_lime_timeseries_extra_kwargs(self):
error_message = "Error due to following unused kwargs: {'extra_kwarg': None}"
with pytest.raises(TypeError, match=error_message):
- dianna.explain_timeseries(
- run_model,
- input_timeseries=input_data,
- method='LIME',
- labels=[0,1],
- class_names=["summer", "winter"],
- kernel_width=25,
- verbose=False,
- preprocess_function=None,
- feature_selection='auto',
- num_features=10,
- num_samples=10,
- num_slices=10,
- batch_size=10,
- mask_type='mean',
- distance_method='cosine',
- extra_kwarg=None
- )
+ dianna.explain_timeseries(get_dummy_model_function(n_outputs=2),
+ input_timeseries=input_data,
+ method='LIME',
+ labels=[0, 1],
+ class_names=["summer", "winter"],
+ kernel_width=25,
+ verbose=False,
+ preprocess_function=None,
+ feature_selection='auto',
+ num_features=10,
+ num_samples=10,
+ num_slices=10,
+ batch_size=10,
+ mask_type='mean',
+ distance_method='cosine',
+ extra_kwarg=None)
+
class TabularKwargs(TestCase):
"""Suite of tests for kwargs to explainers for Images."""
diff --git a/tests/utils.py b/tests/utils.py
index c5ec7bb0..0bfd1313 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -42,20 +42,31 @@ def get_mnist_1_data():
return np.loadtxt(_mnist_1_data.split()).reshape((1, 1, 28, 28))
-def run_model(input_data):
- """Simulate a model that outputs 2-classes.
+def get_dummy_model_function(n_outputs):
+ """Create a function that simulates a model that outputs n_outputs.
Args:
- input_data: input data for the dummy model
+ n_outputs: number of outputs
Returns:
- semi random output
+ dummy model as a function
"""
- n_class = 2
- batch_size = input_data.shape[0]
- np.random.seed(42)
- return np.random.random((batch_size, n_class))
+ def run_model(input_data):
+ """Simulate a model that outputs n_outputs.
+
+ Args:
+ input_data: input data for the dummy model
+
+ Returns:
+ semi random output
+ """
+ batch_size = input_data.shape[0]
+
+ np.random.seed(42)
+ return np.random.random((batch_size, n_outputs))
+
+ return run_model
class ModelRunner:
diff --git a/tutorials/README.md b/tutorials/README.md
index 21020a66..5c0d7a38 100644
--- a/tutorials/README.md
+++ b/tutorials/README.md
@@ -52,7 +52,7 @@ The ONNX models used in the tutorials are available at [dianna/models](https://g
|*Text* |[](./explainers/RISE/rise_text.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/RISE/rise_text.ipynb) |[
](./explainers/LIME/lime_text.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_text.ipynb) |[]()|
| *Time series*| [
](./explainers/RISE/rise_timeseries_weather.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/RISE/rise_timeseries_weather.ipynb)| [
](./explainers/LIME/lime_timeseries_weather.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_timeseries_weather.ipynb)| |
| | | [
](./explainers/LIME/lime_timeseries_coffee.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_timeseries_coffee.ipynb) | |
-| *Tabular* | | [
](./explainers/LIME/lime_tabular_penguin.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_tabular_penguin.ipynb) |[
](./explainers/KernelSHAP/kernelshap_tabular_penguin.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/KernelSHAP/kernelshap_tabular_penguin.ipynb) |
+| *Tabular* |
](./explainers/RISE/rise_tabular_penguin.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/RISE/rise_tabular_penguin.ipynb) | [
](./explainers/LIME/lime_tabular_penguin.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_tabular_penguin.ipynb) |[
](./explainers/KernelSHAP/kernelshap_tabular_penguin.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/KernelSHAP/kernelshap_tabular_penguin.ipynb) |
| | | [
](./explainers/LIME/lime_tabular_weather.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/KernelSHAP/kernelshap_tabular_weather.ipynb)|[
](./explainers/KernelSHAP/kernelshap_tabular_weather.ipynb) or [](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/KernelSHAP/kernelshap_tabular_weather.ipynb) |
To learn more about how we aproach the masking for time-series data, please read our [Masking time-series for XAI](https://blog.esciencecenter.nl/masking-time-series-for-explainable-ai-90247ac252b4) blog-post.
@@ -68,7 +68,7 @@ To learn more about how we aproach the masking for time-series data, please read
### IMPORTANT: Hyperparameters
The XAI methods (explainers) are sensitive to the choice of their hyperparameters! In this [master Thesis](https://staff.fnwi.uva.nl/a.s.z.belloum/MSctheses/MScthesis_Willem_van_der_Spec.pdf), this sensitivity is researched and useful conclusions are drawn.
-The default hyperparameters used in DIANNA for each explainer as well as the choices for some tutorials and their data modality (*i* - images, *txt* - text, *ts* - time series and *tab* - tabular) are given in the tables below.
+The default hyperparameters used in DIANNA for each explainer as well as the choices for some tutorials and their data modality (*i* - images, *txt* - text, *ts* - time series and *tab* - tabular) are given in the tables below.
Also the main conclusions (🠊) from the thesis (on images and text) about the hyperparameters effect are listed.
#### RISE
@@ -102,4 +102,4 @@ Also the main conclusions (🠊) from the thesis (on images and text) about the
🠊 The most crucial parameter is the nubmer of super-pixels $n_{segments}. Higher values led to higher sensitivity, however that observaiton was dependant on the evaluaiton metric.
-🠊 Regularization had only a marginal detrimental effect, the best results were obtained using no regularization (no smoothing, $sigma = 0$) or least squares regression.
+🠊 Regularization had only a marginal detrimental effect, the best results were obtained using no regularization (no smoothing, $sigma = 0$) or least squares regression.
diff --git a/tutorials/explainers/RISE/rise_tabular_penguin.ipynb b/tutorials/explainers/RISE/rise_tabular_penguin.ipynb
new file mode 100644
index 00000000..53974931
--- /dev/null
+++ b/tutorials/explainers/RISE/rise_tabular_penguin.ipynb
@@ -0,0 +1,403 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "\n",
+ "### Model Interpretation using LIME for penguin dataset classifier\n",
+ "This notebook demonstrates the use of DIANNA with the RISE tabular method on the penguins dataset.\n",
+ "\n",
+ "RISE is short for Randomized Input Sampling for Explanation of Black-box Models. It estimates each feature's relevance to the model's decision empirically by probing the model with randomly masked versions of the input text and obtaining the corresponding outputs. More details about this method can be found in the [paper introducing RISE](https://arxiv.org/abs/1806.07421)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Colab setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "start_time": "2024-05-30T11:11:52.728074Z",
+ "end_time": "2024-05-30T11:11:52.815285Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "running_in_colab = 'google.colab' in str(get_ipython())\n",
+ "if running_in_colab:\n",
+ " # install dianna\n",
+ " !python3 -m pip install dianna[notebooks]\n",
+ " \n",
+ " # download data used in this demo\n",
+ " import os \n",
+ " base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/tutorials/'\n",
+ " paths_to_download = ['models/penguin_model.onnx']\n",
+ " for path in paths_to_download:\n",
+ " !wget {base_url + path} -P {os.path.dirname(path)}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 0 - Import libraries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "start_time": "2024-05-30T11:11:52.734426Z",
+ "end_time": "2024-05-30T11:11:52.815285Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import dianna\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import seaborn as sns\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from dianna.utils.onnx_runner import SimpleModelRunner\n",
+ "from pathlib import Path\n",
+ "\n",
+ "root_dir = Path(dianna.__file__).parent"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 1 - Loading the data\n",
+ "Load penguins dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "start_time": "2024-05-30T11:11:52.748720Z",
+ "end_time": "2024-05-30T11:11:52.831285Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "penguins = sns.load_dataset('penguins')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Prepare the data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "start_time": "2024-05-30T11:11:52.756731Z",
+ "end_time": "2024-05-30T11:11:52.891178Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": " bill_length_mm bill_depth_mm flipper_length_mm body_mass_g\n0 39.1 18.7 181.0 3750.0\n1 39.5 17.4 186.0 3800.0\n2 40.3 18.0 195.0 3250.0\n4 36.7 19.3 193.0 3450.0\n5 39.3 20.6 190.0 3650.0\n.. ... ... ... ...\n338 47.2 13.7 214.0 4925.0\n340 46.8 14.3 215.0 4850.0\n341 50.4 15.7 222.0 5750.0\n342 45.2 14.8 212.0 5200.0\n343 49.9 16.1 213.0 5400.0\n\n[342 rows x 4 columns]",
+ "text/html": "
\n | bill_length_mm | \nbill_depth_mm | \nflipper_length_mm | \nbody_mass_g | \n
---|---|---|---|---|
0 | \n39.1 | \n18.7 | \n181.0 | \n3750.0 | \n
1 | \n39.5 | \n17.4 | \n186.0 | \n3800.0 | \n
2 | \n40.3 | \n18.0 | \n195.0 | \n3250.0 | \n
4 | \n36.7 | \n19.3 | \n193.0 | \n3450.0 | \n
5 | \n39.3 | \n20.6 | \n190.0 | \n3650.0 | \n
... | \n... | \n... | \n... | \n... | \n
338 | \n47.2 | \n13.7 | \n214.0 | \n4925.0 | \n
340 | \n46.8 | \n14.3 | \n215.0 | \n4850.0 | \n
341 | \n50.4 | \n15.7 | \n222.0 | \n5750.0 | \n
342 | \n45.2 | \n14.8 | \n212.0 | \n5200.0 | \n
343 | \n49.9 | \n16.1 | \n213.0 | \n5400.0 | \n
342 rows × 4 columns
\n