Skip to content

Commit 72c27a7

Browse files
authored
Merge pull request #621 from dianna-ai/LIME_coefficients
Integrated LIME surrogate model scores segmentation filling as an option in LIME.
2 parents 6ee4461 + bd9aa19 commit 72c27a7

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

dianna/methods/lime.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from lime.lime_image import ImageExplanation
23
from lime.lime_image import LimeImageExplainer
34
from lime.lime_text import LimeTextExplainer
45
from dianna import utils
@@ -164,6 +165,7 @@ def explain(self,
164165
top_labels=None,
165166
num_features=10,
166167
num_samples=5000,
168+
return_masks=True,
167169
positive_only=False,
168170
hide_rest=True,
169171
**kwargs,
@@ -179,6 +181,7 @@ def explain(self,
179181
top_labels: Top labels
180182
num_features (int): Number of features
181183
num_samples (int): Number of samples
184+
return_masks (bool): If true, return discretized masks. Otherwise, return LIME scores
182185
positive_only (bool): Positive only
183186
hide_rest (bool): Hide rest
184187
kwargs: These parameters are passed on
@@ -205,12 +208,14 @@ def explain(self,
205208
num_samples=num_samples,
206209
**explain_instance_kwargs,
207210
)
208-
209-
get_image_and_mask_kwargs = utils.get_kwargs_applicable_to_function(explanation.get_image_and_mask, kwargs)
210-
masks = [explanation.get_image_and_mask(label, positive_only=positive_only, hide_rest=hide_rest,
211-
num_features=num_features, **get_image_and_mask_kwargs)[1]
212-
for label in labels]
213-
return masks
211+
if return_masks:
212+
get_image_and_mask_kwargs = utils.get_kwargs_applicable_to_function(explanation.get_image_and_mask, kwargs)
213+
maps = [explanation.get_image_and_mask(label, positive_only=positive_only, hide_rest=hide_rest,
214+
num_features=num_features, **get_image_and_mask_kwargs)[1]
215+
for label in labels]
216+
else:
217+
maps = [self._get_explanation_values(label, explanation) for label in labels]
218+
return maps
214219

215220
def _prepare_image_data(self, input_data):
216221
"""Transforms the data to be of the shape and type LIME expects.
@@ -276,3 +281,27 @@ def moveaxis_function(data):
276281
if self.preprocess_function is None:
277282
return moveaxis_function
278283
return lambda data: self.preprocess_function(moveaxis_function(data))
284+
285+
def _get_explanation_values(self, label: int, explanation: ImageExplanation) -> np.array:
286+
"""Get the importance scores from LIME in a salience map.
287+
288+
Leverages the `ImageExplanation` class from LIME to generate salience maps.
289+
These salience maps are constructed using the segmentation masks from
290+
the explanation and fills these with the scores from the surrogate model
291+
(default for LIME is Ridge regression) used for the explanation.
292+
293+
Args:
294+
label: The class label for the given explanation
295+
explanation: An Image Explanation generated by LIME
296+
297+
Returns:
298+
A salience map containing the feature importances from LIME
299+
"""
300+
class_explanation = explanation.local_exp[label]
301+
salience_map = np.zeros(explanation.segments.shape,
302+
dtype=class_explanation[0][1].dtype) # Ensure same dataype for segment values
303+
304+
# Fill segments
305+
for segment_id, segment_val in class_explanation:
306+
salience_map[segment_id == explanation.segments] = segment_val
307+
return salience_map
392 KB
Binary file not shown.

tests/test_lime.py

+17
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,23 @@ def test_lime_filename():
4848
assert heatmap[0].shape == input_data[0].shape
4949
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)
5050

51+
@staticmethod
52+
def test_lime_values():
53+
"""Test if get_explanation_values function works correctly."""
54+
input_data = np.random.random((224, 224, 3))
55+
heatmap_expected = np.load('tests/test_data/heatmap_lime_values.npy')
56+
labels = [1]
57+
58+
explainer = LIMEImage(random_state=42)
59+
heatmap = explainer.explain(run_model,
60+
input_data,
61+
labels,
62+
return_masks=False,
63+
num_samples=100)
64+
65+
assert heatmap[0].shape == input_data.shape[:2]
66+
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)
67+
5168
def setUp(self) -> None:
5269
"""Set seed."""
5370
np.random.seed(42)

0 commit comments

Comments
 (0)