|
| 1 | +# MIT License |
| 2 | +# |
| 3 | +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023 |
| 4 | +# |
| 5 | +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated |
| 6 | +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the |
| 7 | +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit |
| 8 | +# persons to whom the Software is furnished to do so, subject to the following conditions: |
| 9 | +# |
| 10 | +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the |
| 11 | +# Software. |
| 12 | +# |
| 13 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE |
| 14 | +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 15 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
| 16 | +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 17 | +# SOFTWARE. |
| 18 | +""" |
| 19 | +This module implements adversarial training with Adversarial Weight Perturbation (AWP) protocol. |
| 20 | +
|
| 21 | +| Paper link: https://proceedings.neurips.cc/paper/2020/file/1ef91c212e30e14bf125e9374262401f-Paper.pdf |
| 22 | +
|
| 23 | +| It was noted that this protocol uses double perturbation mechanism i.e, perturbation on the input samples and then |
| 24 | +perturbation on the model parameters. Consequently, framework specific implementations are being provided in ART. |
| 25 | +""" |
| 26 | +from __future__ import absolute_import, division, print_function, unicode_literals |
| 27 | + |
| 28 | +import abc |
| 29 | +from typing import Optional, Tuple, TYPE_CHECKING |
| 30 | + |
| 31 | +import numpy as np |
| 32 | + |
| 33 | +from art.defences.trainer.trainer import Trainer |
| 34 | +from art.attacks.attack import EvasionAttack |
| 35 | +from art.data_generators import DataGenerator |
| 36 | + |
| 37 | +if TYPE_CHECKING: |
| 38 | + from art.utils import CLASSIFIER_LOSS_GRADIENTS_TYPE |
| 39 | + |
| 40 | + |
| 41 | +class AdversarialTrainerAWP(Trainer): |
| 42 | + """ |
| 43 | + This is abstract class for different backend-specific implementations of AWP protocol |
| 44 | + for adversarial training. |
| 45 | +
|
| 46 | + | Paper link: https://proceedings.neurips.cc/paper/2020/file/1ef91c212e30e14bf125e9374262401f-Paper.pdf |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE", |
| 52 | + proxy_classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE", |
| 53 | + attack: EvasionAttack, |
| 54 | + mode: str = "PGD", |
| 55 | + gamma: float = 0.01, |
| 56 | + beta: float = 6.0, |
| 57 | + warmup: int = 0, |
| 58 | + ): |
| 59 | + """ |
| 60 | + Create an :class:`.AdversarialTrainerAWP` instance. |
| 61 | +
|
| 62 | + :param classifier: Model to train adversarially. |
| 63 | + :param proxy_classifier: Model for adversarial weight perturbation. |
| 64 | + :param attack: attack to use for data augmentation in adversarial training |
| 65 | + :param mode: mode determining the optimization objective of base adversarial training and weight perturbation |
| 66 | + step |
| 67 | + :param gamma: The scaling factor controlling norm of weight perturbation relative to model parameters norm |
| 68 | + :param beta: The scaling factor controlling tradeoff between clean loss and adversarial loss for TRADES protocol |
| 69 | + :param warmup: The number of epochs after which weight perturbation is applied |
| 70 | + """ |
| 71 | + self._attack = attack |
| 72 | + self._proxy_classifier = proxy_classifier |
| 73 | + self._mode = mode |
| 74 | + self._gamma = gamma |
| 75 | + self._beta = beta |
| 76 | + self._warmup = warmup |
| 77 | + self._apply_wp = False |
| 78 | + super().__init__(classifier) |
| 79 | + |
| 80 | + @abc.abstractmethod |
| 81 | + def fit( # pylint: disable=W0221 |
| 82 | + self, |
| 83 | + x: np.ndarray, |
| 84 | + y: np.ndarray, |
| 85 | + validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, |
| 86 | + batch_size: int = 128, |
| 87 | + nb_epochs: int = 20, |
| 88 | + **kwargs |
| 89 | + ): |
| 90 | + """ |
| 91 | + Train a model adversarially with AWP. See class documentation for more information on the exact procedure. |
| 92 | +
|
| 93 | + :param x: Training set. |
| 94 | + :param y: Labels for the training set. |
| 95 | + :param validation_data: Tuple consisting of validation data, (x_val, y_val) |
| 96 | + :param batch_size: Size of batches. |
| 97 | + :param nb_epochs: Number of epochs to use for trainings. |
| 98 | + :param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of |
| 99 | + the target classifier. |
| 100 | + """ |
| 101 | + raise NotImplementedError |
| 102 | + |
| 103 | + @abc.abstractmethod |
| 104 | + def fit_generator( # pylint: disable=W0221 |
| 105 | + self, |
| 106 | + generator: DataGenerator, |
| 107 | + validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, |
| 108 | + nb_epochs: int = 20, |
| 109 | + **kwargs |
| 110 | + ): |
| 111 | + """ |
| 112 | + Train a model adversarially with AWP using a data generator. |
| 113 | + See class documentation for more information on the exact procedure. |
| 114 | +
|
| 115 | + :param generator: Data generator. |
| 116 | + :param validation_data: Tuple consisting of validation data, (x_val, y_val) |
| 117 | + :param nb_epochs: Number of epochs to use for trainings. |
| 118 | + :param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of |
| 119 | + the target classifier. |
| 120 | + """ |
| 121 | + raise NotImplementedError |
| 122 | + |
| 123 | + def predict(self, x: np.ndarray, **kwargs) -> np.ndarray: |
| 124 | + """ |
| 125 | + Perform prediction using the adversarially trained classifier. |
| 126 | +
|
| 127 | + :param x: Input samples. |
| 128 | + :param kwargs: Other parameters to be passed on to the `predict` function of the classifier. |
| 129 | + :return: Predictions for test set. |
| 130 | + """ |
| 131 | + return self._classifier.predict(x, **kwargs) |
0 commit comments