|
| 1 | +# MIT License |
| 2 | +# |
| 3 | +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022 |
| 4 | +# |
| 5 | +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated |
| 6 | +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the |
| 7 | +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit |
| 8 | +# persons to whom the Software is furnished to do so, subject to the following conditions: |
| 9 | +# |
| 10 | +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the |
| 11 | +# Software. |
| 12 | +# |
| 13 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE |
| 14 | +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 15 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
| 16 | +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 17 | +# SOFTWARE. |
| 18 | +""" |
| 19 | +This module implements the Cutout data augmentation defence in PyTorch. |
| 20 | +
|
| 21 | +| Paper link: https://arxiv.org/abs/1708.04552 |
| 22 | +
|
| 23 | +| Please keep in mind the limitations of defences. For more information on the limitations of this defence, |
| 24 | + see https://arxiv.org/abs/1803.09868 . For details on how to evaluate classifier security in general, see |
| 25 | + https://arxiv.org/abs/1902.06705 |
| 26 | +""" |
| 27 | +from __future__ import absolute_import, division, print_function, unicode_literals |
| 28 | + |
| 29 | +import logging |
| 30 | +from typing import Optional, Tuple, TYPE_CHECKING |
| 31 | + |
| 32 | +from tqdm.auto import trange |
| 33 | + |
| 34 | +from art.defences.preprocessor.preprocessor import PreprocessorPyTorch |
| 35 | + |
| 36 | +if TYPE_CHECKING: |
| 37 | + # pylint: disable=C0412 |
| 38 | + import torch |
| 39 | + |
| 40 | +logger = logging.getLogger(__name__) |
| 41 | + |
| 42 | + |
| 43 | +class CutoutPyTorch(PreprocessorPyTorch): |
| 44 | + """ |
| 45 | + Implement the Cutout data augmentation defence approach in PyTorch. |
| 46 | +
|
| 47 | + | Paper link: https://arxiv.org/abs/1708.04552 |
| 48 | +
|
| 49 | + | Please keep in mind the limitations of defences. For more information on the limitations of this defence, |
| 50 | + see https://arxiv.org/abs/1803.09868 . For details on how to evaluate classifier security in general, see |
| 51 | + https://arxiv.org/abs/1902.06705 |
| 52 | + """ |
| 53 | + |
| 54 | + params = ["length", "channels_first", "verbose"] |
| 55 | + |
| 56 | + def __init__( |
| 57 | + self, |
| 58 | + length: int, |
| 59 | + channels_first: bool = False, |
| 60 | + apply_fit: bool = False, |
| 61 | + apply_predict: bool = True, |
| 62 | + device_type: str = "gpu", |
| 63 | + verbose: bool = False, |
| 64 | + ): |
| 65 | + """ |
| 66 | + Create an instance of a Cutout data augmentation object. |
| 67 | +
|
| 68 | + :param length: Maximum length of the bounding box. |
| 69 | + :param channels_first: Set channels first or last. |
| 70 | + :param apply_fit: True if applied during fitting/training. |
| 71 | + :param apply_predict: True if applied during predicting. |
| 72 | + :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`. |
| 73 | + :param verbose: Show progress bars. |
| 74 | + """ |
| 75 | + super().__init__( |
| 76 | + device_type=device_type, |
| 77 | + is_fitted=True, |
| 78 | + apply_fit=apply_fit, |
| 79 | + apply_predict=apply_predict, |
| 80 | + ) |
| 81 | + self.length = length |
| 82 | + self.channels_first = channels_first |
| 83 | + self.verbose = verbose |
| 84 | + self._check_params() |
| 85 | + |
| 86 | + def forward( |
| 87 | + self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None |
| 88 | + ) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]: |
| 89 | + """ |
| 90 | + Apply Cutout data augmentation to sample `x`. |
| 91 | +
|
| 92 | + :param x: Sample to cut out with shape of `NCHW`, `NHWC`, `NCFHW` or `NFHWC`. |
| 93 | + `x` values are expected to be in the data range [0, 1] or [0, 255]. |
| 94 | + :param y: Labels of the sample `x`. This function does not affect them in any way. |
| 95 | + :return: Data augmented sample. |
| 96 | + """ |
| 97 | + import torch # lgtm [py/repeated-import] |
| 98 | + |
| 99 | + x_ndim = len(x.shape) |
| 100 | + |
| 101 | + # NHWC/NCFHW/NFHWC --> NCHW. |
| 102 | + if x_ndim == 4: |
| 103 | + if self.channels_first: |
| 104 | + # NCHW |
| 105 | + x_nchw = x |
| 106 | + else: |
| 107 | + # NHWC --> NCHW |
| 108 | + x_nchw = x.permute(0, 3, 1, 2) |
| 109 | + elif x_ndim == 5: |
| 110 | + if self.channels_first: |
| 111 | + # NCFHW --> NFCHW --> NCHW |
| 112 | + nb_clips, channels, clip_size, height, width = x.shape |
| 113 | + x_nchw = x.permute(0, 2, 1, 3, 4).reshape(nb_clips * clip_size, channels, height, width) |
| 114 | + else: |
| 115 | + # NFHWC --> NHWC --> NCHW |
| 116 | + nb_clips, clip_size, height, width, channels = x.shape |
| 117 | + x_nchw = x.reshape(nb_clips * clip_size, height, width, channels).permute(0, 3, 1, 2) |
| 118 | + else: |
| 119 | + raise ValueError("Unrecognized input dimension. Cutout can only be applied to image and video data.") |
| 120 | + |
| 121 | + n, _, height, width = x_nchw.shape |
| 122 | + x_nchw = x_nchw.clone() |
| 123 | + |
| 124 | + # generate a random bounding box per image |
| 125 | + for idx in trange(n, desc="Cutout", disable=not self.verbose): |
| 126 | + # uniform sampling |
| 127 | + center_x = torch.randint(0, height, (1,)) |
| 128 | + center_y = torch.randint(0, width, (1,)) |
| 129 | + bby1 = torch.clamp(center_y - self.length // 2, 0, height) |
| 130 | + bbx1 = torch.clamp(center_x - self.length // 2, 0, width) |
| 131 | + bby2 = torch.clamp(center_y + self.length // 2, 0, height) |
| 132 | + bbx2 = torch.clamp(center_x + self.length // 2, 0, width) |
| 133 | + |
| 134 | + # zero out the bounding box |
| 135 | + x_nchw[idx, :, bbx1:bbx2, bby1:bby2] = 0 # type: ignore |
| 136 | + |
| 137 | + # NHWC/NCFHW/NFHWC <-- NCHW. |
| 138 | + if x_ndim == 4: |
| 139 | + if self.channels_first: |
| 140 | + # NCHW |
| 141 | + x_aug = x_nchw |
| 142 | + else: |
| 143 | + # NHWC <-- NCHW |
| 144 | + x_aug = x_nchw.permute(0, 2, 3, 1) |
| 145 | + elif x_ndim == 5: # lgtm [py/redundant-comparison] |
| 146 | + if self.channels_first: |
| 147 | + # NCFHW <-- NFCHW <-- NCHW |
| 148 | + x_nfchw = x_nchw.reshape(nb_clips, clip_size, channels, height, width) |
| 149 | + x_aug = x_nfchw.permute(0, 2, 1, 3, 4) |
| 150 | + else: |
| 151 | + # NFHWC <-- NHWC <-- NCHW |
| 152 | + x_nhwc = x_nchw.permute(0, 2, 3, 1) |
| 153 | + x_aug = x_nhwc.reshape(nb_clips, clip_size, height, width, channels) |
| 154 | + |
| 155 | + return x_aug, y |
| 156 | + |
| 157 | + def _check_params(self) -> None: |
| 158 | + if self.length <= 0: |
| 159 | + raise ValueError("Bounding box length must be positive.") |
0 commit comments