30
30
31
31
@pytest .fixture ()
32
32
def get_pytorch_yolo (get_default_cifar10_subset ):
33
- """
34
- This class tests the PyTorchYolo object detector.
35
- """
36
33
import cv2
37
34
import torch
38
35
@@ -53,11 +50,8 @@ def __init__(self, model):
53
50
def forward (self , x , targets = None ):
54
51
if self .training :
55
52
outputs = self .model (x )
56
- # loss is averaged over a batch. Thus, for patch generation use batch_size = 1
57
53
loss , _ = compute_loss (outputs , targets , self .model )
58
-
59
54
loss_components = {"loss_total" : loss }
60
-
61
55
return loss_components
62
56
else :
63
57
return self .model (x )
@@ -157,7 +151,9 @@ def test_pytorch_predict(art_warning, get_pytorch_yolo):
157
151
try :
158
152
result = object_seeker .predict (x = x_test )
159
153
154
+ assert len (result ) == len (x_test )
160
155
assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" ]
156
+ assert np .all (result [0 ]["scores" ] >= 0.3 )
161
157
162
158
except ARTTestException as e :
163
159
art_warning (e )
@@ -184,6 +180,7 @@ def test_pytorch_certify(art_warning, get_pytorch_yolo):
184
180
result = object_seeker .certify (x = x_test , patch_size = 0.01 , offset = 0.1 )
185
181
186
182
assert len (result ) == len (x_test )
183
+ assert np .any (result [0 ])
187
184
188
185
except ARTTestException as e :
189
186
art_warning (e )
0 commit comments