1
1
import numpy as np
2
+ from lime .lime_image import ImageExplanation
2
3
from lime .lime_image import LimeImageExplainer
3
4
from lime .lime_text import LimeTextExplainer
4
5
from dianna import utils
@@ -164,6 +165,7 @@ def explain(self,
164
165
top_labels = None ,
165
166
num_features = 10 ,
166
167
num_samples = 5000 ,
168
+ return_masks = True ,
167
169
positive_only = False ,
168
170
hide_rest = True ,
169
171
** kwargs ,
@@ -179,6 +181,7 @@ def explain(self,
179
181
top_labels: Top labels
180
182
num_features (int): Number of features
181
183
num_samples (int): Number of samples
184
+ return_masks (bool): If true, return discretized masks. Otherwise, return LIME scores
182
185
positive_only (bool): Positive only
183
186
hide_rest (bool): Hide rest
184
187
kwargs: These parameters are passed on
@@ -205,12 +208,14 @@ def explain(self,
205
208
num_samples = num_samples ,
206
209
** explain_instance_kwargs ,
207
210
)
208
-
209
- get_image_and_mask_kwargs = utils .get_kwargs_applicable_to_function (explanation .get_image_and_mask , kwargs )
210
- masks = [explanation .get_image_and_mask (label , positive_only = positive_only , hide_rest = hide_rest ,
211
- num_features = num_features , ** get_image_and_mask_kwargs )[1 ]
212
- for label in labels ]
213
- return masks
211
+ if return_masks :
212
+ get_image_and_mask_kwargs = utils .get_kwargs_applicable_to_function (explanation .get_image_and_mask , kwargs )
213
+ maps = [explanation .get_image_and_mask (label , positive_only = positive_only , hide_rest = hide_rest ,
214
+ num_features = num_features , ** get_image_and_mask_kwargs )[1 ]
215
+ for label in labels ]
216
+ else :
217
+ maps = [self ._get_explanation_values (label , explanation ) for label in labels ]
218
+ return maps
214
219
215
220
def _prepare_image_data (self , input_data ):
216
221
"""Transforms the data to be of the shape and type LIME expects.
@@ -276,3 +281,27 @@ def moveaxis_function(data):
276
281
if self .preprocess_function is None :
277
282
return moveaxis_function
278
283
return lambda data : self .preprocess_function (moveaxis_function (data ))
284
+
285
+ def _get_explanation_values (self , label : int , explanation : ImageExplanation ) -> np .array :
286
+ """Get the importance scores from LIME in a salience map.
287
+
288
+ Leverages the `ImageExplanation` class from LIME to generate salience maps.
289
+ These salience maps are constructed using the segmentation masks from
290
+ the explanation and fills these with the scores from the surrogate model
291
+ (default for LIME is Ridge regression) used for the explanation.
292
+
293
+ Args:
294
+ label: The class label for the given explanation
295
+ explanation: An Image Explanation generated by LIME
296
+
297
+ Returns:
298
+ A salience map containing the feature importances from LIME
299
+ """
300
+ class_explanation = explanation .local_exp [label ]
301
+ salience_map = np .zeros (explanation .segments .shape ,
302
+ dtype = class_explanation [0 ][1 ].dtype ) # Ensure same dataype for segment values
303
+
304
+ # Fill segments
305
+ for segment_id , segment_val in class_explanation :
306
+ salience_map [segment_id == explanation .segments ] = segment_val
307
+ return salience_map
0 commit comments