Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SmoothMix Addition #1705

Closed
wants to merge 11 commits into from
Prev Previous commit
Next Next commit
moved smoothmix training to fit method of smoothmix estimator and cre…
…ated a separate estimator for smoothmix

Signed-off-by: Ajay Sarjoo <ajay.sarjoo@outlook.com>
asarj committed May 19, 2022
commit fe79acbde278db56e616707af7a002c6ff913fa0
117 changes: 7 additions & 110 deletions art/estimators/certification/randomized_smoothing/pytorch.py
Original file line number Diff line number Diff line change
@@ -30,13 +30,11 @@
from art.config import ART_NUMPY_DTYPE
from art.estimators.classification.pytorch import PyTorchClassifier
from art.estimators.certification.randomized_smoothing.randomized_smoothing import RandomizedSmoothingMixin
from art.estimators.certification.randomized_smoothing.smoothmix.train_smoothmix import fit_pytorch
from art.defences.preprocessor.gaussian_augmentation import GaussianAugmentation


if TYPE_CHECKING:
# pylint: disable=C0412
import torch

from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
from art.defences.preprocessor import Preprocessor
from art.defences.postprocessor import Postprocessor
@@ -52,25 +50,7 @@ class PyTorchRandomizedSmoothing(RandomizedSmoothingMixin, PyTorchClassifier):
| Paper link: https://arxiv.org/abs/1902.02918
"""

estimator_params = PyTorchClassifier.estimator_params + [
"sample_size",
"scale",
"alpha",
"num_noise_vec",
"train_multi_noise",
"attack_type",
"epsilon",
"num_steps",
"warmup",
"lbd",
"gamma",
"beta",
"gauss_num",
"eta",
"mix_step",
"maxnorm_s",
"maxnorm"
]
estimator_params = PyTorchClassifier.estimator_params + ["sample_size", "scale", "alpha"]

def __init__(
self,
@@ -79,7 +59,6 @@ def __init__(
input_shape: Tuple[int, ...],
nb_classes: int,
optimizer: Optional["torch.optim.Optimizer"] = None, # type: ignore
scheduler: Optional["torch.optim.lr_scheduler"] = None, # type: ignore
channels_first: bool = True,
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
@@ -89,21 +68,6 @@ def __init__(
sample_size: int = 32,
scale: float = 0.1,
alpha: float = 0.001,
num_noise_vec: int = 1,
train_multi_noise: bool = False,
attack_type: str = "PGD",
epsilon: float = 64.0,
num_steps: int = 10,
warmup: int = 1,
lbd: float = 12.0,
gamma: float = 8.0,
beta: float = 16.0,
gauss_num: int = 16,
eta: float = 1.0,
mix_step: int = 0,
maxnorm_s: Optional[float] = None,
maxnorm: Optional[float] = None,
**kwargs
):
"""
Create a randomized smoothing classifier.
@@ -115,7 +79,6 @@ def __init__(
:param input_shape: The shape of one input instance.
:param nb_classes: The number of classes of the model.
:param optimizer: The optimizer used to train the classifier.
:param scheduler: The learning rate scheduler used to train the classifier.
:param channels_first: Set channels first or last.
:param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
maximum values allowed for features. If floats are provided, these will be used as the range of all
@@ -130,21 +93,6 @@ def __init__(
:param sample_size: Number of samples for smoothing.
:param scale: Standard deviation of Gaussian noise added.
:param alpha: The failure probability of smoothing.
:param num_noise_vec: Number of noise vectors
:param train_multi_noise: Determines whether to use all the noise samples or not
:param attack_type: The type of attack to use
:param epsilon: Maximum perturbation that the attacker can introduce
:param num_steps: Number of attack updates
:param warmup: Warm-up strategy that is gradually increased for the first
10 epochs up to the original value of epsilon
:param lbd: Weight of robustness loss in Macer
:param gamma: Value to multiply the LR by
:param beta: The inverse function temperature in Macer
:param gauss_num: Number of gaussian samples per input
:param eta: Hyperparameter to control the relative strength of the mixup loss in SmoothMix
:param mix_step: Determines which sample to use for the clean side in SmoothMix
:param maxnorm_s: initial value of alpha * mix_step
:param maxnorm: initial value of alpha * mix_step for adversarial examples
"""
super().__init__(
model=model,
@@ -161,37 +109,15 @@ def __init__(
sample_size=sample_size,
scale=scale,
alpha=alpha,
num_noise_vec=num_noise_vec,
train_multi_noise=train_multi_noise,
attack_type=attack_type,
epsilon=epsilon,
num_steps=num_steps,
warmup=warmup,
lbd=lbd,
gamma=gamma,
beta=beta,
gauss_num=gauss_num,
eta=eta,
mix_step=mix_step,
maxnorm_s=maxnorm_s,
maxnorm=maxnorm,
**kwargs
)
self.scheduler = scheduler

def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: bool, **kwargs) -> np.ndarray:
x = x.astype(ART_NUMPY_DTYPE)
return PyTorchClassifier.predict(self, x=x, batch_size=batch_size, training_mode=training_mode, **kwargs)

def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None:
if "train_method" in kwargs:
if kwargs.get("train_method") == "smoothmix":
return fit_pytorch(self, x, y, batch_size, nb_epochs, **kwargs)

g_a = GaussianAugmentation(sigma=self.scale, augmentation=False)
x_rs, _ = g_a(x)
x_rs = x_rs.astype(ART_NUMPY_DTYPE)
return PyTorchClassifier.fit(self, x_rs, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
x = x.astype(ART_NUMPY_DTYPE)
return PyTorchClassifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)

def fit( # pylint: disable=W0221
self,
@@ -247,7 +173,8 @@ def loss_gradient( # type: ignore
:type sampling: `bool`
:return: Array of gradients of the same shape as `x`.
"""
import torch # lgtm [py/repeated-import]
import torch # lgtm [py/repeated-import]

sampling = kwargs.get("sampling")

if sampling:
@@ -272,8 +199,7 @@ def loss_gradient( # type: ignore
inputs_noise_t = inputs_repeat_t + noise
if self.clip_values is not None:
inputs_noise_t.clamp(
torch.tensor(self.clip_values[0]),
torch.tensor(self.clip_values[1]),
torch.tensor(self.clip_values[0]), torch.tensor(self.clip_values[1]),
)

model_outputs = self._model(inputs_noise_t)[-1]
@@ -315,32 +241,3 @@ def class_gradient(
`(batch_size, 1, input_shape)` when `label` parameter is specified.
"""
raise NotImplementedError

# pylint: disable=R0201
def _requires_grad_(
self, model: "torch.nn.Module", requires_grad: bool
) -> None:
"""
Enables gradients for the given model

:param model: The model to enable gradients for
:param requires_grad: Boolean to enable or disable gradients for all layers in the model
"""
for param in model.parameters():
param.requires_grad_(requires_grad)

# pylint: disable=R0201
def _get_minibatches(
self, x: np.ndarray, y: np.ndarray, num_batches: int
):
"""
Generate batches of the training data and target values

:param X: Training data
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or indices of shape
(nb_samples,).
:param num_batches: The number of batches to generate
"""
batch_size = len(x) // num_batches
for i in range(num_batches):
yield x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size]
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@
from tqdm.auto import tqdm

from art.config import ART_NUMPY_DTYPE
from art.defences.preprocessor.gaussian_augmentation import GaussianAugmentation

logger = logging.getLogger(__name__)

@@ -43,68 +44,18 @@ class RandomizedSmoothingMixin(ABC):
| Paper link: https://arxiv.org/abs/1902.02918
"""

def __init__(
self,
sample_size: int,
*args,
scale: float = 0.1,
alpha: float = 0.001,
num_noise_vec: int = 1,
train_multi_noise: bool = False,
attack_type: str = "PGD",
epsilon: float = 64.0,
num_steps: int = 10,
warmup: int = 1,
lbd: float = 12.0,
gamma: float = 8.0,
beta: float = 16.0,
gauss_num: int = 16,
eta: float = 1.0,
mix_step: int = 0,
maxnorm_s: Optional[float] = None,
maxnorm: Optional[float] = None,
**kwargs,
) -> None:
def __init__(self, sample_size: int, *args, scale: float = 0.1, alpha: float = 0.001, **kwargs,) -> None:
"""
Create a randomized smoothing wrapper.

:param sample_size: Number of samples for smoothing.
:param scale: Standard deviation of Gaussian noise added.
:param alpha: The failure probability of smoothing.
:param num_noise_vec: Number of noise vectors
:param train_multi_noise: Determines whether to use all the noise samples or not
:param attack_type: The type of attack to use
:param epsilon: Maximum perturbation that the attacker can introduce
:param num_steps: Number of attack updates
:param warmup: Warm-up strategy that is gradually increased for the first
10 epochs up to the original value of epsilon
:param lbd: Weight of robustness loss in Macer
:param gamma: Value to multiply the LR by
:param beta: The inverse function temperature in Macer
:param gauss_num: Number of gaussian samples per input
:param eta: Hyperparameter to control the relative strength of the mixup loss in SmoothMix
:param mix_step: Determines which sample to use for the clean side in SmoothMix
:param maxnorm_s: initial value of alpha * mix_step
:param maxnorm: initial value of alpha * mix_step for adversarial examples
"""
super().__init__(*args, **kwargs) # type: ignore
self.sample_size = sample_size
self.scale = scale
self.alpha = alpha
self.num_noise_vec = num_noise_vec
self.train_multi_noise = train_multi_noise
self.attack_type = attack_type
self.epsilon = epsilon
self.num_steps = num_steps
self.warmup = warmup
self.lbd = lbd
self.gamma = gamma
self.beta = beta
self.gauss_num = gauss_num
self.eta = eta
self.mix_step = mix_step
self.maxnorm_s = maxnorm_s
self.maxnorm = maxnorm

def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: bool, **kwargs) -> np.ndarray:
"""
@@ -183,7 +134,9 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
and providing it takes no effect.
"""
self._fit_classifier(x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
g_a = GaussianAugmentation(sigma=self.scale, augmentation=False)
x_rs, _ = g_a(x)
self._fit_classifier(x_rs, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)

def certify(self, x: np.ndarray, n: int, batch_size: int = 32) -> Tuple[np.ndarray, np.ndarray]:
"""

This file was deleted.

5 changes: 5 additions & 0 deletions art/estimators/certification/smoothmix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
SmoothMix estimators
"""
from art.estimators.certification.smoothmix.smoothmix import SmoothMixMixin
from art.estimators.certification.smoothmix.pytorch import PyTorchSmoothMix
330 changes: 330 additions & 0 deletions art/estimators/certification/smoothmix/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements SmoothMix applied to classifier predictions.
| Paper link: https://arxiv.org/pdf/2111.09277.pdf
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import List, Optional, Tuple, Union, TYPE_CHECKING

import numpy as np

from art.config import ART_NUMPY_DTYPE
from art.estimators.classification.pytorch import PyTorchClassifier
from art.estimators.certification.smoothmix.smoothmix import SmoothMixMixin
from art.defences.preprocessor.gaussian_augmentation import GaussianAugmentation


if TYPE_CHECKING:
# pylint: disable=C0412
import torch
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
from art.defences.preprocessor import Preprocessor
from art.defences.postprocessor import Postprocessor

logger = logging.getLogger(__name__)


class PyTorchSmoothMix(SmoothMixMixin, PyTorchClassifier):
"""
Implementation of SmoothMix training, as introduced in Jeong et al. (2021)
| Paper link: https://arxiv.org/pdf/2111.09277.pdf
"""

estimator_params = PyTorchClassifier.estimator_params + [
"sample_size",
"scale",
"alpha",
"num_noise_vec",
"attack_type",
"epsilon",
"num_steps",
"warmup",
"lbd",
"gamma",
"beta",
"gauss_num",
"eta",
"mix_step",
"maxnorm_s",
"maxnorm",
]

def __init__(
self,
model: "torch.nn.Module",
loss: "torch.nn.modules.loss._Loss",
input_shape: Tuple[int, ...],
nb_classes: int,
optimizer: Optional["torch.optim.Optimizer"] = None, # type: ignore
scheduler: Optional["torch.optim.lr_scheduler"] = None, # type: ignore
channels_first: bool = True,
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
preprocessing: "PREPROCESSING_TYPE" = (0.0, 1.0),
device_type: str = "gpu",
sample_size: int = 32,
scale: float = 0.1,
alpha: float = 0.001,
num_noise_vec: int = 1,
attack_type: str = "PGD",
epsilon: float = 64.0,
num_steps: int = 10,
warmup: int = 1,
lbd: float = 12.0,
gamma: float = 8.0,
beta: float = 16.0,
gauss_num: int = 16,
eta: float = 1.0,
mix_step: int = 0,
maxnorm_s: Optional[float] = None,
maxnorm: Optional[float] = None,
**kwargs
):
"""
Create a SmoothMix classifier.
:param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits
output should be preferred where possible to ensure attack efficiency.
:param loss: The loss function for which to compute gradients for training. The target label must be raw
categorical, i.e. not converted to one-hot encoding.
:param input_shape: The shape of one input instance.
:param nb_classes: The number of classes of the model.
:param optimizer: The optimizer used to train the classifier.
:param scheduler: The learning rate scheduler used to train the classifier.
:param channels_first: Set channels first or last.
:param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
maximum values allowed for features. If floats are provided, these will be used as the range of all
features. If arrays are provided, each value will be considered the bound for a feature, thus
the shape of clip values needs to match the total number of features.
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
used for data preprocessing. The first value will be subtracted from the input. The input will then
be divided by the second one.
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
:param sample_size: Number of samples for smoothing.
:param scale: Standard deviation of Gaussian noise added.
:param alpha: The failure probability of smoothing.
:param num_noise_vec: Number of noise vectors
:param attack_type: The type of attack to use
:param epsilon: Maximum perturbation that the attacker can introduce
:param num_steps: Number of attack updates
:param warmup: Warm-up strategy that is gradually increased for the first
10 epochs up to the original value of epsilon
:param lbd: Weight of robustness loss in Macer
:param gamma: Value to multiply the LR by
:param beta: The inverse function temperature in Macer
:param gauss_num: Number of gaussian samples per input
:param eta: Hyperparameter to control the relative strength of the mixup loss in SmoothMix
:param mix_step: Determines which sample to use for the clean side in SmoothMix
:param maxnorm_s: initial value of alpha * mix_step
:param maxnorm: initial value of alpha * mix_step for adversarial examples
"""
super().__init__(
model=model,
loss=loss,
input_shape=input_shape,
nb_classes=nb_classes,
optimizer=optimizer,
channels_first=channels_first,
clip_values=clip_values,
preprocessing_defences=preprocessing_defences,
postprocessing_defences=postprocessing_defences,
preprocessing=preprocessing,
device_type=device_type,
sample_size=sample_size,
scale=scale,
alpha=alpha,
num_noise_vec=num_noise_vec,
attack_type=attack_type,
epsilon=epsilon,
num_steps=num_steps,
warmup=warmup,
lbd=lbd,
gamma=gamma,
beta=beta,
gauss_num=gauss_num,
eta=eta,
mix_step=mix_step,
maxnorm_s=maxnorm_s,
maxnorm=maxnorm,
**kwargs
)
self.scheduler = scheduler

def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: bool, **kwargs) -> np.ndarray:
x = x.astype(ART_NUMPY_DTYPE)
return PyTorchClassifier.predict(self, x=x, batch_size=batch_size, training_mode=training_mode, **kwargs)

def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None:
x = x.astype(ART_NUMPY_DTYPE)
return PyTorchClassifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)

def fit( # pylint: disable=W0221
self,
x: np.ndarray,
y: np.ndarray,
batch_size: int = 128,
nb_epochs: int = 10,
training_mode: bool = True,
**kwargs
):
"""
Fit the classifier on the training set `(x, y)`.
:param x: Training data.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or indices of shape
(nb_samples,).
:param batch_size: Batch size.
:key nb_epochs: Number of epochs to use for training
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
and providing it takes no effect.
:type kwargs: `dict`
:return: `None`
"""

# Set model mode
self._model.train(mode=training_mode)

SmoothMixMixin.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)

def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # type: ignore
"""
Perform prediction of the given classifier for a batch of inputs, taking an expectation over transformations.
:param x: Input samples.
:param batch_size: Batch size.
:param is_abstain: True if function will abstain from prediction and return 0s. Default: True
:type is_abstain: `boolean`
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
"""
return SmoothMixMixin.predict(self, x, batch_size=batch_size, training_mode=False, **kwargs)

def loss_gradient( # type: ignore
self, x: np.ndarray, y: np.ndarray, training_mode: bool = False, **kwargs
) -> np.ndarray:
"""
Compute the gradient of the loss function w.r.t. `x`.
:param x: Sample input with shape as expected by the model.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or indices of shape
(nb_samples,).
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
:param sampling: True if loss gradients should be determined with Monte Carlo sampling.
:type sampling: `bool`
:return: Array of gradients of the same shape as `x`.
"""
import torch # lgtm [py/repeated-import]

sampling = kwargs.get("sampling")

if sampling:
self._model.train(mode=training_mode)

# Apply preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=False)

# Check label shape
if self._reduce_labels:
y_preprocessed = np.argmax(y_preprocessed, axis=1)

# Convert the inputs to Tensors
inputs_t = torch.from_numpy(x_preprocessed).to(self._device)
inputs_t.requires_grad = True

# Convert the labels to Tensors
labels_t = torch.from_numpy(y_preprocessed).to(self._device)
inputs_repeat_t = inputs_t.repeat_interleave(self.sample_size, 0)

noise = torch.randn_like(inputs_repeat_t, device=self._device) * self.scale
inputs_noise_t = inputs_repeat_t + noise
if self.clip_values is not None:
inputs_noise_t.clamp(
torch.tensor(self.clip_values[0]), torch.tensor(self.clip_values[1]),
)

model_outputs = self._model(inputs_noise_t)[-1]
softmax = torch.nn.functional.softmax(model_outputs, dim=1)
average_softmax = (
softmax.reshape(-1, self.sample_size, model_outputs.shape[-1]).mean(1, keepdim=True).squeeze(1)
)
log_softmax = torch.log(average_softmax.clamp(min=1e-20))
loss = torch.nn.functional.nll_loss(log_softmax, labels_t)

# Clean gradients
self._model.zero_grad()

# Compute gradients
loss.backward()
gradients = inputs_t.grad.cpu().numpy().copy() # type: ignore
gradients = self._apply_preprocessing_gradient(x, gradients)
assert gradients.shape == x.shape

else:
gradients = PyTorchClassifier.loss_gradient(self, x=x, y=y, training_mode=training_mode, **kwargs)

return gradients

def class_gradient(
self, x: np.ndarray, label: Union[int, List[int], None] = None, training_mode: bool = False, **kwargs
) -> np.ndarray:
"""
Compute per-class derivatives of the given classifier w.r.t. `x` of original classifier.
:param x: Sample input with shape as expected by the model.
:param label: Index of a specific per-class derivative. If an integer is provided, the gradient of that class
output is computed for all samples. If multiple values as provided, the first dimension should
match the batch size of `x`, and each value will be used as target for its corresponding sample in
`x`. If `None`, then gradients for all classes will be computed for each sample.
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
:return: Array of gradients of input features w.r.t. each class in the form
`(batch_size, nb_classes, input_shape)` when computing for all classes, otherwise shape becomes
`(batch_size, 1, input_shape)` when `label` parameter is specified.
"""
raise NotImplementedError

# pylint: disable=R0201
def _requires_grad_(self, model: "torch.nn.Module", requires_grad: bool) -> None:
"""
Enables gradients for the given model
:param model: The model to enable gradients for
:param requires_grad: Boolean to enable or disable gradients for all layers in the model
"""
for param in model.parameters():
param.requires_grad_(requires_grad)

# pylint: disable=R0201
def _get_minibatches(self, x: np.ndarray, y: np.ndarray, num_batches: int):
"""
Generate batches of the training data and target values
:param X: Training data
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or indices of shape
(nb_samples,).
:param num_batches: The number of batches to generate
"""
batch_size = len(x) // num_batches
for i in range(num_batches):
yield x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size]
Original file line number Diff line number Diff line change
@@ -37,13 +37,14 @@ class SmoothMixPGD(object):
"""
Author's original implementation of a Smooth PGD attacker
"""

