Skip to content

Commit f85597f

Browse files
Muhammad Zaid HameedMuhammad Zaid Hameed
Muhammad Zaid Hameed
authored and
Muhammad Zaid Hameed
committed
style corrections
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
1 parent 1d65b5d commit f85597f

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

art/defences/trainer/adversarial_trainer_awp_pytorch.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
375375
return train_loss, train_acc, train_n
376376

377377
def _weight_perturbation(
378-
self, x_batch: np.ndarray, x_batch_pert: np.ndarray, y_batch: np.ndarray
378+
self, x_batch: "torch.Tensor", x_batch_pert: "torch.Tensor", y_batch: "torch.Tensor"
379379
) -> Dict[str, "torch.Tensor"]:
380380
"""
381381
Calculate wight perturbation for a batch of data.
@@ -416,15 +416,15 @@ def _weight_perturbation(
416416
"Incorrect mode provided for base adversarial training. 'mode' must be among 'PGD' and 'TRADES'."
417417
)
418418

419-
self._proxy_classifier._optimizer.zero_grad() # pylint: disable=W0212
419+
self._proxy_classifier._optimizer.zero_grad() # type: ignore # pylint: disable=W0212
420420
loss.backward()
421-
self._proxy_classifier._optimizer.step() # pylint: disable=W0212
421+
self._proxy_classifier._optimizer.step() # type: ignore # pylint: disable=W0212
422422

423423
params_dict_proxy, _ = self._calculate_model_params(self._proxy_classifier)
424424

425425
for name in list_keys:
426426
perturbation = params_dict_proxy[name]["param"] - params_dict[name]["param"]
427-
perturbation = perturbation.reshape(params_dict[name]["size"])
427+
perturbation = torch.reshape(perturbation, params_dict[name]["size"]) # type: ignore
428428
scale = params_dict[name]["norm"] / (perturbation.norm() + EPS)
429429
w_perturb[name] = scale * perturbation
430430

@@ -444,7 +444,7 @@ def _calculate_model_params(
444444

445445
import torch
446446

447-
params_dict = OrderedDict()
447+
params_dict = OrderedDict() # type: ignore
448448
list_params = []
449449
for name, param in p_classifier._model.state_dict().items(): # pylint: disable=W0212
450450
if len(param.size()) <= 1:

examples/adversarial_training_awp.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
from art.attacks.evasion import ProjectedGradientDescent
1919

2020
"""
21-
For this example we choose the PreActResNet18 model as used in the paper
21+
For this example we choose the PreActResNet18 model as used in the paper
2222
(https://proceedings.neurips.cc/paper/2020/file/1ef91c212e30e14bf125e9374262401f-Paper.pdf)
2323
The code for the model architecture has been adopted from
2424
https://github.com/csdongxian/AWP/blob/main/AT_AWP/preactresnet.py
2525
"""
2626

2727

2828
class PreActBlock(nn.Module):
29-
'''Pre-activation version of the BasicBlock.'''
29+
"""Pre-activation version of the BasicBlock."""
30+
3031
expansion = 1
3132

3233
def __init__(self, in_planes, planes, stride=1):
@@ -36,22 +37,23 @@ def __init__(self, in_planes, planes, stride=1):
3637
self.bn2 = nn.BatchNorm2d(planes)
3738
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
3839

39-
if stride != 1 or in_planes != self.expansion*planes:
40+
if stride != 1 or in_planes != self.expansion * planes:
4041
self.shortcut = nn.Sequential(
41-
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
42+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
4243
)
4344

4445
def forward(self, x):
4546
out = F.relu(self.bn1(x))
46-
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
47+
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
4748
out = self.conv1(out)
4849
out = self.conv2(F.relu(self.bn2(out)))
4950
out += shortcut
5051
return out
5152

5253

5354
class PreActBottleneck(nn.Module):
54-
'''Pre-activation version of the original Bottleneck module.'''
55+
"""Pre-activation version of the original Bottleneck module."""
56+
5557
expansion = 4
5658

5759
def __init__(self, in_planes, planes, stride=1):
@@ -61,16 +63,16 @@ def __init__(self, in_planes, planes, stride=1):
6163
self.bn2 = nn.BatchNorm2d(planes)
6264
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
6365
self.bn3 = nn.BatchNorm2d(planes)
64-
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
66+
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
6567

66-
if stride != 1 or in_planes != self.expansion*planes:
68+
if stride != 1 or in_planes != self.expansion * planes:
6769
self.shortcut = nn.Sequential(
68-
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
70+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
6971
)
7072

7173
def forward(self, x):
7274
out = F.relu(self.bn1(x))
73-
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
75+
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
7476
out = self.conv1(out)
7577
out = self.conv2(F.relu(self.bn2(out)))
7678
out = self.conv3(F.relu(self.bn3(out)))
@@ -89,10 +91,10 @@ def __init__(self, block, num_blocks, num_classes=10):
8991
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
9092
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
9193
self.bn = nn.BatchNorm2d(512 * block.expansion)
92-
self.linear = nn.Linear(512*block.expansion, num_classes)
94+
self.linear = nn.Linear(512 * block.expansion, num_classes)
9395

9496
def _make_layer(self, block, planes, num_blocks, stride):
95-
strides = [stride] + [1]*(num_blocks-1)
97+
strides = [stride] + [1] * (num_blocks - 1)
9698
layers = []
9799
for stride in strides:
98100
layers.append(block(self.in_planes, planes, stride))
@@ -113,7 +115,7 @@ def forward(self, x):
113115

114116

115117
def PreActResNet18(num_classes=10):
116-
return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)
118+
return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes)
117119

118120

119121
class CIFAR10_dataset(Dataset):
@@ -202,7 +204,9 @@ def __len__(self):
202204
)
203205

204206
# Step 4: Create the trainer object - AdversarialTrainerAWPPyTorch
205-
trainer = AdversarialTrainerAWPPyTorch(classifier, proxy_classifier, attack, mode="PGD", gamma=gamma, beta=6.0, warmup=warmup)
207+
trainer = AdversarialTrainerAWPPyTorch(
208+
classifier, proxy_classifier, attack, mode="PGD", gamma=gamma, beta=6.0, warmup=warmup
209+
)
206210

207211

208212
# Build a Keras image augmentation object and wrap it in ART

0 commit comments

Comments
 (0)