Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use only best class for auto-tuning of p_keep #86

Merged
merged 1 commit into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def _calculate_mean_class_std_for_text(self, p_keep, runner, input_data, n_masks
for i in range(0, n_masks, batch_size):
current_input = masked[i:i + batch_size]
current_predictions = runner(current_input)
predictions.append(current_predictions)
predictions.append(current_predictions.max(axis=1))
predictions = np.concatenate(predictions)
std_per_class = predictions.std(axis=0)
std_per_class = predictions.std()
return np.mean(std_per_class)

def _generate_masks_for_text(self, input_shape, p_keep, n_masks):
Expand Down Expand Up @@ -123,15 +123,12 @@ def explain_image(self, model_or_function, input_data, batch_size=100):
Returns:
Explanation heatmap for each class (np.ndarray).
"""

runner = utils.get_function(model_or_function, preprocess_function=self.preprocess_function)
# convert data to xarray
input_data = utils.to_xarray(input_data, self.axes_labels, RISE.required_labels)
# batch axis should always be first
input_data = utils.move_axis(input_data, 'batch', 0)
# ensure channels axis is last and keep track of where it was so we can move it back
channels_axis_index = input_data.dims.index('channels')
input_data = utils.move_axis(input_data, 'channels', -1)
input_data, full_preprocess_function = self._prepare_image_data(input_data)
runner = utils.get_function(model_or_function, preprocess_function=full_preprocess_function)

p_keep = self._determine_p_keep_for_images(input_data, runner) if self.p_keep is None else self.p_keep

Expand All @@ -141,11 +138,7 @@ def explain_image(self, model_or_function, input_data, batch_size=100):
self.masks = self.generate_masks_for_images(img_shape, p_keep, self.n_masks)

# Make sure multiplication is being done for correct axes
masked = (input_data * self.masks)
# ensure channels axis is in original location again
masked = utils.move_axis(masked, 'channels', channels_axis_index)
# convert to numpy for onnx
masked = masked.values.astype(input_data.dtype)
masked = input_data * self.masks

batch_predictions = []
for i in tqdm(range(0, self.n_masks, batch_size), desc='Explaining'):
Expand Down Expand Up @@ -173,14 +166,14 @@ def _calculate_mean_class_std_for_images(self, p_keep, runner, input_data, n_mas
batch_size = 50
img_shape = input_data.shape[1:3]
masks = self.generate_masks_for_images(img_shape, p_keep, n_masks)
masked = (input_data * masks).astype(input_data.dtype)
masked = input_data * masks
predictions = []
for i in range(0, n_masks, batch_size):
current_input = masked[i:i + batch_size]
current_predictions = runner(current_input)
predictions.append(current_predictions)
predictions.append(current_predictions.max(axis=1))
predictions = np.concatenate(predictions)
std_per_class = predictions.std(axis=0)
std_per_class = predictions.std()
return np.mean(std_per_class)

def generate_masks_for_images(self, input_size, p_keep, n_masks):
Expand All @@ -198,7 +191,7 @@ def generate_masks_for_images(self, input_size, p_keep, n_masks):
p=(p_keep, 1 - p_keep))
grid = grid.astype('float32')

masks = np.empty((n_masks, *input_size))
masks = np.empty((n_masks, *input_size), dtype=np.float32)

for i in range(n_masks):
y = np.random.randint(0, cell_size[0])
Expand All @@ -207,3 +200,41 @@ def generate_masks_for_images(self, input_size, p_keep, n_masks):
masks[i, :, :] = _upscale(grid[i], up_size)[y:y + input_size[0], x:x + input_size[1]]
masks = masks.reshape(-1, *input_size, 1)
return masks

def _prepare_image_data(self, input_data):
"""
Transforms the data to be of the shape and type RISE expects

Args:
input_data (xarray): Data to be explained

Returns:
transformed input data, preprocessing function to use with utils.get_function()
"""
# ensure channels axis is last and keep track of where it was so we can move it back
channels_axis_index = input_data.dims.index('channels')
input_data = utils.move_axis(input_data, 'channels', -1)
# create preprocessing function that puts model input generated by RISE into the right shape and dtype,
# followed by running the user's preprocessing function
full_preprocess_function = self._get_full_preprocess_function(channels_axis_index, input_data.dtype)
return input_data, full_preprocess_function

def _get_full_preprocess_function(self, channel_axis_index, dtype):
"""
Create a preprocessing function that incorporates both the (optional) user's
preprocessing function, as well as any needed dtype and shape conversions

Args:
channel_axis_index (int): Axis index of the channels in the input data
dtype (type): Data type of the input data (e.g. np.float32)

Returns:
Function that first ensures the data has the same shape and type as the input data,
then runs the users' preprocessing function
"""
def moveaxis_function(data):
return utils.move_axis(data, 'channels', channel_axis_index).astype(dtype).values

if self.preprocess_function is None:
return moveaxis_function
return lambda data: self.preprocess_function(moveaxis_function(data))
30 changes: 4 additions & 26 deletions tests/test_rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
from .test_onnx_runner import generate_data


def make_channels_first(data):
return data.transpose((0, 3, 1, 2))


class RiseOnImages(TestCase):

def test_rise_function(self):
Expand All @@ -36,24 +32,15 @@ def test_rise_filename(self):
assert heatmaps[0].shape == input_data[0].shape[1:]

def test_rise_determine_p_keep_for_images(self):
"""
When using the large sample size of 10000, the mean STD for each class for the following p_keeps
[ 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
is as follows:
[2.069784, 2.600222, 2.8940516, 2.9950087, 2.9579144, 2.8919978, 2.6288269, 2.319147, 1.763127]
So best p_keep should be .4 or .5 ( or at least between .3 and .6).
"""
np.random.seed(0)
expected_p_keeps = [.3, .4, .5, .6]
expected_p_exact_keep = .4
expected_p_exact_keep = .1
model_filename = 'tests/test_data/mnist_model.onnx'
data = get_mnist_1_data().astype(np.float32)

p_keep = rise.RISE()._determine_p_keep_for_images( # pylint: disable=protected-access
data, get_function(model_filename))

assert p_keep in expected_p_keeps # Sanity check: is the outcome in the acceptable range?
assert p_keep == expected_p_exact_keep # Exact test: is the outcome the same as before?
assert np.isclose(p_keep, expected_p_exact_keep)


class RiseOnText(TestCase):
Expand All @@ -78,16 +65,8 @@ def test_rise_text(self):
assert np.allclose(positive_scores, expected_positive_scores)

def test_rise_determine_p_keep_for_text(self):
'''
When using the large sample size of 10000, the mean STD for each class for the following p_keeps
[ 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
is as follows:
[0.18085817, 0.239386, 0.27801532, 0.30555934, 0.31592548, 0.31345606, 0.2901688, 0.2539522, 0.19383237]
So best p_keep should be .4 or .5 ( or at least between .4 and .7).
'''
np.random.seed(0)
expected_p_keeps = [.3, .4, .5, .6]
expected_p_exact_keep = .5
expected_p_exact_keep = .3
model_path = 'tests/test_data/movie_review_model.onnx'
word_vector_file = 'tests/test_data/word_vectors.txt'
runner = ModelRunner(model_path, word_vector_file, max_filter_size=5)
Expand All @@ -97,5 +76,4 @@ def test_rise_determine_p_keep_for_text(self):

p_keep = rise.RISE()._determine_p_keep_for_text(input_tokens, runner) # pylint: disable=protected-access

assert p_keep in expected_p_keeps # Sanity check: is the outcome in the acceptable range?
assert p_keep == expected_p_exact_keep # Exact test: is the outcome the same as before?
assert np.isclose(p_keep, expected_p_exact_keep)
Loading