Skip to content

Commit fd20b0f

Browse files
committed
refactor masker for RISEImage into free function
This is necessary for distance_explainer to be able to reuse it.
1 parent 72c27a7 commit fd20b0f

File tree

2 files changed

+42
-40
lines changed

2 files changed

+42
-40
lines changed

dianna/methods/rise.py

+3-40
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
import numpy as np
2-
from skimage.transform import resize
32
from dianna import utils
43

54
# To Do: remove this import when the method for different input type is splitted
65
from dianna.methods.rise_timeseries import RISETimeseries # noqa: F401 ignore unused import
76
from dianna.utils.predict import make_predictions
7+
from dianna.utils.maskers import generate_masks_for_images
88

99

1010
def normalize(saliency, n_masks, p_keep):
1111
"""Normalizes salience by number of masks and keep probability."""
1212
return saliency / n_masks / p_keep
1313

1414

15-
def _upscale(grid_i, up_size):
16-
return resize(grid_i, up_size, order=1, mode="reflect", anti_aliasing=False)
17-
18-
1915
class RISEText:
2016
"""RISE implementation for text based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb."""
2117

@@ -199,7 +195,7 @@ def explain(self, model_or_function, input_data, labels, batch_size=100):
199195
# data shape without batch axis and channel axis
200196
img_shape = input_data.shape[1:3]
201197
# Expose masks for to make user inspection possible
202-
self.masks = self._generate_masks(img_shape, active_p_keep, self.n_masks)
198+
self.masks = generate_masks_for_images(self.feature_res, img_shape, active_p_keep, self.n_masks)
203199

204200
# Make sure multiplication is being done for correct axes
205201
masked = input_data * self.masks
@@ -257,45 +253,12 @@ def _determine_p_keep(self, input_data, runner, n_masks=100):
257253

258254
def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks):
259255
img_shape = input_data.shape[1:3]
260-
masks = self._generate_masks(img_shape, p_keep, n_masks)
256+
masks = generate_masks_for_images(self.feature_res, img_shape, p_keep, n_masks)
261257
masked = input_data * masks
262258
predictions = make_predictions(masked, runner, batch_size=50)
263259
std_per_class = predictions.std(axis=0)
264260
return np.max(std_per_class)
265261

266-
def _generate_masks(self, input_size, p_keep, n_masks):
267-
"""Generates a set of random masks to mask the input data.
268-
269-
Args:
270-
input_size (int): Size of a single sample of input data, for images without the channel axis.
271-
p_keep: Fraction of input data to keep in each mask
272-
n_masks: Number of masks
273-
274-
Returns:
275-
The generated masks (np.ndarray)
276-
"""
277-
cell_size = np.ceil(np.array(input_size) / self.feature_res)
278-
up_size = (self.feature_res + 1) * cell_size
279-
280-
grid = np.random.choice(
281-
a=(True, False),
282-
size=(n_masks, self.feature_res, self.feature_res),
283-
p=(p_keep, 1 - p_keep),
284-
)
285-
grid = grid.astype("float32")
286-
287-
masks = np.empty((n_masks, *input_size), dtype=np.float32)
288-
289-
for i in range(n_masks):
290-
y = np.random.randint(0, cell_size[0])
291-
x = np.random.randint(0, cell_size[1])
292-
# Linear upsampling and cropping
293-
masks[i, :, :] = _upscale(grid[i], up_size)[
294-
y : y + input_size[0], x : x + input_size[1]
295-
]
296-
masks = masks.reshape(-1, *input_size, 1)
297-
return masks
298-
299262
def _prepare_image_data(self, input_data):
300263
"""Transforms the data to be of the shape and type RISE expects.
301264

dianna/utils/maskers.py

+39
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
from typing import Union
33
import numpy as np
4+
from skimage.transform import resize
45

56

67
def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0.5):
@@ -92,3 +93,41 @@ def _determine_number_masked(p_keep: float, series_length: int) -> int:
9293
warnings.warn('Warning: p_keep chosen too high. Continuing with masking 1 time step per mask.')
9394
return 1
9495
return user_requested_steps
96+
97+
98+
def _upscale(grid_i, up_size):
99+
return resize(grid_i, up_size, order=1, mode="reflect", anti_aliasing=False)
100+
101+
102+
def generate_masks_for_images(feature_res, input_size, p_keep, n_masks):
103+
"""Generates a set of random masks to mask the input data.
104+
105+
Args:
106+
input_size (int): Size of a single sample of input data, for images without the channel axis.
107+
p_keep: Fraction of input data to keep in each mask
108+
n_masks: Number of masks
109+
110+
Returns:
111+
The generated masks (np.ndarray)
112+
"""
113+
cell_size = np.ceil(np.array(input_size) / feature_res)
114+
up_size = (feature_res + 1) * cell_size
115+
116+
grid = np.random.choice(
117+
a=(True, False),
118+
size=(n_masks, feature_res, feature_res),
119+
p=(p_keep, 1 - p_keep),
120+
)
121+
grid = grid.astype("float32")
122+
123+
masks = np.empty((n_masks, *input_size), dtype=np.float32)
124+
125+
for i in range(n_masks):
126+
y = np.random.randint(0, cell_size[0])
127+
x = np.random.randint(0, cell_size[1])
128+
# Linear upsampling and cropping
129+
masks[i, :, :] = _upscale(grid[i], up_size)[
130+
y : y + input_size[0], x : x + input_size[1]
131+
]
132+
masks = masks.reshape(-1, *input_size, 1)
133+
return masks

0 commit comments

Comments
 (0)