def __init__(
self,
steps: int,
mix_step: int,
alpha: float = 1.0,
maxnorm_s: Optional[float] = None,
maxnorm: Optional[float] = None
maxnorm: Optional[float] = None,
) -> None:
"""
Creates a Smooth PGD attacker
@@ -64,13 +65,7 @@ def __init__(
else:
self.maxnorm_s = maxnorm_s

def attack(
self,
model: torch.nn.Module,
inputs: torch.Tensor,
labels: torch.Tensor,
noises: List[Any]
):
def attack(self, model: torch.nn.Module, inputs: torch.Tensor, labels: torch.Tensor, noises: List[Any]):
"""
Attacks the model with the given inputs
@@ -80,7 +75,7 @@ def attack(
:param noises: The noise applied to each input in the attack
"""
if inputs.min() < 0 or inputs.max() > 1:
raise ValueError('Input values should be in the [0, 1] range.')
raise ValueError("Input values should be in the [0, 1] range.")

def _batch_l2norm(x: torch.Tensor) -> torch.Tensor:
"""
@@ -124,7 +119,7 @@ def _project(x: torch.Tensor, x_0: torch.Tensor, maxnorm: Optional[float] = None

logsoftmax = torch.log(avg_softmax.clamp(min=1e-20))

loss = F.nll_loss(logsoftmax, labels, reduction='sum')
loss = F.nll_loss(logsoftmax, labels, reduction="sum")

grad = torch.autograd.grad(loss, [adv])[0]
grad_norm = _batch_l2norm(grad).view(-1, 1, 1, 1)
423 changes: 423 additions & 0 deletions art/estimators/certification/smoothmix/smoothmix.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -60,16 +60,18 @@ def fit_pytorch(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs:
mix_step=self.mix_step,
alpha=self.alpha,
maxnorm=self.maxnorm,
maxnorm_s=self.maxnorm_s
maxnorm_s=self.maxnorm_s,
)

if self.optimizer is None: # pragma: no cover
raise ValueError(f"An optimizer is needed to train the model, but none for provided: {self.optimizer}")
if self.scheduler is None: # pragma: no cover
raise ValueError(f"A scheduler is needed to train the model, but none for provided: {self.scheduler}")
if attacker is None:
raise ValueError(f"A attacker is needed to smooth adversarially train the model, but \
none for provided: {self.attack_type}")
raise ValueError(
f"A attacker is needed to smooth adversarially train the model, but \
none for provided: {self.attack_type}"
)

num_batch = int(np.ceil(len(x) / float(batch_size)))
ind = np.arange(len(x))
@@ -96,9 +98,7 @@ def fit_pytorch(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs:
).to(self.device)

# pylint: disable=W0212
mini_batches = self._get_minibatches(
input_batch, output_batch, self.num_noise_vec
)
mini_batches = self._get_minibatches(input_batch, output_batch, self.num_noise_vec)

for (inputs, targets) in mini_batches:
inputs, targets = inputs.to(self.device), targets.to(self.device)
@@ -108,12 +108,7 @@ def fit_pytorch(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs:
# Attack and find adversarial examples
self._requires_grad_(self.model, False) # pylint: disable=W0212
self.model.eval()
inputs, inputs_adv = attacker.attack(
self.model,
inputs=inputs,
labels=targets,
noises=noises
)
inputs, inputs_adv = attacker.attack(self.model, inputs=inputs, labels=targets, noises=noises)
self.model.train()
self._requires_grad_(self.model, True) # pylint: disable=W0212

@@ -127,7 +122,7 @@ def fit_pytorch(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs:
if isinstance(clean_avg_sm, float):
clean_avg_sm = torch.Tensor(clean_avg_sm)

loss_xent = F.cross_entropy(logits_c, targets_c, reduction='none')
loss_xent = F.cross_entropy(logits_c, targets_c, reduction="none")

in_mix, targets_mix = _mixup_data(inputs, inputs_adv, clean_avg_sm, self.nb_classes)

@@ -139,7 +134,7 @@ def fit_pytorch(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs:
ind_correct = (top1_idx[:, 0] == targets).float()
ind_correct = ind_correct.repeat(self.num_noise_vec)

loss_mixup = F.kl_div(logits_mix_c, targets_mix_c, reduction='none').sum(1)
loss_mixup = F.kl_div(logits_mix_c, targets_mix_c, reduction="none").sum(1)
loss = loss_xent.mean() + self.eta * warmup_v * (ind_correct * loss_mixup).mean()

# compute gradient and do SGD step