23
23
"""
24
24
from __future__ import absolute_import , division , print_function , unicode_literals
25
25
26
+ from collections import OrderedDict
26
27
import logging
28
+ import os
27
29
import time
28
30
from typing import Optional , Tuple , TYPE_CHECKING , List , Dict , Union
29
- from collections import OrderedDict
30
- import six
31
31
32
+ import six
32
33
import numpy as np
33
34
from tqdm .auto import trange
34
- from art import config
35
35
36
+ from art import config
36
37
from art .defences .trainer .adversarial_trainer_oaat import AdversarialTrainerOAAT
37
38
from art .estimators .classification .pytorch import PyTorchClassifier
38
39
from art .data_generators import DataGenerator
@@ -71,7 +72,7 @@ def __init__(
71
72
:param lpips_classifier: Weight averaging model for calculating activations.
72
73
:param list_avg_models: list of models for weight averaging.
73
74
:param attack: attack to use for data augmentation in adversarial training.
74
- :param train_params: training parmaters ' dictionary related to adversarial training
75
+ :param train_params: training parameters ' dictionary related to adversarial training
75
76
"""
76
77
super ().__init__ (classifier , proxy_classifier , lpips_classifier , list_avg_models , attack , train_params )
77
78
self ._classifier : PyTorchClassifier
@@ -104,7 +105,6 @@ def fit(
104
105
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
105
106
the target classifier.
106
107
"""
107
- import os
108
108
import torch
109
109
110
110
logger .info ("Performing adversarial training with OAAT protocol" )
@@ -302,7 +302,6 @@ def fit_generator(
302
302
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
303
303
the target classifier.
304
304
"""
305
- import os
306
305
import torch
307
306
308
307
logger .info ("Performing adversarial training with OAAT protocol" )
@@ -895,7 +894,7 @@ def update_learning_rate(
895
894
else :
896
895
raise ValueError (f"lr_schedule { lr_schedule } not supported" )
897
896
898
- def _attack_lpips ( # type: ignore
897
+ def _attack_lpips (
899
898
self ,
900
899
x : np .ndarray ,
901
900
y : np .ndarray ,
@@ -993,7 +992,7 @@ def _one_step_adv_example(
993
992
994
993
return x_adv
995
994
996
- def _compute_perturbation ( # pylint: disable=W0221
995
+ def _compute_perturbation (
997
996
self , x : "torch.Tensor" , x_init : "torch.Tensor" , y : "torch.Tensor" , training_mode : bool = False
998
997
) -> "torch.Tensor" :
999
998
"""
@@ -1010,9 +1009,6 @@ def _compute_perturbation( # pylint: disable=W0221
1010
1009
"""
1011
1010
import torch
1012
1011
1013
- # Pick a small scalar to avoid division by 0
1014
- tol = 10e-8
1015
-
1016
1012
self ._classifier .model .train (mode = training_mode )
1017
1013
self ._lpips_classifier .model .train (mode = training_mode )
1018
1014
@@ -1124,17 +1120,17 @@ def _compute_perturbation( # pylint: disable=W0221
1124
1120
1125
1121
elif self ._train_params ["norm" ] == 1 :
1126
1122
ind = tuple (range (1 , len (x .shape )))
1127
- grad = grad / (torch .sum (grad .abs (), dim = ind , keepdims = True ) + tol ) # type: ignore
1123
+ grad = grad / (torch .sum (grad .abs (), dim = ind , keepdims = True ) + EPS ) # type: ignore
1128
1124
1129
1125
elif self ._train_params ["norm" ] == 2 :
1130
1126
ind = tuple (range (1 , len (x .shape )))
1131
- grad = grad / (torch .sqrt (torch .sum (grad * grad , axis = ind , keepdims = True )) + tol ) # type: ignore
1127
+ grad = grad / (torch .sqrt (torch .sum (grad * grad , axis = ind , keepdims = True )) + EPS ) # type: ignore
1132
1128
1133
1129
assert x .shape == grad .shape
1134
1130
1135
1131
return grad
1136
1132
1137
- def _apply_perturbation ( # pylint: disable=W0221
1133
+ def _apply_perturbation (
1138
1134
self , x : "torch.Tensor" , perturbation : "torch.Tensor" , eps_step : Union [int , float , np .ndarray ]
1139
1135
) -> "torch.Tensor" :
1140
1136
"""
@@ -1173,8 +1169,6 @@ def _projection(
1173
1169
"""
1174
1170
import torch
1175
1171
1176
- # Pick a small scalar to avoid division by 0
1177
- tol = 10e-8
1178
1172
values_tmp = values .reshape (values .shape [0 ], - 1 )
1179
1173
1180
1174
if norm_p == 2 :
@@ -1187,7 +1181,7 @@ def _projection(
1187
1181
values_tmp
1188
1182
* torch .min (
1189
1183
torch .tensor ([1.0 ], dtype = torch .float32 ).to (self ._classifier .device ),
1190
- eps / (torch .norm (values_tmp , p = 2 , dim = 1 ) + tol ),
1184
+ eps / (torch .norm (values_tmp , p = 2 , dim = 1 ) + EPS ),
1191
1185
).unsqueeze_ (- 1 )
1192
1186
)
1193
1187
@@ -1201,14 +1195,14 @@ def _projection(
1201
1195
values_tmp
1202
1196
* torch .min (
1203
1197
torch .tensor ([1.0 ], dtype = torch .float32 ).to (self ._classifier .device ),
1204
- eps / (torch .norm (values_tmp , p = 1 , dim = 1 ) + tol ),
1198
+ eps / (torch .norm (values_tmp , p = 1 , dim = 1 ) + EPS ),
1205
1199
).unsqueeze_ (- 1 )
1206
1200
)
1207
1201
1208
1202
elif norm_p in [np .inf , "inf" ]:
1209
1203
if isinstance (eps , np .ndarray ):
1210
- eps = eps * np .ones_like (values .cpu ())
1211
- eps = eps .reshape ([eps .shape [0 ], - 1 ]) # type: ignore
1204
+ eps_array = eps * np .ones_like (values .cpu ())
1205
+ eps = eps_array .reshape ([eps_array .shape [0 ], - 1 ])
1212
1206
1213
1207
values_tmp = values_tmp .sign () * torch .min (
1214
1208
values_tmp .abs (), torch .tensor ([eps ], dtype = torch .float32 ).to (self ._classifier .device )
0 commit comments