-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdetect_safetynet_transferability.py
115 lines (94 loc) · 4.48 KB
/
detect_safetynet_transferability.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Detect adv/clean from the hidden feature"""
from __future__ import absolute_import
from __future__ import print_function
import os
import argparse
from datasets import get_data
from models import get_model
import numpy as np
import sklearn.metrics
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold, train_test_split
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor
DATASETS = ['dr', 'cxr', 'derm']
ATTACKS = ['fgsm', 'bim', 'jsma', 'cw-l2', 'clean']
TEST_SIZE = {'dr': 0.2, 'cxr': 0.2, 'derm': 0.2}
def balance_data(X, y, sub_sample=False): # assume that the positive samples is less than negative ones
idx_pos, = np.where(y>0)
idx_neg, = np.where(y<1)
if sub_sample:
idx_neg = np.random.choice(idx_neg, len(idx_pos))
else: # over sample
idx_pos = np.random.choice(idx_pos, len(idx_neg))
idx_resample = np.concatenate([idx_neg, idx_pos])
return X[idx_resample], y[idx_resample]
def detect(args):
assert args.dataset in ['mnist', 'cifar-10', 'svhn', 'dr', 'cxr', 'derm'], \
"Dataset parameter must be either 'mnist', 'cifar-10', 'svhn', 'dr', 'cxr', or 'derm'"
# assert args.attack in ['fgsm', 'bim', 'jsma', 'deepfool', 'pgd', 'ead', 'cw-l2', 'cw-lid', 'cw-li',
# 'fgsm_bb', 'bim_bb', 'jsma_bb', 'deepfool_bb', 'pgd_bb', 'ead_bb', 'cw-l2_bb', 'cw-lid_bb', 'cw-li_bb',], \
# "Attack parameter must be either 'fgsm', 'bim', 'jsma', 'deepfool', " \
# "'pgd', 'ead', 'cw-l2', 'cw-lid'"
attacks = ['fgsm', 'bim', 'pgd']
# load feature/label data and balance data
# clean examples
_, _, cX_test, cy_test = get_data(args.dataset, onehot=False, load_feat='clean') # clean feat
cX_test, cy_test = balance_data(cX_test, cy_test) # balance over positive/negative examples
# adv examples
aX, ay = [], []
for i, attack in enumerate(attacks):
_, _, aX_test, ay_test = get_data(args.dataset, onehot=False, load_feat=attack) # attack feat
aX_test, ay_test = balance_data(aX_test, ay_test) # balance over positive/negative examples
aX.append(aX_test)
ay.append(ay_test)
auc_matrix = np.zeros([len(attacks), len(attacks)])
for aid_train in range(len(attacks)): # attack ID for training
print('training on attack:', attacks[aid_train])
X = np.concatenate([cX_test, aX[aid_train]])
y = np.concatenate([np.zeros(len(cX_test)), np.ones(len(aX[aid_train]))])
threshold = np.mean(X, axis=0, keepdims=True)
# quantization
X = (X > threshold).astype('float32')
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE[args.dataset])
clf = SVC(gamma=2.8, probability=True) #RandomForestClassifier(30) # SVC(gamma=2.8, probability=True)
clf.fit(X_train, y_train)
auc = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
print('AUC:', auc)
auc_matrix[aid_train, aid_train] = auc
for aid_test in range(len(attacks)):
if aid_train == aid_test:
continue
X_test = np.concatenate([cX_test, aX[aid_test]])
y_test = np.concatenate([np.zeros(len(cX_test)), np.ones(len(aX[aid_test]))])
X_test = (X_test > threshold).astype('float32')
auc = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
print('AUC:', auc, '(test on', attacks[aid_test]+')')
auc_matrix[aid_train, aid_test] = auc
print(auc_matrix)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'-d', '--dataset',
help="Dataset to use",
required=True, type=str
)
# parser.add_argument(
# '-a', '--attack',
# help="Attack to use train the discriminator; either 'fgsm', 'bim-a', 'bim-b', 'jsma', 'cw-l2'",
# required=True, type=str
# )
args = parser.parse_args()
detect(args)
# with ThreadPoolExecutor(24) as e:
# for ds in ['dr', 'cxr', 'derm']:
# for atk in ['pgd']: # 'fgsm', 'bim', 'deepfool', 'pgd']:
# argv = ['-d', ds, '-a', atk]
# print('\n$> ', argv)
# args = parser.parse_args(argv)
# e.submit(detect, args) #detect(args)
# e.result()