Skip to content

Commit e472e5f

Browse files
Muhammad Zaid HameedMuhammad Zaid Hameed
Muhammad Zaid Hameed
authored and
Muhammad Zaid Hameed
committedDec 12, 2023
style changes after review
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
1 parent 13c9e98 commit e472e5f

File tree

2 files changed

+15
-21
lines changed

2 files changed

+15
-21
lines changed
 

‎art/defences/trainer/adversarial_trainer_oaat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
:param lpips_classifier: Weight averaging model for calculating activations.
6464
:param list_avg_models: list of models for weight averaging.
6565
:param attack: attack to use for data augmentation in adversarial training
66-
:param train_params: parmaters' dictionary related to adversarial training
66+
:param train_params: parameters' dictionary related to adversarial training
6767
"""
6868
self._attack = attack
6969
self._proxy_classifier = proxy_classifier

‎art/defences/trainer/adversarial_trainer_oaat_pytorch.py

+14-20
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@
2323
"""
2424
from __future__ import absolute_import, division, print_function, unicode_literals
2525

26+
from collections import OrderedDict
2627
import logging
28+
import os
2729
import time
2830
from typing import Optional, Tuple, TYPE_CHECKING, List, Dict, Union
29-
from collections import OrderedDict
30-
import six
3131

32+
import six
3233
import numpy as np
3334
from tqdm.auto import trange
34-
from art import config
3535

36+
from art import config
3637
from art.defences.trainer.adversarial_trainer_oaat import AdversarialTrainerOAAT
3738
from art.estimators.classification.pytorch import PyTorchClassifier
3839
from art.data_generators import DataGenerator
@@ -71,7 +72,7 @@ def __init__(
7172
:param lpips_classifier: Weight averaging model for calculating activations.
7273
:param list_avg_models: list of models for weight averaging.
7374
:param attack: attack to use for data augmentation in adversarial training.
74-
:param train_params: training parmaters' dictionary related to adversarial training
75+
:param train_params: training parameters' dictionary related to adversarial training
7576
"""
7677
super().__init__(classifier, proxy_classifier, lpips_classifier, list_avg_models, attack, train_params)
7778
self._classifier: PyTorchClassifier
@@ -104,7 +105,6 @@ def fit(
104105
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
105106
the target classifier.
106107
"""
107-
import os
108108
import torch
109109

110110
logger.info("Performing adversarial training with OAAT protocol")
@@ -302,7 +302,6 @@ def fit_generator(
302302
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
303303
the target classifier.
304304
"""
305-
import os
306305
import torch
307306

308307
logger.info("Performing adversarial training with OAAT protocol")
@@ -895,7 +894,7 @@ def update_learning_rate(
895894
else:
896895
raise ValueError(f"lr_schedule {lr_schedule} not supported")
897896

898-
def _attack_lpips( # type: ignore
897+
def _attack_lpips(
899898
self,
900899
x: np.ndarray,
901900
y: np.ndarray,
@@ -993,7 +992,7 @@ def _one_step_adv_example(
993992

994993
return x_adv
995994

996-
def _compute_perturbation( # pylint: disable=W0221
995+
def _compute_perturbation(
997996
self, x: "torch.Tensor", x_init: "torch.Tensor", y: "torch.Tensor", training_mode: bool = False
998997
) -> "torch.Tensor":
999998
"""
@@ -1010,9 +1009,6 @@ def _compute_perturbation( # pylint: disable=W0221
10101009
"""
10111010
import torch
10121011

1013-
# Pick a small scalar to avoid division by 0
1014-
tol = 10e-8
1015-
10161012
self._classifier.model.train(mode=training_mode)
10171013
self._lpips_classifier.model.train(mode=training_mode)
10181014

@@ -1124,17 +1120,17 @@ def _compute_perturbation( # pylint: disable=W0221
11241120

11251121
elif self._train_params["norm"] == 1:
11261122
ind = tuple(range(1, len(x.shape)))
1127-
grad = grad / (torch.sum(grad.abs(), dim=ind, keepdims=True) + tol) # type: ignore
1123+
grad = grad / (torch.sum(grad.abs(), dim=ind, keepdims=True) + EPS) # type: ignore
11281124

11291125
elif self._train_params["norm"] == 2:
11301126
ind = tuple(range(1, len(x.shape)))
1131-
grad = grad / (torch.sqrt(torch.sum(grad * grad, axis=ind, keepdims=True)) + tol) # type: ignore
1127+
grad = grad / (torch.sqrt(torch.sum(grad * grad, axis=ind, keepdims=True)) + EPS) # type: ignore
11321128

11331129
assert x.shape == grad.shape
11341130

11351131
return grad
11361132

1137-
def _apply_perturbation( # pylint: disable=W0221
1133+
def _apply_perturbation(
11381134
self, x: "torch.Tensor", perturbation: "torch.Tensor", eps_step: Union[int, float, np.ndarray]
11391135
) -> "torch.Tensor":
11401136
"""
@@ -1173,8 +1169,6 @@ def _projection(
11731169
"""
11741170
import torch
11751171

1176-
# Pick a small scalar to avoid division by 0
1177-
tol = 10e-8
11781172
values_tmp = values.reshape(values.shape[0], -1)
11791173

11801174
if norm_p == 2:
@@ -1187,7 +1181,7 @@ def _projection(
11871181
values_tmp
11881182
* torch.min(
11891183
torch.tensor([1.0], dtype=torch.float32).to(self._classifier.device),
1190-
eps / (torch.norm(values_tmp, p=2, dim=1) + tol),
1184+
eps / (torch.norm(values_tmp, p=2, dim=1) + EPS),
11911185
).unsqueeze_(-1)
11921186
)
11931187

@@ -1201,14 +1195,14 @@ def _projection(
12011195
values_tmp
12021196
* torch.min(
12031197
torch.tensor([1.0], dtype=torch.float32).to(self._classifier.device),
1204-
eps / (torch.norm(values_tmp, p=1, dim=1) + tol),
1198+
eps / (torch.norm(values_tmp, p=1, dim=1) + EPS),
12051199
).unsqueeze_(-1)
12061200
)
12071201

12081202
elif norm_p in [np.inf, "inf"]:
12091203
if isinstance(eps, np.ndarray):
1210-
eps = eps * np.ones_like(values.cpu())
1211-
eps = eps.reshape([eps.shape[0], -1]) # type: ignore
1204+
eps_array = eps * np.ones_like(values.cpu())
1205+
eps = eps_array.reshape([eps_array.shape[0], -1])
12121206

12131207
values_tmp = values_tmp.sign() * torch.min(
12141208
values_tmp.abs(), torch.tensor([eps], dtype=torch.float32).to(self._classifier.device)

0 commit comments

Comments
 (0)