|
1 | 1 | import numpy as np
|
2 |
| -from skimage.transform import resize |
3 | 2 | from dianna import utils
|
4 | 3 |
|
5 | 4 | # To Do: remove this import when the method for different input type is splitted
|
6 | 5 | from dianna.methods.rise_timeseries import RISETimeseries # noqa: F401 ignore unused import
|
7 | 6 | from dianna.utils.predict import make_predictions
|
| 7 | +from dianna.utils.maskers import generate_masks_for_images |
8 | 8 |
|
9 | 9 |
|
10 | 10 | def normalize(saliency, n_masks, p_keep):
|
11 | 11 | """Normalizes salience by number of masks and keep probability."""
|
12 | 12 | return saliency / n_masks / p_keep
|
13 | 13 |
|
14 | 14 |
|
15 |
| -def _upscale(grid_i, up_size): |
16 |
| - return resize(grid_i, up_size, order=1, mode="reflect", anti_aliasing=False) |
17 |
| - |
18 |
| - |
19 | 15 | class RISEText:
|
20 | 16 | """RISE implementation for text based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb."""
|
21 | 17 |
|
@@ -199,7 +195,7 @@ def explain(self, model_or_function, input_data, labels, batch_size=100):
|
199 | 195 | # data shape without batch axis and channel axis
|
200 | 196 | img_shape = input_data.shape[1:3]
|
201 | 197 | # Expose masks for to make user inspection possible
|
202 |
| - self.masks = self._generate_masks(img_shape, active_p_keep, self.n_masks) |
| 198 | + self.masks = generate_masks_for_images(self.feature_res, img_shape, active_p_keep, self.n_masks) |
203 | 199 |
|
204 | 200 | # Make sure multiplication is being done for correct axes
|
205 | 201 | masked = input_data * self.masks
|
@@ -257,45 +253,12 @@ def _determine_p_keep(self, input_data, runner, n_masks=100):
|
257 | 253 |
|
258 | 254 | def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks):
|
259 | 255 | img_shape = input_data.shape[1:3]
|
260 |
| - masks = self._generate_masks(img_shape, p_keep, n_masks) |
| 256 | + masks = generate_masks_for_images(self.feature_res, img_shape, p_keep, n_masks) |
261 | 257 | masked = input_data * masks
|
262 | 258 | predictions = make_predictions(masked, runner, batch_size=50)
|
263 | 259 | std_per_class = predictions.std(axis=0)
|
264 | 260 | return np.max(std_per_class)
|
265 | 261 |
|
266 |
| - def _generate_masks(self, input_size, p_keep, n_masks): |
267 |
| - """Generates a set of random masks to mask the input data. |
268 |
| -
|
269 |
| - Args: |
270 |
| - input_size (int): Size of a single sample of input data, for images without the channel axis. |
271 |
| - p_keep: Fraction of input data to keep in each mask |
272 |
| - n_masks: Number of masks |
273 |
| -
|
274 |
| - Returns: |
275 |
| - The generated masks (np.ndarray) |
276 |
| - """ |
277 |
| - cell_size = np.ceil(np.array(input_size) / self.feature_res) |
278 |
| - up_size = (self.feature_res + 1) * cell_size |
279 |
| - |
280 |
| - grid = np.random.choice( |
281 |
| - a=(True, False), |
282 |
| - size=(n_masks, self.feature_res, self.feature_res), |
283 |
| - p=(p_keep, 1 - p_keep), |
284 |
| - ) |
285 |
| - grid = grid.astype("float32") |
286 |
| - |
287 |
| - masks = np.empty((n_masks, *input_size), dtype=np.float32) |
288 |
| - |
289 |
| - for i in range(n_masks): |
290 |
| - y = np.random.randint(0, cell_size[0]) |
291 |
| - x = np.random.randint(0, cell_size[1]) |
292 |
| - # Linear upsampling and cropping |
293 |
| - masks[i, :, :] = _upscale(grid[i], up_size)[ |
294 |
| - y : y + input_size[0], x : x + input_size[1] |
295 |
| - ] |
296 |
| - masks = masks.reshape(-1, *input_size, 1) |
297 |
| - return masks |
298 |
| - |
299 | 262 | def _prepare_image_data(self, input_data):
|
300 | 263 | """Transforms the data to be of the shape and type RISE expects.
|
301 | 264 |
|
|
0 commit comments