Skip to content

Commit 7ae418d

Browse files
authored
Merge pull request #622 from dianna-ai/refactor_masker_for_images
refactor masker for RISEImage into free function
2 parents 72c27a7 + 4c9304e commit 7ae418d

File tree

2 files changed

+43
-40
lines changed

2 files changed

+43
-40
lines changed

dianna/methods/rise.py

+3-40
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import numpy as np
2-
from skimage.transform import resize
32
from dianna import utils
43

54
# To Do: remove this import when the method for different input type is splitted
65
from dianna.methods.rise_timeseries import RISETimeseries # noqa: F401 ignore unused import
6+
from dianna.utils.maskers import generate_masks_for_images
77
from dianna.utils.predict import make_predictions
88

99

@@ -12,10 +12,6 @@ def normalize(saliency, n_masks, p_keep):
1212
return saliency / n_masks / p_keep
1313

1414

15-
def _upscale(grid_i, up_size):
16-
return resize(grid_i, up_size, order=1, mode="reflect", anti_aliasing=False)
17-
18-
1915
class RISEText:
2016
"""RISE implementation for text based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb."""
2117

@@ -199,7 +195,7 @@ def explain(self, model_or_function, input_data, labels, batch_size=100):
199195
# data shape without batch axis and channel axis
200196
img_shape = input_data.shape[1:3]
201197
# Expose masks for to make user inspection possible
202-
self.masks = self._generate_masks(img_shape, active_p_keep, self.n_masks)
198+
self.masks = generate_masks_for_images(self.feature_res, img_shape, active_p_keep, self.n_masks)
203199

204200
# Make sure multiplication is being done for correct axes
205201
masked = input_data * self.masks
@@ -257,45 +253,12 @@ def _determine_p_keep(self, input_data, runner, n_masks=100):
257253

258254
def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks):
259255
img_shape = input_data.shape[1:3]
260-
masks = self._generate_masks(img_shape, p_keep, n_masks)
256+
masks = generate_masks_for_images(self.feature_res, img_shape, p_keep, n_masks)
261257
masked = input_data * masks
262258
predictions = make_predictions(masked, runner, batch_size=50)
263259
std_per_class = predictions.std(axis=0)
264260
return np.max(std_per_class)
265261

266-
def _generate_masks(self, input_size, p_keep, n_masks):
267-
"""Generates a set of random masks to mask the input data.
268-
269-
Args:
270-
input_size (int): Size of a single sample of input data, for images without the channel axis.
271-
p_keep: Fraction of input data to keep in each mask
272-
n_masks: Number of masks
273-
274-
Returns:
275-
The generated masks (np.ndarray)
276-
"""
277-
cell_size = np.ceil(np.array(input_size) / self.feature_res)
278-
up_size = (self.feature_res + 1) * cell_size
279-
280-
grid = np.random.choice(
281-
a=(True, False),
282-
size=(n_masks, self.feature_res, self.feature_res),
283-
p=(p_keep, 1 - p_keep),
284-
)
285-
grid = grid.astype("float32")
286-
287-
masks = np.empty((n_masks, *input_size), dtype=np.float32)
288-
289-
for i in range(n_masks):
290-
y = np.random.randint(0, cell_size[0])
291-
x = np.random.randint(0, cell_size[1])
292-
# Linear upsampling and cropping
293-
masks[i, :, :] = _upscale(grid[i], up_size)[
294-
y : y + input_size[0], x : x + input_size[1]
295-
]
296-
masks = masks.reshape(-1, *input_size, 1)
297-
return masks
298-
299262
def _prepare_image_data(self, input_data):
300263
"""Transforms the data to be of the shape and type RISE expects.
301264

dianna/utils/maskers.py

+40
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
from typing import Union
33
import numpy as np
4+
from skimage.transform import resize
45

56

67
def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0.5):
@@ -92,3 +93,42 @@ def _determine_number_masked(p_keep: float, series_length: int) -> int:
9293
warnings.warn('Warning: p_keep chosen too high. Continuing with masking 1 time step per mask.')
9394
return 1
9495
return user_requested_steps
96+
97+
98+
def _upscale(grid_i, up_size):
99+
return resize(grid_i, up_size, order=1, mode="reflect", anti_aliasing=False)
100+
101+
102+
def generate_masks_for_images(feature_res, input_size, p_keep, n_masks):
103+
"""Generates a set of random masks to mask the input data.
104+
105+
Args:
106+
feature_res (int): Resolution of features in masks.
107+
input_size (int): Size of a single sample of input data, for images without the channel axis.
108+
p_keep: Fraction of input data to keep in each mask
109+
n_masks: Number of masks
110+
111+
Returns:
112+
The generated masks (np.ndarray)
113+
"""
114+
cell_size = np.ceil(np.array(input_size) / feature_res)
115+
up_size = (feature_res + 1) * cell_size
116+
117+
grid = np.random.choice(
118+
a=(True, False),
119+
size=(n_masks, feature_res, feature_res),
120+
p=(p_keep, 1 - p_keep),
121+
)
122+
grid = grid.astype("float32")
123+
124+
masks = np.empty((n_masks, *input_size), dtype=np.float32)
125+
126+
for i in range(n_masks):
127+
y = np.random.randint(0, cell_size[0])
128+
x = np.random.randint(0, cell_size[1])
129+
# Linear upsampling and cropping
130+
masks[i, :, :] = _upscale(grid[i], up_size)[
131+
y : y + input_size[0], x : x + input_size[1]
132+
]
133+
masks = masks.reshape(-1, *input_size, 1)
134+
return masks

0 commit comments

Comments
 (0)