Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor masker for RISEImage into free function #622

Merged
merged 1 commit into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 3 additions & 40 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
from skimage.transform import resize
from dianna import utils

# To Do: remove this import when the method for different input type is splitted
from dianna.methods.rise_timeseries import RISETimeseries # noqa: F401 ignore unused import
from dianna.utils.maskers import generate_masks_for_images
from dianna.utils.predict import make_predictions


Expand All @@ -12,10 +12,6 @@ def normalize(saliency, n_masks, p_keep):
return saliency / n_masks / p_keep


def _upscale(grid_i, up_size):
return resize(grid_i, up_size, order=1, mode="reflect", anti_aliasing=False)


class RISEText:
"""RISE implementation for text based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb."""

Expand Down Expand Up @@ -199,7 +195,7 @@ def explain(self, model_or_function, input_data, labels, batch_size=100):
# data shape without batch axis and channel axis
img_shape = input_data.shape[1:3]
# Expose masks for to make user inspection possible
self.masks = self._generate_masks(img_shape, active_p_keep, self.n_masks)
self.masks = generate_masks_for_images(self.feature_res, img_shape, active_p_keep, self.n_masks)

# Make sure multiplication is being done for correct axes
masked = input_data * self.masks
Expand Down Expand Up @@ -257,45 +253,12 @@ def _determine_p_keep(self, input_data, runner, n_masks=100):

def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks):
img_shape = input_data.shape[1:3]
masks = self._generate_masks(img_shape, p_keep, n_masks)
masks = generate_masks_for_images(self.feature_res, img_shape, p_keep, n_masks)
masked = input_data * masks
predictions = make_predictions(masked, runner, batch_size=50)
std_per_class = predictions.std(axis=0)
return np.max(std_per_class)

def _generate_masks(self, input_size, p_keep, n_masks):
"""Generates a set of random masks to mask the input data.

Args:
input_size (int): Size of a single sample of input data, for images without the channel axis.
p_keep: Fraction of input data to keep in each mask
n_masks: Number of masks

Returns:
The generated masks (np.ndarray)
"""
cell_size = np.ceil(np.array(input_size) / self.feature_res)
up_size = (self.feature_res + 1) * cell_size

grid = np.random.choice(
a=(True, False),
size=(n_masks, self.feature_res, self.feature_res),
p=(p_keep, 1 - p_keep),
)
grid = grid.astype("float32")

masks = np.empty((n_masks, *input_size), dtype=np.float32)

for i in range(n_masks):
y = np.random.randint(0, cell_size[0])
x = np.random.randint(0, cell_size[1])
# Linear upsampling and cropping
masks[i, :, :] = _upscale(grid[i], up_size)[
y : y + input_size[0], x : x + input_size[1]
]
masks = masks.reshape(-1, *input_size, 1)
return masks

def _prepare_image_data(self, input_data):
"""Transforms the data to be of the shape and type RISE expects.

Expand Down
40 changes: 40 additions & 0 deletions dianna/utils/maskers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Union
import numpy as np
from skimage.transform import resize


def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0.5):
Expand Down Expand Up @@ -92,3 +93,42 @@ def _determine_number_masked(p_keep: float, series_length: int) -> int:
warnings.warn('Warning: p_keep chosen too high. Continuing with masking 1 time step per mask.')
return 1
return user_requested_steps


def _upscale(grid_i, up_size):
return resize(grid_i, up_size, order=1, mode="reflect", anti_aliasing=False)


def generate_masks_for_images(feature_res, input_size, p_keep, n_masks):
"""Generates a set of random masks to mask the input data.
Args:
feature_res (int): Resolution of features in masks.
input_size (int): Size of a single sample of input data, for images without the channel axis.
p_keep: Fraction of input data to keep in each mask
n_masks: Number of masks
Returns:
The generated masks (np.ndarray)
"""
cell_size = np.ceil(np.array(input_size) / feature_res)
up_size = (feature_res + 1) * cell_size

grid = np.random.choice(
a=(True, False),
size=(n_masks, feature_res, feature_res),
p=(p_keep, 1 - p_keep),
)
grid = grid.astype("float32")

masks = np.empty((n_masks, *input_size), dtype=np.float32)

for i in range(n_masks):
y = np.random.randint(0, cell_size[0])
x = np.random.randint(0, cell_size[1])
# Linear upsampling and cropping
masks[i, :, :] = _upscale(grid[i], up_size)[
y : y + input_size[0], x : x + input_size[1]
]
masks = masks.reshape(-1, *input_size, 1)
return masks