1
1
import numpy as np
2
- from skimage .transform import resize
3
2
from dianna import utils
4
3
5
4
# To Do: remove this import when the method for different input type is splitted
6
5
from dianna .methods .rise_timeseries import RISETimeseries # noqa: F401 ignore unused import
6
+ from dianna .utils .maskers import generate_masks_for_images
7
7
from dianna .utils .predict import make_predictions
8
8
9
9
@@ -12,10 +12,6 @@ def normalize(saliency, n_masks, p_keep):
12
12
return saliency / n_masks / p_keep
13
13
14
14
15
- def _upscale (grid_i , up_size ):
16
- return resize (grid_i , up_size , order = 1 , mode = "reflect" , anti_aliasing = False )
17
-
18
-
19
15
class RISEText :
20
16
"""RISE implementation for text based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb."""
21
17
@@ -199,7 +195,7 @@ def explain(self, model_or_function, input_data, labels, batch_size=100):
199
195
# data shape without batch axis and channel axis
200
196
img_shape = input_data .shape [1 :3 ]
201
197
# Expose masks for to make user inspection possible
202
- self .masks = self ._generate_masks ( img_shape , active_p_keep , self .n_masks )
198
+ self .masks = generate_masks_for_images ( self .feature_res , img_shape , active_p_keep , self .n_masks )
203
199
204
200
# Make sure multiplication is being done for correct axes
205
201
masked = input_data * self .masks
@@ -257,45 +253,12 @@ def _determine_p_keep(self, input_data, runner, n_masks=100):
257
253
258
254
def _calculate_max_class_std (self , p_keep , runner , input_data , n_masks ):
259
255
img_shape = input_data .shape [1 :3 ]
260
- masks = self ._generate_masks ( img_shape , p_keep , n_masks )
256
+ masks = generate_masks_for_images ( self .feature_res , img_shape , p_keep , n_masks )
261
257
masked = input_data * masks
262
258
predictions = make_predictions (masked , runner , batch_size = 50 )
263
259
std_per_class = predictions .std (axis = 0 )
264
260
return np .max (std_per_class )
265
261
266
- def _generate_masks (self , input_size , p_keep , n_masks ):
267
- """Generates a set of random masks to mask the input data.
268
-
269
- Args:
270
- input_size (int): Size of a single sample of input data, for images without the channel axis.
271
- p_keep: Fraction of input data to keep in each mask
272
- n_masks: Number of masks
273
-
274
- Returns:
275
- The generated masks (np.ndarray)
276
- """
277
- cell_size = np .ceil (np .array (input_size ) / self .feature_res )
278
- up_size = (self .feature_res + 1 ) * cell_size
279
-
280
- grid = np .random .choice (
281
- a = (True , False ),
282
- size = (n_masks , self .feature_res , self .feature_res ),
283
- p = (p_keep , 1 - p_keep ),
284
- )
285
- grid = grid .astype ("float32" )
286
-
287
- masks = np .empty ((n_masks , * input_size ), dtype = np .float32 )
288
-
289
- for i in range (n_masks ):
290
- y = np .random .randint (0 , cell_size [0 ])
291
- x = np .random .randint (0 , cell_size [1 ])
292
- # Linear upsampling and cropping
293
- masks [i , :, :] = _upscale (grid [i ], up_size )[
294
- y : y + input_size [0 ], x : x + input_size [1 ]
295
- ]
296
- masks = masks .reshape (- 1 , * input_size , 1 )
297
- return masks
298
-
299
262
def _prepare_image_data (self , input_data ):
300
263
"""Transforms the data to be of the shape and type RISE expects.
301
264
0 commit comments