Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Add assert_is_*_link #285

Merged
merged 10 commits into from
Jun 23, 2017
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chainercv/links/model/segnet/segnet_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,6 @@ def predict(self, imgs):
dtype = score.dtype
score = resize(score, (H, W)).astype(dtype)

label = np.argmax(score, axis=0)
label = np.argmax(score, axis=0).astype(np.int32)
labels.append(label)
return labels
6 changes: 3 additions & 3 deletions chainercv/links/model/ssd/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def _suppress(self, raw_bbox, raw_score):
label.append(xp.array((l,) * len(bbox_l)))
score.append(score_l)

bbox = xp.vstack(bbox)
label = xp.hstack(label).astype(int)
score = xp.hstack(score)
bbox = xp.vstack(bbox).astype(np.float32)
label = xp.hstack(label).astype(np.int32)
score = xp.hstack(score).astype(np.float32)

return bbox, label, score

Expand Down
2 changes: 2 additions & 0 deletions chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from chainercv.utils.iterator import unzip # NOQA
from chainercv.utils.testing import assert_is_bbox # NOQA
from chainercv.utils.testing import assert_is_detection_dataset # NOQA
from chainercv.utils.testing import assert_is_detection_link # NOQA
from chainercv.utils.testing import assert_is_image # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing import ConstantStubLink # NOQA
from chainercv.utils.testing import generate_random_bbox # NOQA
2 changes: 2 additions & 0 deletions chainercv/utils/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from chainercv.utils.testing.assertions import assert_is_bbox # NOQA
from chainercv.utils.testing.assertions import assert_is_detection_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions import assert_is_image # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_link # NOQA
from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA
from chainercv.utils.testing.generate_random_bbox import generate_random_bbox # NOQA
2 changes: 2 additions & 0 deletions chainercv/utils/testing/assertions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox # NOQA
from chainercv.utils.testing.assertions.assert_is_detection_dataset import assert_is_detection_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_link import assert_is_semantic_segmentation_link # NOQA
59 changes: 59 additions & 0 deletions chainercv/utils/testing/assertions/assert_is_detection_link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import six

from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox


def assert_is_detection_link(link, n_fg_class):
"""Checks if a link satisfies detection link APIs.

This function checks if a given link satisfies detection link APIs
or not.
If the link does not satifiy the APIs, this function raises an
:class:`AssertionError`.

Args:
link: A link to be checked.
n_fg_class (int): The number of foreground classes.

"""

imgs = [
np.random.randint(0, 256, size=(3, 480, 640)).astype(np.float32),
np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)]

result = link.predict(imgs)
assert len(result) == 3, \
'Link must return three elements: bboxes, labels and scores.'
bboxes, labels, scores = result

assert len(bboxes) == len(imgs), \
'The length of bboxes must be same as that of imgs.'
assert len(labels) == len(imgs), \
'The length of labels must be same as that of imgs.'
assert len(scores) == len(imgs), \
'The length of scores must be same as that of imgs.'

for bbox, label, score in six.moves.zip(bboxes, labels, scores):
assert_is_bbox(bbox)

assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert label.dtype == np.int32, \
'The type of label must be numpy.int32.'
assert label.shape[1:] == (), \
'The shape of label must be (*,).'
assert len(label) == len(bbox), \
'The length of label must be same as that of bbox.'
if len(label) > 0:
assert label.min() >= 0 and label.max() < n_fg_class, \
'The value of label must be in [0, n_fg_class - 1].'

assert isinstance(score, np.ndarray), \
'score must be a numpy.ndarray.'
assert score.dtype == np.float32, \
'The type of score must be numpy.float32.'
assert score.shape[1:] == (), \
'The shape of score must be (*,).'
assert len(score) == len(bbox), \
'The length of score must be same as that of bbox.'
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import six


def assert_is_semantic_segmentation_link(link, n_class):
"""Checks if a link satisfies semantic segmentation link APIs.

This function checks if a given link satisfies semantic segmentation link
APIs or not.
If the link does not satifiy the APIs, this function raises an
:class:`AssertionError`.

Args:
link: A link to be checked.
n_class (int): The number of classes including background.

"""

