Skip to content

Commit 1db4e0a

Browse files
committedOct 28, 2022
Added cutout support for all video data
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
1 parent 15b68be commit 1db4e0a

File tree

6 files changed

+254
-79
lines changed

6 files changed

+254
-79
lines changed
 

‎art/defences/preprocessor/cutout/cutout.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,35 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
7878
"""
7979
Apply Cutout data augmentation to sample `x`.
8080
81-
:param x: Sample to compress with shape of `NCHW` or `NHWC`. The `x` values are expected to be in
82-
the data range [0, 1] or [0, 255].
81+
:param x: Sample to cut out with shape of `NCHW`, `NHWC`, `NCFHW` or `NFHWC`.
82+
`x` values are expected to be in the data range [0, 1] or [0, 255].
8383
:param y: Labels of the sample `x`. This function does not affect them in any way.
8484
:return: Data augmented sample.
8585
"""
8686
x_ndim = len(x.shape)
8787

88+
# NCHW/NCFHW/NFHWC --> NHWC
8889
if x_ndim == 4:
8990
if self.channels_first:
90-
# NCHW
91-
n, _, height, width = x.shape
91+
# NCHW --> NHWC
92+
x_nhwc = np.transpose(x, (0, 2, 3, 1))
9293
else:
93-
# NHWC
94-
n, height, width, _ = x.shape
94+
x_nhwc = x
95+
elif x_ndim == 5:
96+
if self.channels_first:
97+
# NCFHW --> NFHWC --> NHWC
98+
nb_clips, channels, clip_size, height, width = x.shape
99+
x_nfhwc = np.transpose(x, (0, 2, 3, 4, 1))
100+
x_nhwc = np.reshape(x_nfhwc, (nb_clips * clip_size, height, width, channels))
101+
else:
102+
# NFHWC --> NHWC
103+
nb_clips, clip_size, height, width, channels = x.shape
104+
x_nhwc = np.reshape(x, (nb_clips * clip_size, height, width, channels))
95105
else:
96-
raise ValueError("Unrecognized input dimension. Cutout can only be applied to image data.")
106+
raise ValueError("Unrecognized input dimension. Cutout can only be applied to image and video data.")
97107

98-
masks = np.ones_like(x)
108+
n, height, width, _ = x_nhwc.shape
109+
masks = np.ones_like(x_nhwc)
99110

