Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in **attack_adversarial_patch_pytorch_yolo.ipynb** #2148

Closed
Louquinze opened this issue May 15, 2023 · 3 comments · Fixed by #2169
Closed

Error in **attack_adversarial_patch_pytorch_yolo.ipynb** #2148

Louquinze opened this issue May 15, 2023 · 3 comments · Fixed by #2169
Assignees

Comments

@Louquinze
Copy link

Louquinze commented May 15, 2023

Describe the bug
I am trying to run the example notebook attack_adversarial_patch_pytorch_yolo.ipynb. In cell 7 I get a KeyError.

Adversarial Patch PyTorch:   0%|                                                                                            | 0/1000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/ad_patch_torch.py", line 185, in <module>
    patch, patch_mask = ap.generate(x=x, y=target)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py", line 606, in generate
    _ = self._train_step(images=images, target=target, mask=None)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py", line 188, in _train_step
    loss = self._loss(images, target, mask)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py", line 251, in _loss
    loss = self.estimator.compute_loss(x=patched_input, y=target)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/estimators/object_detection/pytorch_yolo.py", line 632, in compute_loss
    loss_components, _ = self._get_losses(x=x, y=y)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/estimators/object_detection/pytorch_yolo.py", line 380, in _get_losses
    x_preprocessed, y_preprocessed = self._preprocess_and_convert_inputs(x=x, y=y, fit=False, no_grad=False)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/estimators/object_detection/pytorch_yolo.py", line 309, in _preprocess_and_convert_inputs
    if y is not None and isinstance(y[0]["boxes"], np.ndarray):
KeyError: 0

I was also checking the content of y:

{'boxes': tensor([[[  0.00000,   0.00000,   8.17627,   8.57283],
         [  0.00000,   0.00000,  27.70321,   8.76418],
         [  1.62518,   0.00000,  34.54520,   9.07174],
         ...,
         [495.75159, 549.26038, 636.07825, 640.00000],
         [539.20776, 543.38226, 640.00000, 640.00000],
         [557.95374, 531.07214, 640.00000, 640.00000]]]), 'labels': tensor([[0, 0, 0,  ..., 0, 0, 0]]), 'scores': tensor([[2.62016e-05, 2.61472e-05, 1.82523e-05,  ..., 3.00497e-06, 6.19944e-06, 4.10037e-06]])}

So it looks like that y isn't a list or array but a dictionary and therefore does not feature the label 0.

When i change the condition to:

if y is not None and isinstance(y["boxes"], np.ndarray):
    ....

I get the following error:

Traceback (most recent call last):
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/ad_patch_torch.py", line 185, in <module>
    patch, patch_mask = ap.generate(x=x, y=target)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py", line 606, in generate
    _ = self._train_step(images=images, target=target, mask=None)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py", line 188, in _train_step
    loss = self._loss(images, target, mask)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py", line 251, in _loss
    loss = self.estimator.compute_loss(x=patched_input, y=target)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/estimators/object_detection/pytorch_yolo.py", line 634, in compute_loss
    loss_components, _ = self._get_losses(x=x, y=y)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/estimators/object_detection/pytorch_yolo.py", line 382, in _get_losses
    x_preprocessed, y_preprocessed = self._preprocess_and_convert_inputs(x=x, y=y, fit=False, no_grad=False)
  File "/home/lukas/PycharmProjects/ExperimentalAttack/adversarial-robustness-toolbox/art/estimators/object_detection/pytorch_yolo.py", line 325, in _preprocess_and_convert_inputs
    x_tensor.requires_grad = True
RuntimeError: you can only change requires_grad flags of leaf variables.

To Reproduce
Steps to reproduce the behavior:
Simple run the jupyter notebook

Expected behavior
Create a adverserial Patch

System information (please complete the following information):

@beat-buesser
Copy link
Collaborator

Hi @Louquinze Thank you very much for using ART and the detailed investigation! We'll try to reproduce the issue as soon as possible and come back.

@beat-buesser
Copy link
Collaborator

@kieranfraser Are you able to reproduce this issue?

@kieranfraser
Copy link
Collaborator

Hi @Louquinze, @beat-buesser,

Yes, I was able to reproduce this issue - due to recent updates to PyTorchYolo, the format of targets passed to the estimator in AdversarialPatchPyTorch were incorrect and gradients of input tensors were not set. I have created a PR addressing this issue which contains a notebook demonstrating generation of patches for YOLO using the AdversarialPatchPyTorch and RobustDPatch classes.

Whilst debugging this, I identified another possible issue related to the loss computation. The pytorchyolo and yolov5 libraries do not have a relatively recent fix for 0 division when calculating alpha. Investigating this further. But you should be able to generate patches with the notebook referenced @Louquinze - let us know if you run into any issues with it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants