33
33
from art .estimators .classification .pytorch import PyTorchClassifier
34
34
from art .data_generators import DataGenerator
35
35
from art .attacks .attack import EvasionAttack
36
+ from art .utils import check_and_transform_label_format
36
37
37
38
if TYPE_CHECKING :
38
39
import torch
@@ -97,6 +98,15 @@ def fit(
97
98
ind = np .arange (len (x ))
98
99
99
100
logger .info ("Adversarial Training TRADES" )
101
+ y = check_and_transform_label_format (y , nb_classes = self .classifier .nb_classes )
102
+
103
+ if validation_data is not None :
104
+ (x_test , y_test ) = validation_data
105
+ y_test = check_and_transform_label_format (y_test , nb_classes = self .classifier .nb_classes )
106
+
107
+ x_preprocessed_test , y_preprocessed_test = self ._classifier ._apply_preprocessing ( # pylint: disable=W0212
108
+ x_test , y_test , fit = True
109
+ )
100
110
101
111
for i_epoch in trange (nb_epochs , desc = "Adversarial Training TRADES - Epochs" ):
102
112
# Shuffle the examples
@@ -107,7 +117,6 @@ def fit(
107
117
train_n = 0.0
108
118
109
119
for batch_id in range (nb_batches ):
110
-
111
120
# Create batch data
112
121
x_batch = x [ind [batch_id * batch_size : min ((batch_id + 1 ) * batch_size , x .shape [0 ])]].copy ()
113
122
y_batch = y [ind [batch_id * batch_size : min ((batch_id + 1 ) * batch_size , x .shape [0 ])]]
@@ -125,9 +134,9 @@ def fit(
125
134
126
135
# compute accuracy
127
136
if validation_data is not None :
128
- ( x_test , y_test ) = validation_data
129
- output = np .argmax (self . predict ( x_test ) , axis = 1 )
130
- nb_correct_pred = np . sum ( output == np . argmax ( y_test , axis = 1 ))
137
+ output = np . argmax ( self . predict ( x_preprocessed_test ), axis = 1 )
138
+ nb_correct_pred = np . sum ( output == np .argmax (y_preprocessed_test , axis = 1 ) )
139
+
131
140
logger .info (
132
141
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f" ,
133
142
i_epoch ,
@@ -188,7 +197,6 @@ def fit_generator(
188
197
train_n = 0.0
189
198
190
199
for batch_id in range (nb_batches ): # pylint: disable=W0612
191
-
192
200
# Create batch data
193
201
x_batch , y_batch = generator .get_batch ()
194
202
x_batch = x_batch .copy ()
@@ -232,6 +240,8 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
232
240
x_batch_pert = self ._attack .generate (x_batch , y = y_batch )
233
241
234
242
# Apply preprocessing
243
+ y_batch = check_and_transform_label_format (y_batch , nb_classes = self .classifier .nb_classes )
244
+
235
245
x_preprocessed , y_preprocessed = self ._classifier ._apply_preprocessing ( # pylint: disable=W0212
236
246
x_batch , y_batch , fit = True
237
247
)
0 commit comments