Skip to content

Commit c366b1d

Browse files
authoredNov 7, 2022
Merge pull request #1850 from f4str/cutout-implementation
Implement Cutout in Numpy, PyTorch, and TensorFlow
2 parents 89bf92f + 3156f32 commit c366b1d

File tree

8 files changed

+884
-0
lines changed

8 files changed

+884
-0
lines changed
 

‎art/defences/preprocessor/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""
22
Module implementing preprocessing defences against adversarial attacks.
33
"""
4+
from art.defences.preprocessor.cutout.cutout import Cutout
5+
from art.defences.preprocessor.cutout.cutout_pytorch import CutoutPyTorch
6+
from art.defences.preprocessor.cutout.cutout_tensorflow import CutoutTensorFlowV2
47
from art.defences.preprocessor.feature_squeezing import FeatureSqueezing
58
from art.defences.preprocessor.gaussian_augmentation import GaussianAugmentation
69
from art.defences.preprocessor.inverse_gan import DefenseGAN, InverseGAN

‎art/defences/preprocessor/cutout/__init__.py

Whitespace-only changes.
+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements the Cutout data augmentation defence.
20+
21+
| Paper link: https://arxiv.org/abs/1708.04552
22+
23+
| Please keep in mind the limitations of defences. For more information on the limitations of this defence,
24+
see https://arxiv.org/abs/1803.09868 . For details on how to evaluate classifier security in general, see
25+
https://arxiv.org/abs/1902.06705
26+
"""
27+
from __future__ import absolute_import, division, print_function, unicode_literals
28+
29+
import logging
30+
from typing import Optional, Tuple
31+
32+
import numpy as np
33+
from tqdm.auto import trange
34+
35+
from art.defences.preprocessor.preprocessor import Preprocessor
36+
37+
logger = logging.getLogger(__name__)
38+
39+
40+
class Cutout(Preprocessor):
41+
"""
42+
Implement the Cutout data augmentation defence approach.
43+
44+
| Paper link: https://arxiv.org/abs/1708.04552
45+
46+
| Please keep in mind the limitations of defences. For more information on the limitations of this defence,
47+
see https://arxiv.org/abs/1803.09868 . For details on how to evaluate classifier security in general, see
48+
https://arxiv.org/abs/1902.06705
49+
"""
50+
51+
params = ["length", "channels_first", "verbose"]
52+
53+
def __init__(
54+
self,
55+
length: int,
56+
channels_first: bool = False,
57+
apply_fit: bool = False,
58+
apply_predict: bool = True,
59+
verbose: bool = False,
60+
) -> None:
61+
"""
62+
Create an instance of a Cutout data augmentation object.
63+
64+
:param length: Maximum length of the bounding box.
65+
:param channels_first: Set channels first or last.
66+
:param apply_fit: True if applied during fitting/training.
67+
:param apply_predict: True if applied during predicting.
68+
:param verbose: Show progress bars.
69+
"""
70+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
71+
self.length = length
72+
self.channels_first = channels_first
73+
self.verbose = verbose
74+
self._check_params()
75+
76+
def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
77+
"""
78+
Apply Cutout data augmentation to sample `x`.
79+
80+
:param x: Sample to cut out with shape of `NCHW`, `NHWC`, `NCFHW` or `NFHWC`.
81+
`x` values are expected to be in the data range [0, 1] or [0, 255].
82+
:param y: Labels of the sample `x`. This function does not affect them in any way.
83+
:return: Data augmented sample.
84+
"""
85+
x_ndim = len(x.shape)
86+
87+
# NCHW/NCFHW/NFHWC --> NHWC
88+
if x_ndim == 4:
89+
if self.channels_first:
90+
# NCHW --> NHWC
91+
x_nhwc = np.transpose(x, (0, 2, 3, 1))
92+
else:
93+
# NHWC
94+
x_nhwc = x
95+
elif x_ndim == 5:
96+
if self.channels_first:
97+
# NCFHW --> NFHWC --> NHWC
98+
nb_clips, channels, clip_size, height, width = x.shape
99+
x_nfhwc = np.transpose(x, (0, 2, 3, 4, 1))
100+
x_nhwc = np.reshape(x_nfhwc, (nb_clips * clip_size, height, width, channels))
101+
else:
102+
# NFHWC --> NHWC
103+
nb_clips, clip_size, height, width, channels = x.shape
104+
x_nhwc = np.reshape(x, (nb_clips * clip_size, height, width, channels))
105+
else:
106+
raise ValueError("Unrecognized input dimension. Cutout can only be applied to image and video data.")
107+
108+
n, height, width, _ = x_nhwc.shape
109+
x_nhwc = x_nhwc.copy()
110+
111+
# generate a random bounding box per image
112+
for idx in trange(n, desc="Cutout", disable=not self.verbose):
113+
# uniform sampling
114+
center_y = np.random.randint(height)
115+
center_x = np.random.randint(width)
116+
bby1 = np.clip(center_y - self.length // 2, 0, height)
117+
bbx1 = np.clip(center_x - self.length // 2, 0, width)
118+
bby2 = np.clip(center_y + self.length // 2, 0, height)
119+
bbx2 = np.clip(center_x + self.length // 2, 0, width)
120+
121+
# zero out the bounding box
122+
x_nhwc[idx, bbx1:bbx2, bby1:bby2, :] = 0
123+
124+
# NCHW/NCFHW/NFHWC <-- NHWC
125+
if x_ndim == 4:
126+
if self.channels_first:
127+
# NHWC <-- NCHW
128+
x_aug = np.transpose(x_nhwc, (0, 3, 1, 2))
129+
else:
130+
# NHWC
131+
x_aug = x_nhwc
132+
elif x_ndim == 5: # lgtm [py/redundant-comparison]
133+
if self.channels_first:
134+
# NCFHW <-- NFHWC <-- NHWC
135+
x_nfhwc = np.reshape(x_nhwc, (nb_clips, clip_size, height, width, channels))
136+
x_aug = np.transpose(x_nfhwc, (0, 4, 1, 2, 3))
137+
else:
138+
# NFHWC <-- NHWC
139+
x_aug = np.reshape(x_nhwc, (nb_clips, clip_size, height, width, channels))
140+
141+
return x_aug, y
142+
143+
def _check_params(self) -> None:
144+
if self.length <= 0:
145+
raise ValueError("Bounding box length must be positive.")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements the Cutout data augmentation defence in PyTorch.
20+
21+
| Paper link: https://arxiv.org/abs/1708.04552
22+
23+
| Please keep in mind the limitations of defences. For more information on the limitations of this defence,
24+
see https://arxiv.org/abs/1803.09868 . For details on how to evaluate classifier security in general, see
25+
https://arxiv.org/abs/1902.06705
26+
"""
27+
from __future__ import absolute_import, division, print_function, unicode_literals
28+
29+
import logging
30+
from typing import Optional, Tuple, TYPE_CHECKING
31+
32+
from tqdm.auto import trange
33+
34+
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
35+
36+
if TYPE_CHECKING:
37+
# pylint: disable=C0412
38+
import torch
39+
40+
logger = logging.getLogger(__name__)
41+
42+
43+
class CutoutPyTorch(PreprocessorPyTorch):
44+
"""
45+
Implement the Cutout data augmentation defence approach in PyTorch.
46+
47+
| Paper link: https://arxiv.org/abs/1708.04552
48+
49+
| Please keep in mind the limitations of defences. For more information on the limitations of this defence,
50+
see https://arxiv.org/abs/1803.09868 . For details on how to evaluate classifier security in general, see
51+
https://arxiv.org/abs/1902.06705
52+
"""
53+
54+
params = ["length", "channels_first", "verbose"]
55+
56+
def __init__(
57+
self,
58+
length: int,
59+
channels_first: bool = False,
60+
apply_fit: bool = False,
61+
apply_predict: bool = True,
62+
device_type: str = "gpu",
63+
verbose: bool = False,
64+
):
65+
"""
66+
Create an instance of a Cutout data augmentation object.
67+
68+
:param length: Maximum length of the bounding box.
69+
:param channels_first: Set channels first or last.
70+
:param apply_fit: True if applied during fitting/training.
71+
:param apply_predict: True if applied during predicting.
72+
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
73+
:param verbose: Show progress bars.
74+
"""
75+
super().__init__(
76+
device_type=device_type,
77+
is_fitted=True,
78+
apply_fit=apply_fit,
79+
apply_predict=apply_predict,
80+
)
81+
self.length = length
82+
self.channels_first = channels_first
83+
self.verbose = verbose
84+
self._check_params()
85+
86+
def forward(
87+
self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None
88+
) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]:
89+
"""
90+
Apply Cutout data augmentation to sample `x`.
91+
92+
:param x: Sample to cut out with shape of `NCHW`, `NHWC`, `NCFHW` or `NFHWC`.
93+
`x` values are expected to be in the data range [0, 1] or [0, 255].
94+
:param y: Labels of the sample `x`. This function does not affect them in any way.
95+
:return: Data augmented sample.
96+
"""
97+
import torch # lgtm [py/repeated-import]
98+
99+
x_ndim = len(x.shape)
100+
101+
# NHWC/NCFHW/NFHWC --> NCHW.
102+
if x_ndim == 4:
103+
if self.channels_first:
104+
# NCHW
105+
x_nchw = x
106+
else:
107+
# NHWC --> NCHW
108+
x_nchw = x.permute(0, 3, 1, 2)
109+
elif x_ndim == 5:
110+
if self.channels_first:
111+
# NCFHW --> NFCHW --> NCHW
112+
nb_clips, channels, clip_size, height, width = x.shape
113+
x_nchw = x.permute(0, 2, 1, 3, 4).reshape(nb_clips * clip_size, channels, height, width)
114+
else:
115+
# NFHWC --> NHWC --> NCHW
116+
nb_clips, clip_size, height, width, channels = x.shape
117+
x_nchw = x.reshape(nb_clips * clip_size, height, width, channels).permute(0, 3, 1, 2)
118+
else:
119+
raise ValueError("Unrecognized input dimension. Cutout can only be applied to image and video data.")
120+
121+
n, _, height, width = x_nchw.shape
122+
x_nchw = x_nchw.clone()
123+
124+
# generate a random bounding box per image
125+
for idx in trange(n, desc="Cutout", disable=not self.verbose):
126+
# uniform sampling
127+
center_x = torch.randint(0, height, (1,))
128+
center_y = torch.randint(0, width, (1,))
129+
bby1 = torch.clamp(center_y - self.length // 2, 0, height)
130+
bbx1 = torch.clamp(center_x - self.length // 2, 0, width)
131+
bby2 = torch.clamp(center_y + self.length // 2, 0, height)
132+
bbx2 = torch.clamp(center_x + self.length // 2, 0, width)
133+
134+
# zero out the bounding box
135+
x_nchw[idx, :, bbx1:bbx2, bby1:bby2] = 0 # type: ignore
136+
137+
# NHWC/NCFHW/NFHWC <-- NCHW.
138+
if x_ndim == 4:
139+
if self.channels_first:
140+
# NCHW
141+
x_aug = x_nchw
142+
else:
143+
# NHWC <-- NCHW
144+
x_aug = x_nchw.permute(0, 2, 3, 1)
145+
elif x_ndim == 5: # lgtm [py/redundant-comparison]
146+
if self.channels_first:
147+
# NCFHW <-- NFCHW <-- NCHW
148+
x_nfchw = x_nchw.reshape(nb_clips, clip_size, channels, height, width)
149+
x_aug = x_nfchw.permute(0, 2, 1, 3, 4)
150+
else:
151+
# NFHWC <-- NHWC <-- NCHW
152+
x_nhwc = x_nchw.permute(0, 2, 3, 1)
153+
x_aug = x_nhwc.reshape(nb_clips, clip_size, height, width, channels)
154+
155+
return x_aug, y
156+
157+
def _check_params(self) -> None:
158+
if self.length <= 0:
159+
raise ValueError("Bounding box length must be positive.")

0 commit comments

Comments
 (0)