-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathearlystop.py
136 lines (123 loc) · 5.94 KB
/
earlystop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# from models import *
import torch
import numpy as np
from apex import amp, optimizers
import torch.nn as nn
def earlystop(args, model, optimizer, data, target, step_size, epsilon, perturb_steps, tau, randominit_type, loss_fn, rand_init=True, omega=0):
'''
The implematation of early-stopped PGD
Following the Alg.1 in our FAT paper <https://arxiv.org/abs/2002.11242>
:param step_size: the PGD step size
:param epsilon: the perturbation bound
:param perturb_steps: the maximum PGD step
:param tau: the step controlling how early we should stop interations when wrong adv data is found
:param randominit_type: To decide the type of random inirialization (random start for searching adv data)
:param rand_init: To decide whether to initialize adversarial sample with random noise (random start for searching adv data)
:param omega: random sample parameter for adv data generation (this is for escaping the local minimum.)
:return: output_adv (friendly adversarial data) output_target (targets), output_natural (the corresponding natrual data), count (average backword propagations count)
'''
model.eval()
if args.is_softmax or args.is_gumbel or args.is_diri:
model.arch_param.requires_grad = False
K = perturb_steps
count = 0
output_target = []
output_adv = []
output_natural = []
control = (torch.ones(len(target)) * tau).cuda()
# print(tau)
# Initialize the adversarial data with random noise
if rand_init:
if randominit_type == "normal_distribution_randominit":
iter_adv = data.detach() + 0.001 * torch.randn(data.shape).cuda().detach()
iter_adv = torch.clamp(iter_adv, 0.0, 1.0)
if randominit_type == "uniform_randominit":
iter_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon,
epsilon, data.shape)).float().cuda()
iter_adv = torch.clamp(iter_adv, 0.0, 1.0)
else:
iter_adv = data.cuda().detach()
iter_clean_data = data.cuda().detach()
iter_target = target.cuda().detach()
output_iter_clean_data = model(data)
while K > 0:
iter_adv.requires_grad_()
output = model(iter_adv)
pred = output.max(1, keepdim=True)[1]
output_index = []
iter_index = []
# Calculate the indexes of adversarial data those still needs to be iterated
for idx in range(len(pred)):
if pred[idx] != target[idx]:
if control[idx] == 0:
output_index.append(idx)
else:
control[idx] -= 1
iter_index.append(idx)
else:
iter_index.append(idx)
# Add adversarial data those do not need any more iteration into set output_adv
if len(output_index) != 0:
if len(output_target) == 0:
# incorrect adv data should not keep iterated
output_adv = iter_adv[output_index].reshape(
-1, 3, 32, 32).cuda()
output_natural = iter_clean_data[output_index].reshape(
-1, 3, 32, 32).cuda()
output_target = iter_target[output_index].reshape(-1).cuda()
else:
# incorrect adv data should not keep iterated
output_adv = torch.cat(
(output_adv, iter_adv[output_index].reshape(-1, 3, 32, 32).cuda()), dim=0)
output_natural = torch.cat(
(output_natural, iter_clean_data[output_index].reshape(-1, 3, 32, 32).cuda()), dim=0)
output_target = torch.cat(
(output_target, iter_target[output_index].reshape(-1).cuda()), dim=0)
# calculate gradient
model.zero_grad()
with torch.enable_grad():
if loss_fn == "cent":
loss_adv = nn.CrossEntropyLoss(
reduction='mean')(output, iter_target)
if loss_fn == "kl":
criterion_kl = nn.KLDivLoss(size_average=False).cuda()
loss_adv = criterion_kl(F.log_softmax(
output, dim=1), F.softmax(output_iter_clean_data, dim=1))
if not args.use_amp:
loss_adv.backward()#retain_graph=True
else:
with amp.scale_loss(loss_adv, optimizer) as scaled_loss:
scaled_loss.backward()#retain_graph=True
grad = iter_adv.grad
# update iter adv
if len(iter_index) != 0:
control = control[iter_index]
iter_adv = iter_adv[iter_index]
iter_clean_data = iter_clean_data[iter_index]
iter_target = iter_target[iter_index]
output_iter_clean_data = output_iter_clean_data[iter_index]
grad = grad[iter_index]
eta = step_size * grad.sign()
iter_adv = iter_adv.detach() + eta + omega * \
torch.randn(iter_adv.shape).detach().cuda()
iter_adv = torch.min(
torch.max(iter_adv, iter_clean_data - epsilon), iter_clean_data + epsilon)
iter_adv = torch.clamp(iter_adv, 0, 1)
count += len(iter_target)
else:
output_adv = output_adv.detach()
return output_adv, output_target, output_natural, count
K = K - 1
if len(output_target) == 0:
output_target = iter_target.reshape(-1).squeeze().cuda()
output_adv = iter_adv.reshape(-1, 3, 32, 32).cuda()
output_natural = iter_clean_data.reshape(-1, 3, 32, 32).cuda()
else:
output_adv = torch.cat(
(output_adv, iter_adv.reshape(-1, 3, 32, 32)), dim=0).cuda()
output_target = torch.cat(
(output_target, iter_target.reshape(-1)), dim=0).squeeze().cuda()
output_natural = torch.cat(
(output_natural, iter_clean_data.reshape(-1, 3, 32, 32).cuda()), dim=0).cuda()
output_adv = output_adv.detach()
return output_adv, output_target, output_natural, count