Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2ddd96c

Browse files
committedSep 14, 2022
Fixed linting
1 parent dac6855 commit 2ddd96c

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed
 

‎art/defences/preprocessor/cutout/cutout.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
from typing import Optional, Tuple
2929

3030
import numpy as np
31-
from scipy.ndimage.filters import median_filter
3231

33-
from art.utils import CLIP_VALUES_TYPE
3432
from art.defences.preprocessor.preprocessor import Preprocessor
3533

3634
logger = logging.getLogger(__name__)
@@ -102,11 +100,11 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
102100
masks[i, :, bbx1:bbx2, bby1:bby2] = 0
103101
else:
104102
masks[i, bbx1:bbx2, bby1:bby2, :] = 0
105-
103+
106104
x_aug = x * masks
107105

108106
return x_aug, y
109107

110108
def _check_params(self) -> None:
111109
if self.length <= 0:
112-
raise ValueError('Bounding box length must be positive.')
110+
raise ValueError("Bounding box length must be positive.")

‎art/defences/preprocessor/cutout/cutout_pytorch.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636

3737
class CutoutPyTorch(PreprocessorPyTorch):
3838
"""
39-
Implement the Cutout data augmentation defense
39+
Implement the Cutout data augmentation defence
4040
"""
4141

42-
params = ['length', "channels_first"]
42+
params = ["length", "channels_first"]
4343

4444
def __init__(
4545
self,
@@ -98,22 +98,22 @@ def forward(
9898
masks = torch.ones(*x.shape)
9999
for i in range(n):
100100
# uniform sampling
101-
cy = torch.randint(h)
102-
cx = torch.randint(w)
103-
bby1 = torch.clamp(cy - self.length // 2, 0, h)
104-
bbx1 = torch.clamp(cx - self.length // 2, 0, w)
105-
bby2 = torch.clamp(cy + self.length // 2, 0, h)
106-
bbx2 = torch.clamp(cx + self.length // 2, 0, w)
101+
cy = np.random.randint(h)
102+
cx = np.random.randint(w)
103+
bby1 = np.clip(cy - self.length // 2, 0, h)
104+
bbx1 = np.clip(cx - self.length // 2, 0, w)
105+
bby2 = np.clip(cy + self.length // 2, 0, h)
106+
bbx2 = np.clip(cx + self.length // 2, 0, w)
107107

108108
if self.channels_first:
109109
masks[i, :, bbx1:bbx2, bby1:bby2] = 0
110110
else:
111111
masks[i, bbx1:bbx2, bby1:bby2, :] = 0
112-
112+
113113
x_aug = x * masks
114114

115115
return x_aug, y
116116

117117
def _check_params(self) -> None:
118118
if self.length <= 0:
119-
raise ValueError('Bounding box length must be positive.')
119+
raise ValueError("Bounding box length must be positive.")

0 commit comments

Comments
 (0)
Please sign in to comment.