imgs = [
np.random.randint(0, 256, size=(3, 480, 640)).astype(np.float32),
np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)]

labels = link.predict(imgs)
assert len(labels) == len(imgs), \
'The length of labels must be same as that of imgs.'

for img, label in six.moves.zip(imgs, labels):
assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert label.dtype == np.int32, \
'The type of label must be numpy.int32.'
assert label.shape == img.shape[1:], \
'The shape of label must be (H, W).'
assert label.min() >= 0 and label.max() < n_class, \
'The value of label must be in [0, n_class - 1].'
8 changes: 8 additions & 0 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ assert_is_detection_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_detection_dataset

assert_is_detection_link
~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_detection_link
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autofunctions raises error.
Can you use autofunction?


assert_is_image
~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_image
Expand All @@ -71,6 +75,10 @@ assert_is_semantic_segmentation_dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_semantic_segmentation_dataset

assert_is_semantic_segmentation_link
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunctions:: assert_is_semantic_segmentation_link

ConstantStubLink
~~~~~~~~~~~~~~~~
.. autoclass:: ConstantStubLink
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from chainer import testing
from chainer.testing import attr

from chainercv.utils import assert_is_detection_link

from dummy_faster_rcnn import DummyFasterRCNN


Expand Down Expand Up @@ -58,39 +60,13 @@ def test_call_gpu(self):
self.link.to_gpu()
self.check_call()

def check_predict(self):
imgs = [
_random_array(np, (3, 640, 480)),
_random_array(np, (3, 320, 320))]

bboxes, labels, scores = self.link.predict(imgs)

self.assertEqual(len(bboxes), len(imgs))
self.assertEqual(len(labels), len(imgs))
self.assertEqual(len(scores), len(imgs))

for bbox, label, score in zip(bboxes, labels, scores):
self.assertIsInstance(bbox, np.ndarray)
self.assertEqual(bbox.dtype, np.float32)
self.assertEqual(bbox.ndim, 2)
self.assertLessEqual(bbox.shape[0], self.n_roi)
self.assertEqual(bbox.shape[1], 4)

self.assertIsInstance(label, np.ndarray)
self.assertEqual(label.dtype, np.int32)
self.assertEqual(label.shape, (bbox.shape[0],))

self.assertIsInstance(score, np.ndarray)
self.assertEqual(score.dtype, np.float32)
self.assertEqual(score.shape, (bbox.shape[0],))

def test_predict_cpu(self):
self.check_predict()
assert_is_detection_link(self.link, self.n_class - 1)

@attr.gpu
def test_predict_gpu(self):
self.link.to_gpu()
self.check_predict()
assert_is_detection_link(self.link, self.n_class - 1)


