Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Add assert_is_*_link #285

Merged
merged 10 commits into from
Jun 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chainercv/links/model/segnet/segnet_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,6 @@ def predict(self, imgs):
dtype = score.dtype
score = resize(score, (H, W)).astype(dtype)

label = np.argmax(score, axis=0)
label = np.argmax(score, axis=0).astype(np.int32)
labels.append(label)
return labels
6 changes: 3 additions & 3 deletions chainercv/links/model/ssd/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def _suppress(self, raw_bbox, raw_score):
label.append(xp.array((l,) * len(bbox_l)))
score.append(score_l)

bbox = xp.vstack(bbox)
label = xp.hstack(label).astype(int)
score = xp.hstack(score)
bbox = xp.vstack(bbox).astype(np.float32)
label = xp.hstack(label).astype(np.int32)
score = xp.hstack(score).astype(np.float32)

return bbox, label, score

Expand Down
2 changes: 2 additions & 0 deletions chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from chainercv.utils.iterator import unzip # NOQA
from chainercv.utils.testing import assert_is_bbox # NOQA
from chainercv.utils.testing import assert_is_detection_dataset # NOQA
from chainercv.utils.testing import assert_is_detection_link # NOQA
from chainercv.utils.testing import assert_is_image # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing import ConstantStubLink # NOQA
from chainercv.utils.testing import generate_random_bbox # NOQA
2 changes: 2 additions & 0 deletions chainercv/utils/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from chainercv.utils.testing.assertions import assert_is_bbox # NOQA
from chainercv.utils.testing.assertions import assert_is_detection_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions import assert_is_image # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA
from chainercv.utils.testing.generate_random_bbox import generate_random_bbox # NOQA
2 changes: 2 additions & 0 deletions chainercv/utils/testing/assertions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox # NOQA
from chainercv.utils.testing.assertions.assert_is_detection_dataset import assert_is_detection_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_link import assert_is_semantic_segmentation_link # NOQA
59 changes: 59 additions & 0 deletions chainercv/utils/testing/assertions/assert_is_detection_link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import six

from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox


def assert_is_detection_link(link, n_fg_class):
"""Checks if a link satisfies detection link APIs.

This function checks if a given link satisfies detection link APIs
or not.
If the link does not satifiy the APIs, this function raises an
:class:`AssertionError`.

Args:
link: A link to be checked.
n_fg_class (int): The number of foreground classes.

"""

imgs = [
np.random.randint(0, 256, size=(3, 480, 640)).astype(np.float32),
np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)]

result = link.predict(imgs)
assert len(result) == 3, \
'Link must return three elements: bboxes, labels and scores.'
bboxes, labels, scores = result

assert len(bboxes) == len(imgs), \
'The length of bboxes must be same as that of imgs.'
assert len(labels) == len(imgs), \
'The length of labels must be same as that of imgs.'
assert len(scores) == len(imgs), \
'The length of scores must be same as that of imgs.'

for bbox, label, score in six.moves.zip(bboxes, labels, scores):
assert_is_bbox(bbox)

assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert label.dtype == np.int32, \
'The type of label must be numpy.int32.'
assert label.shape[1:] == (), \
'The shape of label must be (*,).'
assert len(label) == len(bbox), \
'The length of label must be same as that of bbox.'
if len(label) > 0:
assert label.min() >= 0 and label.max() < n_fg_class, \
'The value of label must be in [0, n_fg_class - 1].'

assert isinstance(score, np.ndarray), \
'score must be a numpy.ndarray.'
assert score.dtype == np.float32, \
'The type of score must be numpy.float32.'
assert score.shape[1:] == (), \
'The shape of score must be (*,).'
assert len(score) == len(bbox), \
'The length of score must be same as that of bbox.'
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import six


def assert_is_semantic_segmentation_link(link, n_class):
"""Checks if a link satisfies semantic segmentation link APIs.

This function checks if a given link satisfies semantic segmentation link
APIs or not.
If the link does not satifiy the APIs, this function raises an
:class:`AssertionError`.

Args:
link: A link to be checked.
n_class (int): The number of classes including background.

"""

