-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathorthogonal_pgd.py
215 lines (181 loc) · 11.6 KB
/
orthogonal_pgd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import math
import numpy as np
from tqdm.auto import tqdm
import torch
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class PGD():
def __init__(self, classifier, detector, classifier_loss=None, detector_loss=None, steps=100,
alpha=1/255, eps=8/255, use_projection=True, projection_norm='linf', target=None, lmbd=0,
k=None, project_detector=False, project_classifier=False, img_min=0, img_max=1, device=None):
'''
:param classifier: model used for classification
:param detector: model used for detection
:param classifier_loss: loss used for classification model
:param detector_loss: loss used for detection model. Need to have __call__ method which outputs adversarial scores ranging from 0 to 1 (0 if not afversarial and 1 if adversarial)
:param steps: number of steps for which to perform gradient descent/ascent
:param alpha: step size
:param eps: constraint on noise that can be applied to images
:param use_projection: True if gradients should be projected onto each other
:param projection_norm: 'linf' or 'l2' for regularization of gradients
:param target: target label to attack. if None, an untargeted attack is run
:param lmbd: hyperparameter for 'f + lmbd * g' when 'use_projection' is False
:param k: if not None, take gradients of g onto f every kth step
:param project_detector: if True, take gradients of g onto f
:param project_classifier: if True, take gradients of f onto g
'''
self.classifier = classifier
self.detector = detector
self.steps = steps
self.alpha = alpha
self.eps = eps
self.classifier_loss = classifier_loss
self.detector_loss = detector_loss
self.use_projection = use_projection
self.projection_norm = projection_norm
self.project_classifier = project_classifier
self.project_detector = project_detector
self.target = target
self.lmbd = lmbd
self.k = k
self.img_min = img_min
self.img_max = img_max
self.device = device
# metrics to keep track of
self.all_classifier_losses = []
self.all_detector_losses = []
def attack_batch(self, inputs, targets):
adv_images = inputs.clone().detach()
original_inputs_numpy = inputs.clone().detach().numpy()
# alarm_targets = torch.tensor(np.zeros(len(inputs)).reshape(-1, 1))
# ideally no adversarial images should be detected
alarm_targets = torch.tensor(np.zeros(len(inputs)))
batch_size = inputs.shape[0]
# targeted attack
if self.target is not None:
targeted_targets = targets.to(self.device)
advx_final = inputs.detach().numpy()
rep_sim_final = np.zeros(inputs.shape[0])
label_sim_final = np.zeros(inputs.shape[0])
progress = tqdm(range(self.steps))
for step in progress:
adv_images.requires_grad = True
# calculating gradient of classifier w.r.t. images
outputs = self.classifier(adv_images.to(self.device))
classifier_labels = outputs.max(-1)[1]
if self.target is not None:
loss_classifier = 1 * self.classifier_loss(outputs, targeted_targets)
else:
loss_classifier = self.classifier_loss(outputs, targets.to(self.device))
loss_classifier.backward(retain_graph=True)
grad_classifier = adv_images.grad.cpu().detach()
# calculating gradient of detector w.r.t. images
adv_images.grad = None
ssl_labels, aug_labels, rep_similarities = self.detector(adv_images.to(self.device))
label_similarities = (classifier_labels.unsqueeze(dim=1) == aug_labels.max(-1)[1].to(self.device)).sum(-1)
if self.target is not None:
ssl_cls_loss = self.classifier_loss(aug_labels.reshape(-1, 10).to(self.device), targeted_targets.repeat_interleave(aug_labels.shape[1]).to(self.device))
ssl_rep_loss = rep_similarities.mean()
loss_detector = ssl_rep_loss * -1 + ssl_cls_loss
else:
loss_detector = self.classifier_loss(ssl_labels.to(self.device), targets.to(self.device))
loss_detector.requires_grad_(True)
loss_detector.backward()
grad_detector = adv_images.grad.cpu().detach()
self.all_classifier_losses.append(loss_classifier.detach().data.item())
self.all_detector_losses.append(loss_detector.detach().data.item())
progress.set_description("Losses (%.3f/%.3f)" % (np.mean(self.all_classifier_losses[-10:]),
np.mean(self.all_detector_losses[-10:])))
if self.target is not None:
has_attack_succeeded = (outputs.cpu().detach().numpy().argmax(1)==targeted_targets.cpu().numpy())
else:
has_attack_succeeded = (outputs.cpu().detach().numpy().argmax(1)!=targets.numpy())
adv_images_np = adv_images.cpu().detach().numpy()
# print(torch.max(torch.abs(adv_images-inputs)))
# print('b',torch.max(torch.abs(torch.tensor(advx_final)-inputs)))
sim_mean = rep_similarities.mean(-1)
for i in range(len(advx_final)):
if has_attack_succeeded[i] and ((rep_sim_final[i] < sim_mean[i]) or (label_similarities[i] > label_sim_final[i])):
# if has_attack_succeeded[i] and ((rep_sim_final[i] < sim_mean[i]) and (label_similarities[i] > label_sim_final[i])):
# if has_attack_succeeded[i] and (rep_sim_final[i] < sim_mean[i]):
# print("assign", i, np.max(advx_final[i]-original_inputs_numpy[i]))
advx_final[i] = adv_images_np[i]
rep_sim_final[i] = sim_mean[i]
label_sim_final[i] = label_similarities[i]
#print("Update", i, adv_scores[i])
# using hyperparameter to combine gradient of classifier and gradient of detector
if not self.use_projection:
grad = grad_classifier + self.lmbd * grad_detector
else:
if self.project_detector:
# using Orthogonal Projected Gradient Descent
# projection of gradient of detector on gradient of classifier
# then grad_d' = grad_d - (project grad_d onto grad_c)
# if step == 0 or step == 999:
# print('save')
# torch.save((torch.bmm(grad_detector.view(batch_size, 1, -1), grad_classifier.view(batch_size, -1, 1)))/(1e-20+torch.bmm(grad_classifier.view(batch_size, 1, -1), grad_classifier.view(batch_size, -1, 1))).view(-1, 1, 1), f'./pth/detector2classifier_{step}.pth')
# torch.save(grad_classifier.view(batch_size, 1, -1), f'./pth/grad_classifier_{step}.pth')
grad_detector_proj = grad_detector - torch.bmm((torch.bmm(grad_detector.view(batch_size, 1, -1), grad_classifier.view(batch_size, -1, 1)))/(1e-20+torch.bmm(grad_classifier.view(batch_size, 1, -1), grad_classifier.view(batch_size, -1, 1))).view(-1, 1, 1), grad_classifier.view(batch_size, 1, -1)).view(grad_detector.shape)
else:
# if step == 0 or step == 999:
# print('save')
# torch.save(grad_detector, f'./pth/select_grad_d_{step}.pth')
grad_detector_proj = grad_detector
if self.project_classifier:
# using Orthogonal Projected Gradient Descent
# projection of gradient of detector on gradient of classifier
# then grad_c' = grad_c - (project grad_c onto grad_d)
# if step == 0 or step == 999:
# print('save')
# torch.save((torch.bmm(grad_classifier.view(batch_size, 1, -1), grad_detector.view(batch_size, -1, 1)))/(1e-20+torch.bmm(grad_detector.view(batch_size, 1, -1), grad_detector.view(batch_size, -1, 1))).view(-1, 1, 1), f'./pth/classifier2detector_{step}.pth')
# torch.save(grad_detector.view(batch_size, 1, -1), f'./pth/grad_detector_{step}.pth')
grad_classifier_proj = grad_classifier - torch.bmm((torch.bmm(grad_classifier.view(batch_size, 1, -1), grad_detector.view(batch_size, -1, 1)))/(1e-20+torch.bmm(grad_detector.view(batch_size, 1, -1), grad_detector.view(batch_size, -1, 1))).view(-1, 1, 1), grad_detector.view(batch_size, 1, -1)).view(grad_classifier.shape)
else:
# if step == 0 or step == 999:
# print('save')
# torch.save(grad_classifier, f'./pth/select_grad_c_{step}.pth')
grad_classifier_proj = grad_classifier
# making sure adversarial images have crossed decision boundary
outputs_perturbed = outputs.cpu().detach().numpy()
if self.target is not None:
outputs_perturbed[np.arange(targeted_targets.shape[0]), targets] += .05
has_attack_succeeded = np.array((outputs_perturbed.argmax(1)==targeted_targets.cpu().numpy())[:,None,None,None],dtype=np.float32)
else:
outputs_perturbed[np.arange(targets.shape[0]), targets] += .05
has_attack_succeeded = np.array((outputs_perturbed.argmax(1)!=targets.numpy())[:,None,None,None],dtype=np.float32)
if self.k:
# take gradients of g onto f every kth step
if i%self.k==0:
grad = grad_detector_proj
else:
grad = grad_classifier_proj
else:
grad = grad_classifier_proj * (1-has_attack_succeeded) + grad_detector_proj * has_attack_succeeded
if np.any(np.isnan(grad.numpy())):
print(np.mean(np.isnan(grad.numpy())))
print("ABORT")
break
if self.target is not None:
grad = -grad
# l2 regularization
if self.projection_norm == 'l2':
grad_norms = torch.norm(grad.view(batch_size, -1), p=2, dim=1) + 1e-20
grad = grad / grad_norms.view(batch_size, 1, 1, 1)
# linf regularization
elif self.projection_norm == 'linf':
grad = torch.sign(grad)
else:
raise Exception('Incorrect Projection Norm')
adv_images = adv_images.detach() + self.alpha * grad
delta = torch.clamp(adv_images - torch.tensor(original_inputs_numpy), min=-self.eps, max=self.eps)
adv_images = torch.clamp(torch.tensor(original_inputs_numpy) + delta, min=self.img_min, max=self.img_max).detach()
return torch.tensor(advx_final)
def attack(self, inputs, targets, batch_size=30):
adv_images = []
number_batch = int(math.ceil(len(inputs) / batch_size))
for index in range(number_batch):
start = index * batch_size
end = min((index + 1) * batch_size, len(inputs))
batch_samples, batch_targets = inputs[start:end], targets[start:end]
batch_adv_images = self.attack_batch(batch_samples, batch_targets)
adv_images.append(batch_adv_images)
return torch.cat(adv_images)