Skip to content

Commit 0c1bd00

Browse files
committedAug 15, 2023
object seeker unit tests
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
1 parent ebb7341 commit 0c1bd00

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed
 

‎tests/estimators/object_detection/test_object_seeker.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@
3030

3131
@pytest.fixture()
3232
def get_pytorch_yolo(get_default_cifar10_subset):
33-
"""
34-
This class tests the PyTorchYolo object detector.
35-
"""
3633
import cv2
3734
import torch
3835

@@ -53,11 +50,8 @@ def __init__(self, model):
5350
def forward(self, x, targets=None):
5451
if self.training:
5552
outputs = self.model(x)
56-
# loss is averaged over a batch. Thus, for patch generation use batch_size = 1
5753
loss, _ = compute_loss(outputs, targets, self.model)
58-
5954
loss_components = {"loss_total": loss}
60-
6155
return loss_components
6256
else:
6357
return self.model(x)
@@ -157,7 +151,9 @@ def test_pytorch_predict(art_warning, get_pytorch_yolo):
157151
try:
158152
result = object_seeker.predict(x=x_test)
159153

154+
assert len(result) == len(x_test)
160155
assert list(result[0].keys()) == ["boxes", "labels", "scores"]
156+
assert np.all(result[0]["scores"] >= 0.3)
161157

162158
except ARTTestException as e:
163159
art_warning(e)
@@ -184,6 +180,7 @@ def test_pytorch_certify(art_warning, get_pytorch_yolo):
184180
result = object_seeker.certify(x=x_test, patch_size=0.01, offset=0.1)
185181

186182
assert len(result) == len(x_test)
183+
assert np.any(result[0])
187184

188185
except ARTTestException as e:
189186
art_warning(e)

0 commit comments

Comments
 (0)