Skip to content

Commit 90bf04b

Browse files
authored
Merge pull request #2224 from Zaid-Hameed/awp_adv
adding adversarial weight perturbation protocol
2 parents 19259d7 + 0a78cdb commit 90bf04b

File tree

5 files changed

+1163
-0
lines changed

5 files changed

+1163
-0
lines changed

art/defences/trainer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@
1010
from art.defences.trainer.adversarial_trainer_fbf_pytorch import AdversarialTrainerFBFPyTorch
1111
from art.defences.trainer.adversarial_trainer_trades import AdversarialTrainerTRADES
1212
from art.defences.trainer.adversarial_trainer_trades_pytorch import AdversarialTrainerTRADESPyTorch
13+
from art.defences.trainer.adversarial_trainer_awp import AdversarialTrainerAWP
14+
from art.defences.trainer.adversarial_trainer_awp_pytorch import AdversarialTrainerAWPPyTorch
1315
from art.defences.trainer.dp_instahide_trainer import DPInstaHideTrainer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements adversarial training with Adversarial Weight Perturbation (AWP) protocol.
20+
21+
| Paper link: https://proceedings.neurips.cc/paper/2020/file/1ef91c212e30e14bf125e9374262401f-Paper.pdf
22+
23+
| It was noted that this protocol uses double perturbation mechanism i.e, perturbation on the input samples and then
24+
perturbation on the model parameters. Consequently, framework specific implementations are being provided in ART.
25+
"""
26+
from __future__ import absolute_import, division, print_function, unicode_literals
27+
28+
import abc
29+
from typing import Optional, Tuple, TYPE_CHECKING
30+
31+
import numpy as np
32+
33+
from art.defences.trainer.trainer import Trainer
34+
from art.attacks.attack import EvasionAttack
35+
from art.data_generators import DataGenerator
36+
37+
if TYPE_CHECKING:
38+
from art.utils import CLASSIFIER_LOSS_GRADIENTS_TYPE
39+
40+
41+
class AdversarialTrainerAWP(Trainer):
42+
"""
43+
This is abstract class for different backend-specific implementations of AWP protocol
44+
for adversarial training.
45+
46+
| Paper link: https://proceedings.neurips.cc/paper/2020/file/1ef91c212e30e14bf125e9374262401f-Paper.pdf
47+
"""
48+
49+
def __init__(
50+
self,
51+
classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
52+
proxy_classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
53+
attack: EvasionAttack,
54+
mode: str = "PGD",
55+
gamma: float = 0.01,
56+
beta: float = 6.0,
57+
warmup: int = 0,
58+
):
59+
"""
60+
Create an :class:`.AdversarialTrainerAWP` instance.
61+
62+
:param classifier: Model to train adversarially.
63+
:param proxy_classifier: Model for adversarial weight perturbation.
64+
:param attack: attack to use for data augmentation in adversarial training
65+
:param mode: mode determining the optimization objective of base adversarial training and weight perturbation
66+
step
67+
:param gamma: The scaling factor controlling norm of weight perturbation relative to model parameters norm
68+
:param beta: The scaling factor controlling tradeoff between clean loss and adversarial loss for TRADES protocol
69+
:param warmup: The number of epochs after which weight perturbation is applied
70+
"""
71+
self._attack = attack
72+
self._proxy_classifier = proxy_classifier
73+
self._mode = mode
74+
self._gamma = gamma
75+
self._beta = beta
76+
self._warmup = warmup
77+
self._apply_wp = False
78+
super().__init__(classifier)
79+
80+
@abc.abstractmethod
81+
def fit( # pylint: disable=W0221
82+
self,
83+
x: np.ndarray,
84+
y: np.ndarray,
85+
validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
86+
batch_size: int = 128,
87+
nb_epochs: int = 20,
88+
**kwargs
89+
):
90+
"""
91+
Train a model adversarially with AWP. See class documentation for more information on the exact procedure.
92+
93+
:param x: Training set.
94+
:param y: Labels for the training set.
95+
:param validation_data: Tuple consisting of validation data, (x_val, y_val)
96+
:param batch_size: Size of batches.
97+
:param nb_epochs: Number of epochs to use for trainings.
98+
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
99+
the target classifier.
100+
"""
101+
raise NotImplementedError
102+
103+
@abc.abstractmethod
104+
def fit_generator( # pylint: disable=W0221
105+
self,
106+
generator: DataGenerator,
107+
validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
108+
nb_epochs: int = 20,
109+
**kwargs
110+
):
111+
"""
112+
Train a model adversarially with AWP using a data generator.
113+
See class documentation for more information on the exact procedure.
114+
115+
:param generator: Data generator.
116+
:param validation_data: Tuple consisting of validation data, (x_val, y_val)
117+
:param nb_epochs: Number of epochs to use for trainings.
118+
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
119+
the target classifier.
120+
"""
121+
raise NotImplementedError
122+
123+
def predict(self, x: np.ndarray, **kwargs) -> np.ndarray:
124+
"""
125+
Perform prediction using the adversarially trained classifier.
126+
127+
:param x: Input samples.
128+
:param kwargs: Other parameters to be passed on to the `predict` function of the classifier.
129+
:return: Predictions for test set.
130+
"""
131+
return self._classifier.predict(x, **kwargs)

0 commit comments

Comments
 (0)