100111
# generate a random bounding box per image
101112
for idx in trange(n, desc="Cutout", disable=not self.verbose):
@@ -108,12 +119,25 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
108119
bbx2 = np.clip(center_x + self.length // 2, 0, width)
109120

110121
# zero out the bounding box
122+
masks[idx, bbx1:bbx2, bby1:bby2, :] = 0
123+
124+
x_nhwc = x_nhwc * masks
125+
126+
# NCHW/NCFHW/NFHWC <-- NHWC
127+
if x_ndim == 4:
111128
if self.channels_first:
112-
masks[idx, :, bbx1:bbx2, bby1:bby2] = 0
129+
# NHWC <-- NCHW
130+
x_aug = np.transpose(x_nhwc, (0, 3, 1, 2))
113131
else:
114-
masks[idx, bbx1:bbx2, bby1:bby2, :] = 0
115-
116-
x_aug = x * masks
132+
x_aug = x_nhwc
133+
elif x_ndim == 5: # lgtm [py/redundant-comparison]
134+
if self.channels_first:
135+
# NCFHW <-- NFHWC <-- NHWC
136+
x_nfhwc = np.reshape(x_nhwc, (nb_clips, clip_size, height, width, channels))
137+
x_aug = np.transpose(x_nfhwc, (0, 4, 1, 2, 3))
138+
else:
139+
# NFHWC <-- NHWC
140+
x_aug = np.reshape(x_nhwc, (nb_clips, clip_size, height, width, channels))
117141

118142
return x_aug, y
119143

‎art/defences/preprocessor/cutout/cutout_pytorch.py

+67-25
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def __init__(
7272
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
7373
:param verbose: Show progress bars.
7474
"""
75+
import torch # lgtm [py/repeated-import]
76+
from torch.autograd import Function
77+
7578
super().__init__(
7679
device_type=device_type,
7780
is_fitted=True,
@@ -83,50 +86,89 @@ def __init__(
8386
self.verbose = verbose
8487
self._check_params()
8588

89+
class RandomCutout(Function): # pylint: disable=W0223
90+
"""
91+
Function running Preprocessor.
92+
"""
93+
94+
@staticmethod
95+
def forward(ctx, input): # pylint: disable=W0622,W0221
96+
ctx.save_for_backward(input)
97+
n, _, height, width = input.shape
98+
masks = torch.ones_like(input)
99+
100+
# generate a random bounding box per image
101+
for idx in trange(n, desc="Cutout", disable=not self.verbose):
102+
# uniform sampling
103+
center_x = torch.randint(0, height, (1,))
104+
center_y = torch.randint(0, width, (1,))
105+
bby1 = torch.clamp(center_y - self.length // 2, 0, height)
106+
bbx1 = torch.clamp(center_x - self.length // 2, 0, width)
107+
bby2 = torch.clamp(center_y + self.length // 2, 0, height)
108+
bbx2 = torch.clamp(center_x + self.length // 2, 0, width)
109+
110+
# zero out the bounding box
111+
masks[idx, :, bbx1:bbx2, bby1:bby2] = 0 # type: ignore
112+
113+
return input * masks
114+
115+
@staticmethod
116+
def backward(ctx, grad_output): # pylint: disable=W0221
117+
return grad_output
118+
119+
self._random_cutout = RandomCutout
120+
86121
def forward(
87122
self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None
88123
) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]:
89124
"""
90125
Apply Cutout data augmentation to sample `x`.
91126
92-
:param x: Sample to compress with shape of `NCHW` or `NHWC`. The `x` values are expected to be in
93-
the data range [0, 1] or [0, 255].
127+
:param x: Sample to cut out with shape of `NCHW`, `NHWC`, `NCFHW` or `NFHWC`.
128+
`x` values are expected to be in the data range [0, 1] or [0, 255].
94129
:param y: Labels of the sample `x`. This function does not affect them in any way.
95130
:return: Data augmented sample.
96131
"""
97-
import torch # lgtm [py/repeated-import]
98-
99132
x_ndim = len(x.shape)
100133

134+
# NHWC/NCFHW/NFHWC --> NCHW.
101135
if x_ndim == 4:
102136
if self.channels_first:
103-
# NCHW
104-
n, _, height, width = x.shape
137+
x_nchw = x
105138
else:
106-
# NHWC
107-
n, height, width, _ = x.shape
139+
# NHWC --> NCHW
140+
x_nchw = x.permute(0, 3, 1, 2)
141+
elif x_ndim == 5:
142+
if self.channels_first:
143+
# NCFHW --> NFCHW --> NCHW
144+
nb_clips, channels, clip_size, height, width = x.shape
145+
x_nchw = x.permute(0, 2, 1, 3, 4).reshape(nb_clips * clip_size, channels, height, width)
146+
else:
147+
# NFHWC --> NHWC --> NCHW
148+
nb_clips, clip_size, height, width, channels = x.shape
149+
x_nchw = x.reshape(nb_clips * clip_size, height, width, channels).permute(0, 3, 1, 2)
108150
else:
109-
raise ValueError("Unrecognized input dimension. Cutout can only be applied to image data.")
110-
111-
masks = torch.ones_like(x)
151+
raise ValueError("Unrecognized input dimension. Cutout can only be applied to image and video data.")
112152

113-
# generate a random bounding box per image
114-
for idx in trange(n, desc="Cutout", disable=not self.verbose):
115-
# uniform sampling
116-
center_x = torch.randint(0, height, (1,))
117-
center_y = torch.randint(0, width, (1,))
118-
bby1 = torch.clamp(center_y - self.length // 2, 0, height)
119-
bbx1 = torch.clamp(center_x - self.length // 2, 0, width)
120-
bby2 = torch.clamp(center_y + self.length // 2, 0, height)
121-
bbx2 = torch.clamp(center_x + self.length // 2, 0, width)
153+
# apply random cutout
154+
x_nchw = self._random_cutout.apply(x_nchw)
122155

123-
# zero out the bounding box
156+
# NHWC/NCFHW/NFHWC <-- NCHW.
157+
if x_ndim == 4:
124158
if self.channels_first:
125-
masks[idx, :, bbx1:bbx2, bby1:bby2] = 0 # type: ignore
159+
x_aug = x_nchw
126160
else:
127-
masks[idx, bbx1:bbx2, bby1:bby2, :] = 0 # type: ignore
128-
129-
x_aug = x * masks
161+
# NHWC <-- NCHW
162+
x_aug = x_nchw.permute(0, 2, 3, 1)
163+
elif x_ndim == 5: # lgtm [py/redundant-comparison]
164+
if self.channels_first:
165+
# NCFHW <-- NFCHW <-- NCHW
166+
x_nfchw = x_nchw.reshape(nb_clips, clip_size, channels, height, width)
167+
x_aug = x_nfchw.permute(0, 2, 1, 3, 4)
168+
else:
169+
# NFHWC <-- NHWC <-- NCHW
170+
x_nhwc = x_nchw.permute(0, 2, 3, 1)
171+
x_aug = x_nhwc.reshape(nb_clips, clip_size, height, width, channels)
130172

131173
return x_aug, y
132174

‎art/defences/preprocessor/cutout/cutout_tensorflow.py

+35-25
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import logging
3030
from typing import Optional, Tuple, TYPE_CHECKING
3131

32-
from tqdm.auto import trange
33-
3432
from art.defences.preprocessor.preprocessor import PreprocessorTensorFlowV2
3533

3634
if TYPE_CHECKING:
@@ -80,45 +78,57 @@ def forward(self, x: "tf.Tensor", y: Optional["tf.Tensor"] = None) -> Tuple["tf.
8078
"""
8179
Apply Cutout data augmentation to sample `x`.
8280
83-
:param x: Sample to compress with shape of `NCHW` or `NHWC`. The `x` values are expected to be in
84-
the data range [0, 1] or [0, 255].
81+
:param x: Sample to cut out with shape of `NCHW`, `NHWC`, `NCFHW` or `NFHWC`.
82+
`x` values are expected to be in the data range [0, 1] or [0, 255].
8583
:param y: Labels of the sample `x`. This function does not affect them in any way.
8684
:return: Data augmented sample.
8785
"""
8886
import tensorflow as tf # lgtm [py/repeated-import]
87+
import tensorflow_addons as tfa
8988

9089
x_ndim = len(x.shape)
9190

91+
# NCHW/NCFHW/NFHWC --> NHWC
9292
if x_ndim == 4:
9393
if self.channels_first:
94-
# NCHW
95-
n, _, height, width = x.shape
94+
# NCHW --> NHWC
95+
x_nhwc = tf.transpose(x, (0, 2, 3, 1))
96+
else:
97+
x_nhwc = x
98+
elif x_ndim == 5:
99+
if self.channels_first:
100+
# NCFHW --> NFHWC --> NHWC
101+
nb_clips, channels, clip_size, height, width = x.shape
102+
x_nfhwc = tf.transpose(x, (0, 2, 3, 4, 1))
103+
x_nhwc = tf.reshape(x_nfhwc, (nb_clips * clip_size, height, width, channels))
96104
else:
97-
# NHWC
98-
n, height, width, _ = x.shape
105+
# NFHWC --> NHWC
106+
nb_clips, clip_size, height, width, channels = x.shape
107+
x_nhwc = tf.reshape(x, (nb_clips * clip_size, height, width, channels))
99108
else:
100-
raise ValueError("Unrecognized input dimension. Cutout can only be applied to image data.")
109+
raise ValueError("Unrecognized input dimension. Cutout can only be applied to image and video data.")
101110

102-
masks = tf.Variable(tf.ones_like(x), trainable=False)
111+
# round down length to be divisible by 2
112+
length = self.length if self.length % 2 == 0 else max(self.length - 1, 2)
103113

104-
# generate a random bounding box per image
105-
for idx in trange(n, desc="Cutout", disable=not self.verbose):
106-
# uniform sampling
107-
center_y = tf.random.uniform(shape=[], maxval=height, dtype=tf.int32) # pylint: disable=E1123
108-
center_x = tf.random.uniform(shape=[], maxval=width, dtype=tf.int32) # pylint: disable=E1123
109-
bby1 = tf.clip_by_value(center_y - self.length // 2, 0, height)
110-
bbx1 = tf.clip_by_value(center_x - self.length // 2, 0, width)
111-
bby2 = tf.clip_by_value(center_y + self.length // 2, 0, height)
112-
bbx2 = tf.clip_by_value(center_x + self.length // 2, 0, width)
114+
# apply random cutout
115+
x_nhwc = tfa.image.random_cutout(x_nhwc, (length, length))
113116

114-
# zero out the bounding box
117+
# NCHW/NCFHW/NFHWC <-- NHWC
118+
if x_ndim == 4:
115119
if self.channels_first:
116-
bbox = masks[idx, :, bbx1:bbx2, bby1:bby2]
120+
# NHWC <-- NCHW
121+
x_aug = tf.transpose(x_nhwc, (0, 3, 1, 2))
117122
else:
118-
bbox = masks[idx, bbx1:bbx2, bby1:bby2, :]
119-
bbox.assign(tf.zeros_like(bbox))
120-
121-
x_aug = x * masks
123+
x_aug = x_nhwc
124+
elif x_ndim == 5: # lgtm [py/redundant-comparison]
125+
if self.channels_first:
126+
# NCFHW <-- NFHWC <-- NHWC
127+
x_nfhwc = tf.reshape(x_nhwc, (nb_clips, clip_size, height, width, channels))
128+
x_aug = tf.transpose(x_nfhwc, (0, 4, 1, 2, 3))
129+
else:
130+
# NFHWC <-- NHWC
131+
x_aug = tf.reshape(x_nhwc, (nb_clips, clip_size, height, width, channels))
122132

123133
return x_aug, y
124134

‎tests/defences/preprocessor/cutout/test_cutout.py

+40-7
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,23 @@ def image_batch(request, channels_first):
3838
channels = request.param
3939

4040
if channels_first:
41-
data_shape = (2, channels, 16, 16)
41+
data_shape = (2, channels, 12, 8)
4242
else:
43-
data_shape = (2, 16, 16, channels)
43+
data_shape = (2, 12, 8, channels)
44+
return (255 * np.ones(data_shape)).astype(ART_NUMPY_DTYPE)
45+
46+
47+
@pytest.fixture(params=[1, 3], ids=["grayscale", "RGB"])
48+
def video_batch(request, channels_first):
49+
"""
50+
Video fixtures of shape NFHWC and NCFHW.
51+
"""
52+
channels = request.param
53+
54+
if channels_first:
55+
data_shape = (2, 2, channels, 12, 8)
56+
else:
57+
data_shape = (2, 2, 12, 8, channels)
4458
return (255 * np.ones(data_shape)).astype(ART_NUMPY_DTYPE)
4559

4660

@@ -52,14 +66,14 @@ def empty_image(request, channels_first):
5266
channels = request.param
5367

5468
if channels_first:
55-
data_shape = (2, channels, 16, 16)
69+
data_shape = (2, channels, 12, 8)
5670
else:
57-
data_shape = (2, 16, 16, channels)
71+
data_shape = (2, 12, 8, channels)
5872
return np.zeros(data_shape).astype(ART_NUMPY_DTYPE)
5973

6074

6175
@pytest.mark.framework_agnostic
62-
@pytest.mark.parametrize("length", [2, 4])
76+
@pytest.mark.parametrize("length", [4, 5])
6377
@pytest.mark.parametrize("channels_first", [True, False])
6478
def test_cutout_image_data(art_warning, image_batch, length, channels_first):
6579
try:
@@ -76,6 +90,25 @@ def test_cutout_image_data(art_warning, image_batch, length, channels_first):
7690
art_warning(e)
7791

7892

93+
@pytest.mark.framework_agnostic
94+
@pytest.mark.parametrize("length", [4])
95+
@pytest.mark.parametrize("channels_first", [True, False])
96+
def test_cutout_video_data(art_warning, video_batch, length, channels_first):
97+
try:
98+
cutout = Cutout(length=length, channels_first=channels_first)
99+
count = np.not_equal(cutout(video_batch)[0], video_batch).sum()
100+
101+
n = video_batch.shape[0]
102+
frames = video_batch.shape[1]
103+
if channels_first:
104+
channels = video_batch.shape[2]
105+
else:
106+
channels = video_batch.shape[-1]
107+
assert count <= n * frames * channels * length * length
108+
except ARTTestException as e:
109+
art_warning(e)
110+
111+
79112
@pytest.mark.framework_agnostic
80113
@pytest.mark.parametrize("length", [4])
81114
@pytest.mark.parametrize("channels_first", [True])
@@ -91,9 +124,9 @@ def test_cutout_empty_data(art_warning, empty_image, length, channels_first):
91124
def test_non_image_data_error(art_warning, tabular_batch):
92125
try:
93126
test_input = tabular_batch
94-
cutout = Cutout(length=8, channels_first=True)
127+
cutout = Cutout(length=4, channels_first=True)
95128

96-
exc_msg = "Unrecognized input dimension. Cutout can only be applied to image data."
129+
exc_msg = "Unrecognized input dimension. Cutout can only be applied to image and video data."
97130
with pytest.raises(ValueError, match=exc_msg):
98131
cutout(test_input)
99132
except ARTTestException as e:

‎tests/defences/preprocessor/cutout/test_cutout_pytorch.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,23 @@ def image_batch(request, channels_first):
3838
channels = request.param
3939

4040
if channels_first:
41-
data_shape = (2, channels, 16, 16)
41+
data_shape = (2, channels, 12, 8)
4242
else:
43-
data_shape = (2, 16, 16, channels)
43+
data_shape = (2, 12, 8, channels)
44+
return (255 * np.ones(data_shape)).astype(ART_NUMPY_DTYPE)
45+
46+
47+
@pytest.fixture(params=[1, 3], ids=["grayscale", "RGB"])
48+
def video_batch(request, channels_first):
49+
"""
50+
Video fixtures of shape NFHWC and NCFHW.
51+
"""
52+
channels = request.param
53+
54+
if channels_first:
55+
data_shape = (2, 2, channels, 12, 8)
56+
else:
57+
data_shape = (2, 2, 12, 8, channels)
4458
return (255 * np.ones(data_shape)).astype(ART_NUMPY_DTYPE)
4559

4660

@@ -52,9 +66,9 @@ def empty_image(request, channels_first):
5266
channels = request.param
5367

5468
if channels_first:
55-
data_shape = (2, channels, 16, 16)
69+
data_shape = (2, channels, 12, 8)
5670
else:
57-
data_shape = (2, 16, 16, channels)
71+
data_shape = (2, 12, 8, channels)
5872
return np.zeros(data_shape).astype(ART_NUMPY_DTYPE)
5973

6074

@@ -76,6 +90,25 @@ def test_cutout_image_data(art_warning, image_batch, length, channels_first):
7690
art_warning(e)
7791

7892

93+
@pytest.mark.only_with_platform("pytorch")
94+
@pytest.mark.parametrize("length", [4])
95+
@pytest.mark.parametrize("channels_first", [True, False])
96+
def test_cutout_video_data(art_warning, video_batch, length, channels_first):
97+
try:
98+
cutout = CutoutPyTorch(length=length, channels_first=channels_first)
99+
count = np.not_equal(cutout(video_batch)[0], video_batch).sum()
100+
101+
n = video_batch.shape[0]
102+
frames = video_batch.shape[1]
103+
if channels_first:
104+
channels = video_batch.shape[2]
105+
else:
106+
channels = video_batch.shape[-1]
107+
assert count <= n * frames * channels * length * length
108+
except ARTTestException as e:
109+
art_warning(e)
110+
111+
79112
@pytest.mark.only_with_platform("pytorch")
80113
@pytest.mark.parametrize("length", [4])
81114
@pytest.mark.parametrize("channels_first", [True])
@@ -93,7 +126,7 @@ def test_non_image_data_error(art_warning, tabular_batch):
93126
test_input = tabular_batch
94127
cutout = CutoutPyTorch(length=8, channels_first=True)
95128

96-
exc_msg = "Unrecognized input dimension. Cutout can only be applied to image data."
129+
exc_msg = "Unrecognized input dimension. Cutout can only be applied to image and video data."
97130
with pytest.raises(ValueError, match=exc_msg):
98131
cutout(test_input)
99132
except ARTTestException as e:

‎tests/defences/preprocessor/cutout/test_cutout_tensorflow.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,23 @@ def image_batch(request, channels_first):
3838
channels = request.param
3939

4040
if channels_first:
41-
data_shape = (2, channels, 16, 16)
41+
data_shape = (2, channels, 12, 8)
4242
else:
43-
data_shape = (2, 16, 16, channels)
43+
data_shape = (2, 12, 8, channels)
44+
return (255 * np.ones(data_shape)).astype(ART_NUMPY_DTYPE)
45+
46+
47+
@pytest.fixture(params=[1, 3], ids=["grayscale", "RGB"])
48+
def video_batch(request, channels_first):
49+
"""
50+
Video fixtures of shape NFHWC and NCFHW.
51+
"""
52+
channels = request.param
53+
54+
if channels_first:
55+
data_shape = (2, 2, channels, 12, 8)
56+
else:
57+
data_shape = (2, 2, 12, 8, channels)
4458
return (255 * np.ones(data_shape)).astype(ART_NUMPY_DTYPE)
4559

4660

@@ -52,9 +66,9 @@ def empty_image(request, channels_first):
5266
channels = request.param
5367

5468
if channels_first:
55-
data_shape = (2, channels, 16, 16)
69+
data_shape = (2, channels, 12, 8)
5670
else:
57-
data_shape = (2, 16, 16, channels)
71+
data_shape = (2, 12, 8, channels)
5872
return np.zeros(data_shape).astype(ART_NUMPY_DTYPE)
5973

6074

@@ -76,6 +90,25 @@ def test_cutout_image_data(art_warning, image_batch, length, channels_first):
7690
art_warning(e)
7791

7892

93+
@pytest.mark.only_with_platform("tensorflow2")
94+
@pytest.mark.parametrize("length", [4])
95+
@pytest.mark.parametrize("channels_first", [True, False])
96+
def test_cutout_video_data(art_warning, video_batch, length, channels_first):
97+
try:
98+
cutout = CutoutTensorFlowV2(length=length, channels_first=channels_first)
99+
count = np.not_equal(cutout(video_batch)[0], video_batch).sum()
100+
101+
n = video_batch.shape[0]
102+
frames = video_batch.shape[1]
103+
if channels_first:
104+
channels = video_batch.shape[2]
105+
else:
106+
channels = video_batch.shape[-1]
107+
assert count <= n * frames * channels * length * length
108+
except ARTTestException as e:
109+
art_warning(e)
110+
111+
79112
@pytest.mark.only_with_platform("tensorflow2")
80113
@pytest.mark.parametrize("length", [4])
81114
@pytest.mark.parametrize("channels_first", [True])
@@ -93,7 +126,7 @@ def test_non_image_data_error(art_warning, tabular_batch):
93126
test_input = tabular_batch
94127
cutout = CutoutTensorFlowV2(length=8, channels_first=True)
95128

96-
exc_msg = "Unrecognized input dimension. Cutout can only be applied to image data."
129+
exc_msg = "Unrecognized input dimension. Cutout can only be applied to image and video data."
97130
with pytest.raises(ValueError, match=exc_msg):
98131
cutout(test_input)
99132
except ARTTestException as e:

0 commit comments

Comments
 (0)
Please sign in to comment.