Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

546 masking time step segmentation #562

Merged
merged 40 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d86462a
Merge branch 'main' into 546-masking-time-step-segmentation
cwmeijer Apr 18, 2023
db439a5
add draft working segmented time step masks (refs #546)
cwmeijer Apr 18, 2023
e944e4e
refactor time series maksing WIP
cwmeijer Apr 25, 2023
beeea4e
refactor timeseries masks
cwmeijer Apr 26, 2023
fb616b2
add some general masks tests that also print masks
cwmeijer Apr 26, 2023
2ed154c
remove old time step mask function WIP
cwmeijer Apr 26, 2023
e82022d
Merge branch 'main' into 546-masking-time-step-segmentation
cwmeijer Apr 26, 2023
aafaf8d
many segmented time step masking fixes
cwmeijer Jun 6, 2023
4bc1436
Merge remote-tracking branch 'refs/remotes/origin/546-masking-time-st…
cwmeijer Jun 6, 2023
3acb117
add feature_res configurable for rise timeseries
cwmeijer Jun 6, 2023
daefb59
fix bug swapped arguments p_keep and num_features and 2 test usages
cwmeijer Nov 20, 2023
2f66363
add failing test that checks number of masked cells in maskers
cwmeijer Nov 21, 2023
0acaeb8
make test case temperatures easier and configurable
cwmeijer Nov 27, 2023
8975ad9
add failing tests for masker
cwmeijer Nov 27, 2023
dd6af61
add printing to test for debugging (WIP)
cwmeijer Nov 27, 2023
4b29c1f
Merge branch 'main' into 546-masking-time-step-segmentation
cwmeijer Nov 30, 2023
9d0f821
parameterize rise time series test synthetic data
cwmeijer Dec 4, 2023
8de390b
split masking for image and time series and fix time series
cwmeijer Dec 12, 2023
47f663d
fix error in consistent p_keep
cwmeijer Dec 12, 2023
bddaf29
make mask number condition more general
cwmeijer Dec 12, 2023
8f15c58
make masked time steps number stochastic
cwmeijer Dec 13, 2023
d4789be
lower n_masks in test to save ~15 seconds testing
cwmeijer Dec 13, 2023
e5145b1
add final version of tests before switching mask gen approach
cwmeijer Jan 16, 2024
8044b05
add various changes (WIP)
cwmeijer Jan 16, 2024
005f402
Clean up mask time series
cwmeijer Jan 17, 2024
4ada437
Mask all channels with the same time step mask
cwmeijer Jan 17, 2024
e1be7ff
rename private function for cleaner code
cwmeijer Jan 18, 2024
8840d51
clean up and document code
cwmeijer Jan 24, 2024
31fc6fb
make saving masks, data and predictions optional
cwmeijer Jan 24, 2024
728139a
Merge remote-tracking branch 'origin/546-masking-time-step-segmentati…
cwmeijer Jan 24, 2024
463b91f
Merge branch 'main' into 546-masking-time-step-segmentation
cwmeijer Jan 25, 2024
906fe05
remove projection.ipynb temp notebook
cwmeijer Jan 25, 2024
ba79712
fix import error after merge
cwmeijer Jan 25, 2024
6537d9a
replace deprecated np.bool with bool
cwmeijer Jan 25, 2024
711d95b
add projected mask test
cwmeijer Jan 25, 2024
778da8a
fix: random offset is now independent for each mask
cwmeijer Jan 29, 2024
43bcd66
Apply suggestions from code review
cwmeijer Feb 20, 2024
943560b
make generate_interpolated_float_masks_for_image public
cwmeijer Feb 21, 2024
77720a9
add test, fix bug in image masking, rename some functions
cwmeijer Feb 21, 2024
c1e6a67
Merge branch 'main' into 546-masking-time-step-segmentation
cwmeijer Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dianna/methods/rise_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from dianna import utils
from dianna.utils.maskers import generate_masks_for_images
from dianna.utils.maskers import _generate_interpolated_float_masks
from dianna.utils.predict import make_predictions
from dianna.utils.rise_utils import normalize

Expand Down Expand Up @@ -60,8 +60,8 @@ def explain(self, model_or_function, input_data, labels, batch_size=100):
# data shape without batch axis and channel axis
img_shape = input_data.shape[1:3]
# Expose masks for to make user inspection possible
self.masks = generate_masks_for_images(img_shape, self.n_masks,
active_p_keep, self.feature_res)
self.masks = _generate_interpolated_float_masks(
img_shape, active_p_keep, self.n_masks, self.feature_res)

# Make sure multiplication is being done for correct axes
masked = input_data * self.masks
Expand Down Expand Up @@ -117,8 +117,8 @@ def _determine_p_keep(self, input_data, runner, n_masks=100):

def _calculate_max_class_std(self, p_keep, runner, input_data, n_masks):
img_shape = input_data.shape[1:3]
masks = generate_masks_for_images(img_shape, n_masks, p_keep,
self.feature_res)
masks = _generate_interpolated_float_masks(img_shape, p_keep, n_masks,
self.feature_res)
masked = input_data * masks
predictions = make_predictions(masked, runner, batch_size=50)
std_per_class = predictions.std(axis=0)
Expand Down
8 changes: 6 additions & 2 deletions dianna/methods/rise_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self,
self.p_keep = p_keep
self.preprocess_function = preprocess_function
self.masks = None
self.masked = None
self.predictions = None

def explain(self,
Expand Down Expand Up @@ -54,10 +55,13 @@ def explain(self,
model_or_function, preprocess_function=self.preprocess_function)
self.masks = generate_masks(input_timeseries,
number_of_masks=self.n_masks,
feature_res=self.feature_res,
p_keep=self.p_keep)
masked = mask_data(input_timeseries, self.masks, mask_type=mask_type)
self.masked = mask_data(input_timeseries,
self.masks,
mask_type=mask_type)

self.predictions = make_predictions(masked, runner, batch_size)
self.predictions = make_predictions(self.masked, runner, batch_size)
n_labels = self.predictions.shape[1]

saliency = self.predictions.T.dot(self.masks.reshape(
Expand Down
237 changes: 182 additions & 55 deletions dianna/utils/maskers.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,58 @@
import heapq
import warnings
from typing import Union
import numpy as np
from numpy import ndarray
from skimage.transform import resize


def generate_masks(input_data: np.array,
number_of_masks: int,
p_keep: float = 0.5):
def generate_masks(
input_data: np.array,
number_of_masks: int,
feature_res: int = 8,
p_keep: float = 0.5,
):
"""Generate masks for time series data given a probability of keeping any time step or channel unmasked.

Args:
input_data: Timeseries data to be explained.
number_of_masks: Number of masks to generate.
p_keep: the probability that any value remains unmasked.
feature_res: Resolution of features in masks.

Returns:
Single array containing all masks where the first dimension represents the batch.
"""
if input_data.shape[-1] == 1: # univariate data
return generate_time_step_masks(input_data, number_of_masks, p_keep)
return generate_time_step_masks(input_data,
number_of_masks,
p_keep,
number_of_features=feature_res)

number_of_channel_masks = number_of_masks // 3
number_of_time_step_masks = number_of_channel_masks
number_of_combined_masks = number_of_masks - number_of_time_step_masks - number_of_channel_masks

time_step_masks = generate_time_step_masks(input_data,
number_of_time_step_masks,
p_keep)
p_keep, feature_res)
channel_masks = generate_channel_masks(input_data, number_of_channel_masks,
p_keep)
number_of_combined_masks = generate_time_step_masks(
input_data, number_of_combined_masks, p_keep) * generate_channel_masks(
input_data, number_of_combined_masks, p_keep)

return np.concatenate(
[time_step_masks, channel_masks, number_of_combined_masks], axis=0)
# Product of two masks: we need sqrt p_keep to ensure correct resulting p_keep
sqrt_p_keep = np.sqrt(p_keep)
combined_masks = generate_time_step_masks(
input_data, number_of_combined_masks,
sqrt_p_keep, feature_res) * generate_channel_masks(
input_data, number_of_combined_masks, sqrt_p_keep)

return np.concatenate([time_step_masks, channel_masks, combined_masks],
axis=0)


def generate_channel_masks(input_data: np.ndarray, number_of_masks: int,
p_keep: float):
"""Generate masks that mask one or multiple channels at a time."""
"""Generate masks that mask one or multiple channels independently at a time."""
number_of_channels = input_data.shape[1]
number_of_channels_masked = _determine_number_masked(
p_keep, number_of_channels)
Expand All @@ -52,20 +65,6 @@ def generate_channel_masks(input_data: np.ndarray, number_of_masks: int,
return masks


def generate_time_step_masks(input_data: np.ndarray, number_of_masks: int,
p_keep: float):
"""Generate masks that mask one or multiple time steps at a time."""
series_length = input_data.shape[0]
number_of_steps_masked = _determine_number_masked(p_keep, series_length)
masked_data_shape = [number_of_masks] + list(input_data.shape)
masks = np.ones(masked_data_shape, dtype=bool)
for i in range(number_of_masks):
steps_to_mask = np.random.choice(series_length, number_of_steps_masked,
False)
masks[i, steps_to_mask] = False
return masks


def mask_data(data: np.array, masks: np.array, mask_type: Union[object, str]):
"""Mask data given using a set of masks.

Expand All @@ -87,7 +86,7 @@ def mask_data(data: np.array, masks: np.array, mask_type: Union[object, str]):
return result


def _get_mask_value(data: np.array, mask_type: str) -> int:
def _get_mask_value(data: np.array, mask_type: object) -> int:
"""Calculates a masking value of the given type for the data."""
if callable(mask_type):
return mask_type(data)
Expand All @@ -97,57 +96,185 @@ def _get_mask_value(data: np.array, mask_type: str) -> int:


def _determine_number_masked(p_keep: float, series_length: int) -> int:
user_requested_steps = int(np.round(series_length * (1 - p_keep)))
if user_requested_steps == series_length:
"""Determine the number of time steps that need to be masked."""
mean = series_length * (1 - p_keep)
floor = np.floor(mean)
ceil = np.ceil(mean)
if floor != ceil:
user_requested_steps = int(
np.random.choice([floor, ceil], 1, p=[ceil - mean, mean - floor]))

Check notice

Code scanning / SonarCloud

numpy.random.Generator should be preferred to numpy.random.RandomState Low

Use a "numpy.random.Generator" here instead of this legacy function. See more on SonarCloud
else:
user_requested_steps = int(floor)

if user_requested_steps >= series_length:
warnings.warn(
'Warning: p_keep chosen too low. Continuing with leaving 1 time step unmasked per mask.'
)
return series_length - 1
if user_requested_steps == 0:
if user_requested_steps <= 0:
warnings.warn(
'Warning: p_keep chosen too high. Continuing with masking 1 time step per mask.'
)
return 1
return user_requested_steps


def _upscale(grid_i, up_size):
return resize(grid_i,
up_size,
order=1,
mode='reflect',
anti_aliasing=False)
def generate_time_step_masks(input_data: np.ndarray, number_of_masks: int,
p_keep: float, number_of_features: int):
"""Generate masks that masks complete time steps at a time while masking time steps in a segmented fashion."""
time_series_length = input_data.shape[0]
number_of_channels = input_data.shape[1]

float_masks = _generate_interpolated_float_masks_for_timeseries(
[time_series_length, 1], number_of_masks, number_of_features)[:, :, 0]
bool_masks = np.empty(shape=float_masks.shape, dtype=np.bool)

# Convert float masks to bool masks using a dynamic threshold
for i in range(float_masks.shape[0]):
bool_masks[i] = _mask_bottom_ratio(float_masks[i], p_keep)

return np.repeat(bool_masks, number_of_channels, axis=2)

def generate_masks_for_images(input_size, n_masks, p_keep, feature_res):

def _mask_bottom_ratio(float_mask: np.ndarray, p_keep: float) -> np.ndarray:
"""Return a bool mask given a mask of floats and a ratio.

Return a mask containing bool values where the top p_keep values of the float mask remain unmasked and the rest is
masked.

Args:
float_mask: a mask containing float values
p_keep: the ratio of keeping cells unmasked

Returns:
a mask containing bool
"""
flat = float_mask.flatten()
time_indices = list(range(len(flat)))
number_of_unmasked_cells = _determine_number_masked(
p_keep, len(time_indices))
top_indices = heapq.nsmallest(number_of_unmasked_cells,
time_indices,
key=lambda time_step: flat[time_step])
flat_mask = np.ones(flat.shape, dtype=np.bool)
flat_mask[top_indices] = False
return flat_mask.reshape(float_mask.shape)


def _generate_interpolated_float_masks(input_size: int, p_keep: float,
number_of_masks: int,
number_of_features: int):
"""Generates a set of random masks to mask the input data.

Args:
input_size (int): Size of a single sample of input data, for images without the channel axis.
n_masks: Number of masks
p_keep: Fraction of input data to keep in each mask
feature_res (int): Resolution of features in masks.
p_keep: ?
number_of_masks: Number of masks
number_of_features: Number of features per dimension

Returns:
The generated masks (np.ndarray)
"""
cell_size = np.ceil(np.array(input_size) / feature_res)
up_size = (feature_res + 1) * cell_size
grid = np.random.choice(a=(True, False),

Check notice

Code scanning / SonarCloud

numpy.random.Generator should be preferred to numpy.random.RandomState Low

Use a "numpy.random.Generator" here instead of this legacy function. See more on SonarCloud
size=(number_of_masks, number_of_features,
number_of_features),
p=(p_keep, 1 - p_keep)).astype('float32')
cell_size = np.ceil(np.array(input_size) / number_of_features)
up_size = (number_of_features + 1) * cell_size
masks = np.empty((number_of_masks, *input_size), dtype=np.float32)
for i in range(masks.shape[0]):
y_offset = np.random.randint(0, cell_size[0])

Check notice

Code scanning / SonarCloud

numpy.random.Generator should be preferred to numpy.random.RandomState Low

Use a "numpy.random.Generator" here instead of this legacy function. See more on SonarCloud
x_offset = np.random.randint(0, cell_size[1])

Check notice

Code scanning / SonarCloud

numpy.random.Generator should be preferred to numpy.random.RandomState Low

Use a "numpy.random.Generator" here instead of this legacy function. See more on SonarCloud
# Linear upsampling and cropping
masks[i, :, :] = _upscale(grid[i],
up_size)[y_offset:y_offset + input_size[0],
x_offset:x_offset + input_size[1]]
masks = masks.reshape(-1, *input_size, 1)
return masks

grid = np.random.choice(
a=(True, False),
size=(n_masks, feature_res, feature_res),
p=(p_keep, 1 - p_keep),
)
grid = grid.astype('float32')

masks = np.empty((n_masks, *input_size), dtype=np.float32)
def _generate_interpolated_float_masks_for_timeseries(input_size: int, number_of_masks: int, number_of_features: int) \
-> ndarray:
"""Generates a set of random masks to mask the input data.

for i in range(n_masks):
y = np.random.randint(0, cell_size[0])
x = np.random.randint(0, cell_size[1])
# Linear upsampling and cropping
masks[i, :, :] = _upscale(grid[i], up_size)[y:y + input_size[0],
x:x + input_size[1]]
masks = masks.reshape(-1, *input_size, 1)
Args:
input_size (int): Size of a single sample of input time series.
number_of_masks: Number of masks
number_of_features: Number of features in the time dimension

Returns:
The generated masks (np.ndarray)
"""
grid = np.random.random(size=(number_of_masks, number_of_features,

Check notice

Code scanning / SonarCloud

numpy.random.Generator should be preferred to numpy.random.RandomState Low

Use a "numpy.random.Generator" here instead of this legacy function. See more on SonarCloud
1), ).astype('float32')

masks_shape = (number_of_masks, *input_size)

if grid.shape == masks_shape:
masks = grid
else:
masks = _project_grids_to_masks(grid, masks_shape)
return masks.reshape(-1, *input_size, 1)


def _project_grids_to_masks__old(grid: ndarray, masks_shape: tuple,
number_of_features: int) -> ndarray:
mask_size = masks_shape[1:]
cell_size = np.ceil(np.array(mask_size) / number_of_features)
up_size = (number_of_features + 1) * cell_size
masks = np.empty(masks_shape, dtype=np.float32)
for i in range(masks_shape[0]):
y_offset = np.random.randint(0, cell_size[0])
x_offset = np.random.randint(0, cell_size[1])
masks[i, :, :] = _upscale(grid[i],
up_size)[y_offset:y_offset + mask_size[0],
x_offset:x_offset + mask_size[1]]
return masks


def _project_grids_to_masks(grids: ndarray,
masks_shape: tuple,
offset=None) -> ndarray:
offset = np.random.random() if offset is None else offset

number_of_features = grids.shape[1]

mask_len = masks_shape[1]

masks = np.empty(masks_shape, dtype=np.float32)
for i_mask in range(masks.shape[0]):
grid = grids[i_mask, :, 0]
mask = masks[i_mask, :, 0]

center_keys = []
for i_mask_step, center_key in enumerate(
np.linspace(start=offset,
stop=number_of_features - 2 + offset,
num=mask_len)):
center_keys.append(center_key)
ceil_key = int(np.ceil(center_key))
floor_key = int(np.floor(center_key))
if ceil_key == floor_key:
combined_value_from_grid = grid[ceil_key]
else:
floor_val = grid[floor_key]
ceil_val = grid[ceil_key]
combined_value_from_grid = (
ceil_key - center_key) * floor_val + (center_key -
floor_key) * ceil_val

mask[i_mask_step] = combined_value_from_grid
for i_channel in range(masks.shape[-1]):
masks[
i_mask, :,
i_channel] = mask # Mask all channels with the same time step mask
return masks


def _upscale(grid_i, up_size):
"""Up samples and crops the grid to result in an array with size up_size."""
return resize(grid_i,
up_size,
order=1,
mode='reflect',
anti_aliasing=False)
Loading