@@ -72,6 +72,9 @@ def __init__(
72
72
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
73
73
:param verbose: Show progress bars.
74
74
"""
75
+ import torch # lgtm [py/repeated-import]
76
+ from torch .autograd import Function
77
+
75
78
super ().__init__ (
76
79
device_type = device_type ,
77
80
is_fitted = True ,
@@ -83,50 +86,89 @@ def __init__(
83
86
self .verbose = verbose
84
87
self ._check_params ()
85
88
89
+ class RandomCutout (Function ): # pylint: disable=W0223
90
+ """
91
+ Function running Preprocessor.
92
+ """
93
+
94
+ @staticmethod
95
+ def forward (ctx , input ): # pylint: disable=W0622,W0221
96
+ ctx .save_for_backward (input )
97
+ n , _ , height , width = input .shape
98
+ masks = torch .ones_like (input )
99
+
100
+ # generate a random bounding box per image
101
+ for idx in trange (n , desc = "Cutout" , disable = not self .verbose ):
102
+ # uniform sampling
103
+ center_x = torch .randint (0 , height , (1 ,))
104
+ center_y = torch .randint (0 , width , (1 ,))
105
+ bby1 = torch .clamp (center_y - self .length // 2 , 0 , height )
106
+ bbx1 = torch .clamp (center_x - self .length // 2 , 0 , width )
107
+ bby2 = torch .clamp (center_y + self .length // 2 , 0 , height )
108
+ bbx2 = torch .clamp (center_x + self .length // 2 , 0 , width )
109
+
110
+ # zero out the bounding box
111
+ masks [idx , :, bbx1 :bbx2 , bby1 :bby2 ] = 0 # type: ignore
112
+
113
+ return input * masks
114
+
115
+ @staticmethod
116
+ def backward (ctx , grad_output ): # pylint: disable=W0221
117
+ return grad_output
118
+
119
+ self ._random_cutout = RandomCutout
120
+
86
121
def forward (
87
122
self , x : "torch.Tensor" , y : Optional ["torch.Tensor" ] = None
88
123
) -> Tuple ["torch.Tensor" , Optional ["torch.Tensor" ]]:
89
124
"""
90
125
Apply Cutout data augmentation to sample `x`.
91
126
92
- :param x: Sample to compress with shape of `NCHW` or `NHWC`. The `x` values are expected to be in
93
- the data range [0, 1] or [0, 255].
127
+ :param x: Sample to cut out with shape of `NCHW`, `NHWC`, `NCFHW` or `NFHWC`.
128
+ `x` values are expected to be in the data range [0, 1] or [0, 255].
94
129
:param y: Labels of the sample `x`. This function does not affect them in any way.
95
130
:return: Data augmented sample.
96
131
"""
97
- import torch # lgtm [py/repeated-import]
98
-
99
132
x_ndim = len (x .shape )
100
133
134
+ # NHWC/NCFHW/NFHWC --> NCHW.
101
135
if x_ndim == 4 :
102
136
if self .channels_first :
103
- # NCHW
104
- n , _ , height , width = x .shape
137
+ x_nchw = x
105
138
else :
106
- # NHWC
107
- n , height , width , _ = x .shape
139
+ # NHWC --> NCHW
140
+ x_nchw = x .permute (0 , 3 , 1 , 2 )
141
+ elif x_ndim == 5 :
142
+ if self .channels_first :
143
+ # NCFHW --> NFCHW --> NCHW
144
+ nb_clips , channels , clip_size , height , width = x .shape
145
+ x_nchw = x .permute (0 , 2 , 1 , 3 , 4 ).reshape (nb_clips * clip_size , channels , height , width )
146
+ else :
147
+ # NFHWC --> NHWC --> NCHW
148
+ nb_clips , clip_size , height , width , channels = x .shape
149
+ x_nchw = x .reshape (nb_clips * clip_size , height , width , channels ).permute (0 , 3 , 1 , 2 )
108
150
else :
109
- raise ValueError ("Unrecognized input dimension. Cutout can only be applied to image data." )
110
-
111
- masks = torch .ones_like (x )
151
+ raise ValueError ("Unrecognized input dimension. Cutout can only be applied to image and video data." )
112
152
113
- # generate a random bounding box per image
114
- for idx in trange (n , desc = "Cutout" , disable = not self .verbose ):
115
- # uniform sampling
116
- center_x = torch .randint (0 , height , (1 ,))
117
- center_y = torch .randint (0 , width , (1 ,))
118
- bby1 = torch .clamp (center_y - self .length // 2 , 0 , height )
119
- bbx1 = torch .clamp (center_x - self .length // 2 , 0 , width )
120
- bby2 = torch .clamp (center_y + self .length // 2 , 0 , height )
121
- bbx2 = torch .clamp (center_x + self .length // 2 , 0 , width )
153
+ # apply random cutout
154
+ x_nchw = self ._random_cutout .apply (x_nchw )
122
155
123
- # zero out the bounding box
156
+ # NHWC/NCFHW/NFHWC <-- NCHW.
157
+ if x_ndim == 4 :
124
158
if self .channels_first :
125
- masks [ idx , :, bbx1 : bbx2 , bby1 : bby2 ] = 0 # type: ignore
159
+ x_aug = x_nchw
126
160
else :
127
- masks [idx , bbx1 :bbx2 , bby1 :bby2 , :] = 0 # type: ignore
128
-
129
- x_aug = x * masks
161
+ # NHWC <-- NCHW
162
+ x_aug = x_nchw .permute (0 , 2 , 3 , 1 )
163
+ elif x_ndim == 5 : # lgtm [py/redundant-comparison]
164
+ if self .channels_first :
165
+ # NCFHW <-- NFCHW <-- NCHW
166
+ x_nfchw = x_nchw .reshape (nb_clips , clip_size , channels , height , width )
167
+ x_aug = x_nfchw .permute (0 , 2 , 1 , 3 , 4 )
168
+ else :
169
+ # NFHWC <-- NHWC <-- NCHW
170
+ x_nhwc = x_nchw .permute (0 , 2 , 3 , 1 )
171
+ x_aug = x_nhwc .reshape (nb_clips , clip_size , height , width , channels )
130
172
131
173
return x_aug , y
132
174
0 commit comments