Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ObjectSeeker Defense Implementation #2246

Merged
merged 33 commits into from
Sep 8, 2023
Merged
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
567019c
implement iou and nms
f4str Jun 30, 2023
5352e9f
template object seeker classes
f4str Jul 5, 2023
78eee10
objectseeker masked predictions
f4str Jul 14, 2023
4a487df
objectseeker masked prediction pruning
f4str Jul 14, 2023
57fa85e
objectseeker implementation
f4str Jul 19, 2023
29d6db6
move most objectseeker methods to the abstract base class
f4str Aug 1, 2023
e4573db
fix style checks
f4str Aug 1, 2023
216b8b4
add docstrings to methods
f4str Aug 1, 2023
cba4709
objectseeker certify partial implementation
f4str Aug 9, 2023
baa305f
complete objectseeker certify implementation
f4str Aug 10, 2023
b0616a0
fix style checks
f4str Aug 10, 2023
9e0ace0
add object seeker unit tests
f4str Aug 14, 2023
854a0f6
convert randomized smoothing unit tests to pytest
f4str Jul 6, 2023
1e776de
pytorch smooth mix implementation
f4str Jul 6, 2023
4eed259
smooth mix unit tests
f4str Jul 6, 2023
18598aa
cleanup smoothmix extraneous parameters
f4str Jul 6, 2023
3d8e17b
pytorch macer implementation
f4str Jul 7, 2023
793baf1
pytorch macer unit tests
f4str Jul 7, 2023
2d40553
add learning rate schedulers to tensorflow estimators
f4str Jul 7, 2023
3cef324
implement tensorflow macer
f4str Jul 7, 2023
e51b823
unit tests for tensorflow macer
f4str Jul 7, 2023
1057a08
pytorch smooth adv implementation
f4str Jul 7, 2023
ade9c45
unit tests for pytorch smooth adv
f4str Jul 10, 2023
0536ddb
tensorflow smooth adv implementation
f4str Jul 10, 2023
7163110
linting and style checks
f4str Jul 10, 2023
2204400
update randomized smoothing progress bars
f4str Jul 10, 2023
ebb7341
address review comments
f4str Aug 4, 2023
0c1bd00
object seeker unit tests
f4str Aug 15, 2023
215558c
fix style checks
f4str Aug 15, 2023
5b411f3
Merge branch 'dev_1.16.0' into object-seeker
f4str Aug 23, 2023
a892488
address review comments
f4str Aug 29, 2023
bf124b7
Merge branch 'dev_1.16.0' into object-seeker
beat-buesser Sep 4, 2023
46af3c2
Merge branch 'dev_1.16.0' into object-seeker
beat-buesser Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add learning rate schedulers to tensorflow estimators
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
f4str committed Aug 15, 2023
commit 2d40553c01ac5b51bc344d90473f59f4b929025a
Original file line number Diff line number Diff line change
@@ -177,7 +177,9 @@ def fit( # pylint: disable=W0221
dataset = TensorDataset(x_tensor, y_tensor)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)

m = torch.distributions.normal.Normal(torch.tensor([0.0], device=self.device), torch.tensor([1.0], device=self.device))
m = torch.distributions.normal.Normal(
torch.tensor([0.0], device=self.device), torch.tensor([1.0], device=self.device)
)

# Start training
for _ in tqdm(range(nb_epochs)):
12 changes: 9 additions & 3 deletions art/estimators/certification/randomized_smoothing/tensorflow.py
Original file line number Diff line number Diff line change
@@ -137,8 +137,9 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
shape (nb_samples,).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for
TensorFlow and providing it takes no effect.
:param kwargs: Dictionary of framework-specific arguments. This parameter currently only supports
"scheduler" which is an optional function that will be called at the end of every
epoch to adjust the learning rate.
"""
import tensorflow as tf

@@ -165,6 +166,8 @@ def train_step(model, images, labels):
else:
train_step = self._train_step

scheduler = kwargs.get("scheduler")

y = check_and_transform_label_format(y, nb_classes=self.nb_classes)

# Apply preprocessing
@@ -176,12 +179,15 @@ def train_step(model, images, labels):

train_ds = tf.data.Dataset.from_tensor_slices((x_preprocessed, y_preprocessed)).shuffle(10000).batch(batch_size)

for _ in tqdm(range(nb_epochs)):
for epoch in tqdm(range(nb_epochs)):
for images, labels in train_ds:
# Add random noise for randomized smoothing
images += tf.random.normal(shape=images.shape, mean=0.0, stddev=self.scale)
train_step(self.model, images, labels)

if scheduler is not None:
scheduler(epoch)

def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # type: ignore
"""
Perform prediction of the given classifier for a batch of inputs, taking an expectation over transformations.
24 changes: 18 additions & 6 deletions art/estimators/classification/tensorflow.py
Original file line number Diff line number Diff line change
@@ -957,8 +957,9 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
shape (nb_samples,).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for
TensorFlow and providing it takes no effect.
:param kwargs: Dictionary of framework-specific arguments. This parameter currently only supports
"scheduler" which is an optional function that will be called at the end of every
epoch to adjust the learning rate.
"""
import tensorflow as tf

@@ -985,6 +986,8 @@ def train_step(model, images, labels):
else:
train_step = self._train_step

scheduler = kwargs.get("scheduler")

y = check_and_transform_label_format(y, nb_classes=self.nb_classes)

# Apply preprocessing
@@ -996,19 +999,23 @@ def train_step(model, images, labels):

train_ds = tf.data.Dataset.from_tensor_slices((x_preprocessed, y_preprocessed)).shuffle(10000).batch(batch_size)

for _ in range(nb_epochs):
for epoch in range(nb_epochs):
for images, labels in train_ds:
train_step(self.model, images, labels)

if scheduler is not None:
scheduler(epoch)

def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwargs) -> None:
"""
Fit the classifier using the generator that yields batches as specified.

:param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
training in TensorFlow, it will.
:param nb_epochs: Number of epochs to use for training.
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for
TensorFlow and providing it takes no effect.
:param kwargs: Dictionary of framework-specific arguments. This parameter currently only supports
"scheduler" which is an optional function that will be called at the end of every
epoch to adjust the learning rate.
"""
import tensorflow as tf
from art.data_generators import TensorFlowV2DataGenerator
@@ -1036,6 +1043,8 @@ def train_step(model, images, labels):
else:
train_step = self._train_step

scheduler = kwargs.get("scheduler")

# Train directly in TensorFlow
from art.preprocessing.standardisation_mean_std.tensorflow import StandardisationMeanStdTensorFlow

@@ -1050,11 +1059,14 @@ def train_step(model, images, labels):
== (0, 1)
)
):
for _ in range(nb_epochs):
for epoch in range(nb_epochs):
for i_batch, o_batch in generator.iterator:
if self._reduce_labels:
o_batch = tf.math.argmax(o_batch, axis=1)
train_step(self._model, i_batch, o_batch)

if scheduler is not None:
scheduler(epoch)
else:
# Fit a generic data generator through the API
super().fit_generator(generator, nb_epochs=nb_epochs)