@testing.parameterize(
Expand Down
21 changes: 3 additions & 18 deletions tests/links_tests/model_tests/segnet_tests/test_segnet_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from chainer.testing import attr

from chainercv.links import SegNetBasic
from chainercv.utils import assert_is_semantic_segmentation_link


@testing.parameterize(
Expand Down Expand Up @@ -37,29 +38,13 @@ def test_call_gpu(self):
self.link.to_gpu()
self.check_call()

def check_predict(self):
hs = np.random.randint(128, 160, size=(2,))
ws = np.random.randint(128, 160, size=(2,))
imgs = [
np.random.uniform(size=(3, hs[0], ws[0])).astype(np.float32),
np.random.uniform(size=(3, hs[1], ws[1])).astype(np.float32),
]

labels = self.link.predict(imgs)

self.assertEqual(len(labels), 2)
for i in range(2):
self.assertIsInstance(labels[i], np.ndarray)
self.assertEqual(labels[i].shape, (hs[i], ws[i]))
self.assertEqual(labels[i].dtype, np.int64)

def test_predict_cpu(self):
self.check_predict()
assert_is_semantic_segmentation_link(self.link, self.n_class)

@attr.gpu
def test_predict_gpu(self):
self.link.to_gpu()
self.check_predict()
assert_is_semantic_segmentation_link(self.link, self.n_class)


testing.run_module(__name__, __file__)
30 changes: 3 additions & 27 deletions tests/links_tests/model_tests/ssd_tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from chainercv.links.model.ssd import Multibox
from chainercv.links.model.ssd import SSD
from chainercv.utils import assert_is_detection_link


def _random_array(xp, shape):
Expand Down Expand Up @@ -154,38 +155,13 @@ def test_use_preset(self):
with self.assertRaises(ValueError):
self.link.use_preset('unknown')

def _check_predict(self):
imgs = [
_random_array(np, (3, 640, 480)),
_random_array(np, (3, 320, 320))]

bboxes, labels, scores = self.link.predict(imgs)

self.assertEqual(len(bboxes), len(imgs))
self.assertEqual(len(labels), len(imgs))
self.assertEqual(len(scores), len(imgs))

for bbox, label, score in zip(bboxes, labels, scores):
self.assertIsInstance(bbox, np.ndarray)
self.assertEqual(bbox.ndim, 2)
self.assertLessEqual(bbox.shape[0], self.n_bbox * self.n_fg_class)
self.assertEqual(bbox.shape[1], 4)

self.assertIsInstance(label, np.ndarray)
self.assertEqual(label.ndim, 1)
self.assertEqual(label.shape[0], bbox.shape[0])

self.assertIsInstance(score, np.ndarray)
self.assertEqual(score.ndim, 1)
self.assertEqual(score.shape[0], bbox.shape[0])

def test_predict_cpu(self):
self._check_predict()
assert_is_detection_link(self.link, self.n_fg_class)

@attr.gpu
def test_predict_gpu(self):
self.link.to_gpu()
self._check_predict()
assert_is_detection_link(self.link, self.n_fg_class)


testing.run_module(__name__, __file__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import unittest

import chainer
from chainer import testing

from chainercv.utils import assert_is_detection_link
from chainercv.utils import generate_random_bbox


class DetectionLink(chainer.Link):

def predict(self, imgs):
bboxes = list()
labels = list()
scores = list()

for img in imgs:
n_bbox = np.random.randint(0, 10)
bboxes.append(generate_random_bbox(
n_bbox, img.shape[1:], 4, 12))
labels.append(np.random.randint(
0, 20, size=n_bbox).astype(np.int32))
scores.append(np.random.uniform(
0, 1, size=n_bbox).astype(np.float32))

return bboxes, labels, scores


class InvalidPredictionSizeLink(DetectionLink):

def predict(self, imgs):
bboxes, labels, scores = super(
InvalidPredictionSizeLink, self).predict(imgs)
return bboxes[1:], labels[1:], scores[1:]


class InvalidLabelSizeLink(DetectionLink):

def predict(self, imgs):
bboxes, labels, scores = super(
InvalidLabelSizeLink, self).predict(imgs)
return bboxes, [label[1:] for label in labels], scores


class InvalidLabelValueLink(DetectionLink):

def predict(self, imgs):
bboxes, labels, scores = super(
InvalidLabelValueLink, self).predict(imgs)
return bboxes, [label + 1000 for label in labels], scores


class InvalidScoreSizeLink(DetectionLink):

def predict(self, imgs):
bboxes, labels, scores = super(
InvalidScoreSizeLink, self).predict(imgs)
return bboxes, labels, [score[1:] for score in scores]


@testing.parameterize(
{'link': DetectionLink(), 'valid': True},
{'link': InvalidPredictionSizeLink(), 'valid': False},
{'link': InvalidLabelSizeLink(), 'valid': False},
{'link': InvalidLabelValueLink(), 'valid': False},
{'link': InvalidScoreSizeLink(), 'valid': False},
)
class TestAssertIsDetectionLink(unittest.TestCase):

def test_assert_is_detection_link(self):
if self.valid:
assert_is_detection_link(self.link, 20)
else:
with self.assertRaises(AssertionError):
assert_is_detection_link(self.link, 20)


testing.run_module(__name__, __file__)
Loading