imgs = [
np.random.randint(0, 256, size=(3, 480, 640)).astype(np.float32),
np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)]

labels = link.predict(imgs)
assert len(labels) == len(imgs), \
'The length of labels must be same as that of imgs.'

for img, label in six.moves.zip(imgs, labels):
assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert label.dtype == np.int32, \
'The type of label must be numpy.int32.'
assert label.shape == img.shape[1:], \
'The shape of label must be (H, W).'
assert label.min() >= 0 and label.max() < n_class, \
'The value of label must be in [0, n_class - 1].'
16 changes: 12 additions & 4 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,27 @@ Testing Utilities

assert_is_bbox
~~~~~~~~~~~~~~
.. autofunctions:: assert_is_bbox
.. autofunction:: assert_is_bbox

assert_is_detection_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_detection_dataset
.. autofunction:: assert_is_detection_dataset

assert_is_detection_link
~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_detection_link

assert_is_image
~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_image
.. autofunction:: assert_is_image

assert_is_semantic_segmentation_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_semantic_segmentation_dataset
.. autofunction:: assert_is_semantic_segmentation_dataset

assert_is_semantic_segmentation_link
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: assert_is_semantic_segmentation_link

ConstantStubLink
~~~~~~~~~~~~~~~~
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from chainer import testing
from chainer.testing import attr

from chainercv.utils import assert_is_detection_link

from dummy_faster_rcnn import DummyFasterRCNN


Expand Down Expand Up @@ -58,39 +60,13 @@ def test_call_gpu(self):
self.link.to_gpu()
self.check_call()

def check_predict(self):
imgs = [
_random_array(np, (3, 640, 480)),
_random_array(np, (3, 320, 320))]

bboxes, labels, scores = self.link.predict(imgs)

self.assertEqual(len(bboxes), len(imgs))
self.assertEqual(len(labels), len(imgs))
self.assertEqual(len(scores), len(imgs))

for bbox, label, score in zip(bboxes, labels, scores):
self.assertIsInstance(bbox, np.ndarray)
self.assertEqual(bbox.dtype, np.float32)
self.assertEqual(bbox.ndim, 2)
self.assertLessEqual(bbox.shape[0], self.n_roi)
self.assertEqual(bbox.shape[1], 4)

self.assertIsInstance(label, np.ndarray)
self.assertEqual(label.dtype, np.int32)
self.assertEqual(label.shape, (bbox.shape[0],))

self.assertIsInstance(score, np.ndarray)
self.assertEqual(score.dtype, np.float32)
self.assertEqual(score.shape, (bbox.shape[0],))

def test_predict_cpu(self):
self.check_predict()
assert_is_detection_link(self.link, self.n_class - 1)

@attr.gpu
def test_predict_gpu(self):
self.link.to_gpu()
self.check_predict()
assert_is_detection_link(self.link, self.n_class - 1)


@testing.parameterize(
Expand Down
21 changes: 3 additions & 18 deletions tests/links_tests/model_tests/segnet_tests/test_segnet_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from chainer.testing import attr

from chainercv.links import SegNetBasic
from chainercv.utils import assert_is_semantic_segmentation_link


@testing.parameterize(
Expand Down Expand Up @@ -37,29 +38,13 @@ def test_call_gpu(self):
self.link.to_gpu()
self.check_call()

def check_predict(self):
hs = np.random.randint(128, 160, size=(2,))
ws = np.random.randint(128, 160, size=(2,))
imgs = [
np.random.uniform(size=(3, hs[0], ws[0])).astype(np.float32),
np.random.uniform(size=(3, hs[1], ws[1])).astype(np.float32),
]

labels = self.link.predict(imgs)

self.assertEqual(len(labels), 2)
for i in range(2):
self.assertIsInstance(labels[i], np.ndarray)
self.assertEqual(labels[i].shape, (hs[i], ws[i]))
self.assertEqual(labels[i].dtype, np.int64)

def test_predict_cpu(self):
self.check_predict()
assert_is_semantic_segmentation_link(self.link, self.n_class)

@attr.gpu
def test_predict_gpu(self):
self.link.to_gpu()
self.check_predict()
assert_is_semantic_segmentation_link(self.link, self.n_class)


testing.run_module(__name__, __file__)
30 changes: 3 additions & 27 deletions tests/links_tests/model_tests/ssd_tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from chainercv.links.model.ssd import Multibox
from chainercv.links.model.ssd import SSD
from chainercv.utils import assert_is_detection_link


def _random_array(xp, shape):
Expand Down Expand Up @@ -154,38 +155,13 @@ def test_use_preset(self):
with self.assertRaises(ValueError):
self.link.use_preset('unknown')

def _check_predict(self):
imgs = [
_random_array(np, (3, 640, 480)),
_random_array(np, (3, 320, 320))]

bboxes, labels, scores = self.link.predict(imgs)

self.assertEqual(len(bboxes), len(imgs))
self.assertEqual(len(labels), len(imgs))
self.assertEqual(len(scores), len(imgs))

for bbox, label, score in zip(bboxes, labels, scores):
self.assertIsInstance(bbox, np.ndarray)
self.assertEqual(bbox.ndim, 2)
self.assertLessEqual(bbox.shape[0], self.n_bbox * self.n_fg_class)
self.assertEqual(bbox.shape[1], 4)

self.assertIsInstance(label, np.ndarray)
self.assertEqual(label.ndim, 1)
self.assertEqual(label.shape[0], bbox.shape[0])

self.assertIsInstance(score, np.ndarray)
self.assertEqual(score.ndim, 1)
self.assertEqual(score.shape[0], bbox.shape[0])

def test_predict_cpu(self):
self._check_predict()
assert_is_detection_link(self.link, self.n_fg_class)

@attr.gpu
def test_predict_gpu(self):
self.link.to_gpu()
self._check_predict()
assert_is_detection_link(self.link, self.n_fg_class)


testing.run_module(__name__, __file__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import unittest

import chainer
from chainer import testing

from chainercv.utils import assert_is_detection_link
from chainercv.utils import generate_random_bbox


class DetectionLink(chainer.Link):

def predict(self, imgs):
bboxes = list()
labels = list()
scores = list()

for img in imgs:
n_bbox = np.random.randint(0, 10)
bboxes.append(generate_random_bbox(
n_bbox, img.shape[1:], 4, 12))
labels.append(np.random.randint(
0, 20, size=n_bbox).astype(np.int32))
scores.append(np.random.uniform(
0, 1, size=n_bbox).astype(np.float32))

return bboxes, labels, scores


class InvalidPredictionSizeLink(DetectionLink):

def predict(self, imgs):
bboxes, labels, scores = super(
InvalidPredictionSizeLink, self).predict(imgs)
return bboxes[1:], labels[1:], scores[1:]


class InvalidLabelSizeLink(DetectionLink):

def predict(self, imgs):
bboxes, labels, scores = super(
InvalidLabelSizeLink, self).predict(imgs)
return bboxes, [label[1:] for label in labels], scores


class InvalidLabelValueLink(DetectionLink):

def predict(self, imgs):
bboxes, labels, scores = super(
InvalidLabelValueLink, self).predict(imgs)
return bboxes, [label + 1000 for label in labels], scores


class InvalidScoreSizeLink(DetectionLink):

def predict(self, imgs):
bboxes, labels, scores = super(
InvalidScoreSizeLink, self).predict(imgs)
return bboxes, labels, [score[1:] for score in scores]


@testing.parameterize(
{'link': DetectionLink(), 'valid': True},
{'link': InvalidPredictionSizeLink(), 'valid': False},
{'link': InvalidLabelSizeLink(), 'valid': False},
{'link': InvalidLabelValueLink(), 'valid': False},
{'link': InvalidScoreSizeLink(), 'valid': False},
)
class TestAssertIsDetectionLink(unittest.TestCase):

def test_assert_is_detection_link(self):
if self.valid:
assert_is_detection_link(self.link, 20)
else:
with self.assertRaises(AssertionError):
assert_is_detection_link(self.link, 20)


testing.run_module(__name__, __file__)
Loading