From bce64c36ffb65789fa8c3faa7de73f8524e7eade Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 15:27:58 +0900 Subject: [PATCH 01/10] add assert_is_detection_link --- .../assertions/assert_is_detection_link.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 chainercv/utils/testing/assertions/assert_is_detection_link.py diff --git a/chainercv/utils/testing/assertions/assert_is_detection_link.py b/chainercv/utils/testing/assertions/assert_is_detection_link.py new file mode 100644 index 0000000000..e34a1eb104 --- /dev/null +++ b/chainercv/utils/testing/assertions/assert_is_detection_link.py @@ -0,0 +1,65 @@ +import numpy as np +import six + +from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox + + +def assert_is_detection_link(link, n_fg_class, max_n_bbox=None): + """Checks if a dataset satisfies detection dataset APIs. + + This function checks if a given dataset satisfies detection dataset APIs + or not. + If the dataset does not satifiy the APIs, this function raises an + :class:`AssertionError`. + + Args: + dataset: A dataset to be checked. + n_fg_class (int): The number of foreground classes. + n_example (int): The number of examples to be checked. + If this argument is specified, this function picks + examples ramdomly and checks them. Otherwise, + this function checks all examples. + + """ + + imgs = [ + np.random.randint(0, 256, size=(3, 480, 640)).astype(np.float32), + np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)] + + result = link.predict(imgs) + assert len(result) == 3, \ + 'Link must return three elements: bboxes, labels and scores.' + bboxes, labels, scores = result + + assert len(bboxes) == len(imgs), \ + 'The length of bboxes must be same as that of imgs.' + assert len(labels) == len(imgs), \ + 'The length of labels must be same as that of imgs.' + assert len(scores) == len(imgs), \ + 'The length of scores must be same as that of imgs.' + + for bbox, label, score in six.moves.zip(bboxes, labels, scores): + assert_is_bbox(bbox) + if max_n_bbox: + assert len(bbox) <= max_n_bbox, \ + 'The length of bbox must not exceed max_n_bbox.' + + assert isinstance(label, np.ndarray), \ + 'label must be a numpy.ndarray.' + assert label.dtype == np.int32, \ + 'The type of label must be numpy.int32.' + assert label.shape[1:] == (), \ + 'The shape of label must be (*,).' + assert len(label) == len(bbox), \ + 'The length of label must be same as that of bbox.' + assert label.min() >= 0 and label.max() < n_fg_class, \ + 'The value of label must be in [0, n_fg_class - 1].' + + assert isinstance(score, np.ndarray), \ + 'score must be a numpy.ndarray.' + assert score.dtype == np.float32, \ + 'The type of score must be numpy.float32.' + assert score.shape[1:] == (), \ + 'The shape of score must be (*,).' + assert len(score) == len(bbox), \ + 'The length of score must be same as that of bbox.' From 2c767d735b2ad54b89c7ab9c8597e57d776cb1ac Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 15:31:52 +0900 Subject: [PATCH 02/10] update docs --- .../assertions/assert_is_detection_link.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/chainercv/utils/testing/assertions/assert_is_detection_link.py b/chainercv/utils/testing/assertions/assert_is_detection_link.py index e34a1eb104..0c48eff772 100644 --- a/chainercv/utils/testing/assertions/assert_is_detection_link.py +++ b/chainercv/utils/testing/assertions/assert_is_detection_link.py @@ -4,21 +4,17 @@ from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox -def assert_is_detection_link(link, n_fg_class, max_n_bbox=None): - """Checks if a dataset satisfies detection dataset APIs. +def assert_is_detection_link(link, n_fg_class): + """Checks if a link satisfies detection link APIs. - This function checks if a given dataset satisfies detection dataset APIs + This function checks if a given link satisfies detection link APIs or not. - If the dataset does not satifiy the APIs, this function raises an + If the link does not satifiy the APIs, this function raises an :class:`AssertionError`. Args: - dataset: A dataset to be checked. + link: A link to be checked. n_fg_class (int): The number of foreground classes. - n_example (int): The number of examples to be checked. - If this argument is specified, this function picks - examples ramdomly and checks them. Otherwise, - this function checks all examples. """ @@ -40,9 +36,6 @@ def assert_is_detection_link(link, n_fg_class, max_n_bbox=None): for bbox, label, score in six.moves.zip(bboxes, labels, scores): assert_is_bbox(bbox) - if max_n_bbox: - assert len(bbox) <= max_n_bbox, \ - 'The length of bbox must not exceed max_n_bbox.' assert isinstance(label, np.ndarray), \ 'label must be a numpy.ndarray.' From 4687d7d16cec017355d651c472a889fa3c4652b1 Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 15:55:12 +0900 Subject: [PATCH 03/10] add docs and tests --- chainercv/utils/__init__.py | 1 + chainercv/utils/testing/__init__.py | 1 + .../utils/testing/assertions/__init__.py | 1 + docs/source/reference/utils.rst | 4 + .../test_assert_is_detection_link.py | 79 +++++++++++++++++++ 5 files changed, 86 insertions(+) create mode 100644 tests/utils_tests/testing_tests/assertions_tests/test_assert_is_detection_link.py diff --git a/chainercv/utils/__init__.py b/chainercv/utils/__init__.py index 28ca7e5e93..7ca3b72e57 100644 --- a/chainercv/utils/__init__.py +++ b/chainercv/utils/__init__.py @@ -8,6 +8,7 @@ from chainercv.utils.iterator import unzip # NOQA from chainercv.utils.testing import assert_is_bbox # NOQA from chainercv.utils.testing import assert_is_detection_dataset # NOQA +from chainercv.utils.testing import assert_is_detection_link # NOQA from chainercv.utils.testing import assert_is_image # NOQA from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA from chainercv.utils.testing import ConstantStubLink # NOQA diff --git a/chainercv/utils/testing/__init__.py b/chainercv/utils/testing/__init__.py index 426fd608b8..69aeb858f6 100644 --- a/chainercv/utils/testing/__init__.py +++ b/chainercv/utils/testing/__init__.py @@ -1,5 +1,6 @@ from chainercv.utils.testing.assertions import assert_is_bbox # NOQA from chainercv.utils.testing.assertions import assert_is_detection_dataset # NOQA +from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA from chainercv.utils.testing.assertions import assert_is_image # NOQA from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA diff --git a/chainercv/utils/testing/assertions/__init__.py b/chainercv/utils/testing/assertions/__init__.py index 0f72a4830b..6b96acf1a6 100644 --- a/chainercv/utils/testing/assertions/__init__.py +++ b/chainercv/utils/testing/assertions/__init__.py @@ -1,4 +1,5 @@ from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox # NOQA from chainercv.utils.testing.assertions.assert_is_detection_dataset import assert_is_detection_dataset # NOQA +from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA diff --git a/docs/source/reference/utils.rst b/docs/source/reference/utils.rst index 097f7a1a92..14d36e4703 100644 --- a/docs/source/reference/utils.rst +++ b/docs/source/reference/utils.rst @@ -63,6 +63,10 @@ assert_is_detection_dataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunctions:: assert_is_detection_dataset +assert_is_detection_link +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunctions:: assert_is_detection_link + assert_is_image ~~~~~~~~~~~~~~~ .. autofunctions:: assert_is_image diff --git a/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_detection_link.py b/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_detection_link.py new file mode 100644 index 0000000000..013ad3c7f0 --- /dev/null +++ b/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_detection_link.py @@ -0,0 +1,79 @@ +import numpy as np +import unittest + +import chainer +from chainer import testing + +from chainercv.utils import assert_is_detection_link +from chainercv.utils import generate_random_bbox + + +class DetectionLink(chainer.Link): + + def predict(self, imgs): + bboxes = list() + labels = list() + scores = list() + + for img in imgs: + n_bbox = np.random.randint(0, 10) + bboxes.append(generate_random_bbox( + n_bbox, img.shape[1:], 4, 12)) + labels.append(np.random.randint( + 0, 20, size=n_bbox).astype(np.int32)) + scores.append(np.random.uniform( + 0, 1, size=n_bbox).astype(np.float32)) + + return bboxes, labels, scores + + +class InvalidPredictionSizeLink(DetectionLink): + + def predict(self, imgs): + bboxes, labels, scores = super( + InvalidPredictionSizeLink, self).predict(imgs) + return bboxes[1:], labels[1:], scores[1:] + + +class InvalidLabelSizeLink(DetectionLink): + + def predict(self, imgs): + bboxes, labels, scores = super( + InvalidLabelSizeLink, self).predict(imgs) + return bboxes, [label[1:] for label in labels], scores + + +class InvalidLabelValueLink(DetectionLink): + + def predict(self, imgs): + bboxes, labels, scores = super( + InvalidLabelValueLink, self).predict(imgs) + return bboxes, [label + 1000 for label in labels], scores + + +class InvalidScoreSizeLink(DetectionLink): + + def predict(self, imgs): + bboxes, labels, scores = super( + InvalidScoreSizeLink, self).predict(imgs) + return bboxes, labels, [score[1:] for score in scores] + + +@testing.parameterize( + {'link': DetectionLink(), 'valid': True}, + {'link': InvalidPredictionSizeLink(), 'valid': False}, + {'link': InvalidLabelSizeLink(), 'valid': False}, + {'link': InvalidLabelValueLink(), 'valid': False}, + {'link': InvalidScoreSizeLink(), 'valid': False}, +) +class TestAssertIsDetectionLink(unittest.TestCase): + + def test_assert_is_detection_link(self): + if self.valid: + assert_is_detection_link(self.link, 20) + else: + with self.assertRaises(AssertionError): + assert_is_detection_link(self.link, 20) + + +testing.run_module(__name__, __file__) From 4309d2eb3b54de330cf073393c5f2bb120047d76 Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 16:01:27 +0900 Subject: [PATCH 04/10] use assert_is_detection_link --- .../faster_rcnn_tests/test_faster_rcnn.py | 32 +++---------------- .../model_tests/ssd_tests/test_ssd.py | 30 ++--------------- 2 files changed, 7 insertions(+), 55 deletions(-) diff --git a/tests/links_tests/model_tests/faster_rcnn_tests/test_faster_rcnn.py b/tests/links_tests/model_tests/faster_rcnn_tests/test_faster_rcnn.py index 5fcd8ec44a..5ac82043cd 100644 --- a/tests/links_tests/model_tests/faster_rcnn_tests/test_faster_rcnn.py +++ b/tests/links_tests/model_tests/faster_rcnn_tests/test_faster_rcnn.py @@ -5,6 +5,8 @@ from chainer import testing from chainer.testing import attr +from chainercv.utils import assert_is_detection_link + from dummy_faster_rcnn import DummyFasterRCNN @@ -58,39 +60,13 @@ def test_call_gpu(self): self.link.to_gpu() self.check_call() - def check_predict(self): - imgs = [ - _random_array(np, (3, 640, 480)), - _random_array(np, (3, 320, 320))] - - bboxes, labels, scores = self.link.predict(imgs) - - self.assertEqual(len(bboxes), len(imgs)) - self.assertEqual(len(labels), len(imgs)) - self.assertEqual(len(scores), len(imgs)) - - for bbox, label, score in zip(bboxes, labels, scores): - self.assertIsInstance(bbox, np.ndarray) - self.assertEqual(bbox.dtype, np.float32) - self.assertEqual(bbox.ndim, 2) - self.assertLessEqual(bbox.shape[0], self.n_roi) - self.assertEqual(bbox.shape[1], 4) - - self.assertIsInstance(label, np.ndarray) - self.assertEqual(label.dtype, np.int32) - self.assertEqual(label.shape, (bbox.shape[0],)) - - self.assertIsInstance(score, np.ndarray) - self.assertEqual(score.dtype, np.float32) - self.assertEqual(score.shape, (bbox.shape[0],)) - def test_predict_cpu(self): - self.check_predict() + assert_is_detection_link(self.link, self.n_class - 1) @attr.gpu def test_predict_gpu(self): self.link.to_gpu() - self.check_predict() + assert_is_detection_link(self.link, self.n_class - 1) @testing.parameterize( diff --git a/tests/links_tests/model_tests/ssd_tests/test_ssd.py b/tests/links_tests/model_tests/ssd_tests/test_ssd.py index 3960fa190d..5f37f8b998 100644 --- a/tests/links_tests/model_tests/ssd_tests/test_ssd.py +++ b/tests/links_tests/model_tests/ssd_tests/test_ssd.py @@ -7,6 +7,7 @@ from chainercv.links.model.ssd import Multibox from chainercv.links.model.ssd import SSD +from chainercv.utils import assert_is_detection_link def _random_array(xp, shape): @@ -154,38 +155,13 @@ def test_use_preset(self): with self.assertRaises(ValueError): self.link.use_preset('unknown') - def _check_predict(self): - imgs = [ - _random_array(np, (3, 640, 480)), - _random_array(np, (3, 320, 320))] - - bboxes, labels, scores = self.link.predict(imgs) - - self.assertEqual(len(bboxes), len(imgs)) - self.assertEqual(len(labels), len(imgs)) - self.assertEqual(len(scores), len(imgs)) - - for bbox, label, score in zip(bboxes, labels, scores): - self.assertIsInstance(bbox, np.ndarray) - self.assertEqual(bbox.ndim, 2) - self.assertLessEqual(bbox.shape[0], self.n_bbox * self.n_fg_class) - self.assertEqual(bbox.shape[1], 4) - - self.assertIsInstance(label, np.ndarray) - self.assertEqual(label.ndim, 1) - self.assertEqual(label.shape[0], bbox.shape[0]) - - self.assertIsInstance(score, np.ndarray) - self.assertEqual(score.ndim, 1) - self.assertEqual(score.shape[0], bbox.shape[0]) - def test_predict_cpu(self): - self._check_predict() + assert_is_detection_link(self.link, self.n_fg_class) @attr.gpu def test_predict_gpu(self): self.link.to_gpu() - self._check_predict() + assert_is_detection_link(self.link, self.n_fg_class) testing.run_module(__name__, __file__) From d6295383c1ec36a47b3307d81e0ce1d0127b2b60 Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 16:09:22 +0900 Subject: [PATCH 05/10] fix bug when bbox is empty --- .../utils/testing/assertions/assert_is_detection_link.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chainercv/utils/testing/assertions/assert_is_detection_link.py b/chainercv/utils/testing/assertions/assert_is_detection_link.py index 0c48eff772..8673b210e0 100644 --- a/chainercv/utils/testing/assertions/assert_is_detection_link.py +++ b/chainercv/utils/testing/assertions/assert_is_detection_link.py @@ -45,8 +45,9 @@ def assert_is_detection_link(link, n_fg_class): 'The shape of label must be (*,).' assert len(label) == len(bbox), \ 'The length of label must be same as that of bbox.' - assert label.min() >= 0 and label.max() < n_fg_class, \ - 'The value of label must be in [0, n_fg_class - 1].' + if len(label) > 0: + assert label.min() >= 0 and label.max() < n_fg_class, \ + 'The value of label must be in [0, n_fg_class - 1].' assert isinstance(score, np.ndarray), \ 'score must be a numpy.ndarray.' From 36065baf506d46a1026a9422c7bdbe64af197c8f Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 16:09:37 +0900 Subject: [PATCH 06/10] fix return type --- chainercv/links/model/ssd/ssd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chainercv/links/model/ssd/ssd.py b/chainercv/links/model/ssd/ssd.py index bdbf2023e1..46a79bc58d 100644 --- a/chainercv/links/model/ssd/ssd.py +++ b/chainercv/links/model/ssd/ssd.py @@ -182,9 +182,9 @@ def _suppress(self, raw_bbox, raw_score): label.append(xp.array((l,) * len(bbox_l))) score.append(score_l) - bbox = xp.vstack(bbox) - label = xp.hstack(label).astype(int) - score = xp.hstack(score) + bbox = xp.vstack(bbox).astype(np.float32) + label = xp.hstack(label).astype(np.int32) + score = xp.hstack(score).astype(np.float32) return bbox, label, score From b3f0bd2075c32e4917374325abc20adff44b36f1 Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 16:34:03 +0900 Subject: [PATCH 07/10] add assert_is_semantic_segmentation_link --- chainercv/utils/__init__.py | 1 + chainercv/utils/testing/__init__.py | 1 + .../utils/testing/assertions/__init__.py | 1 + .../assert_is_semantic_segmentation_link.py | 35 +++++++++++ docs/source/reference/utils.rst | 4 ++ ...st_assert_is_semantic_segmentation_link.py | 62 +++++++++++++++++++ 6 files changed, 104 insertions(+) create mode 100644 chainercv/utils/testing/assertions/assert_is_semantic_segmentation_link.py create mode 100644 tests/utils_tests/testing_tests/assertions_tests/test_assert_is_semantic_segmentation_link.py diff --git a/chainercv/utils/__init__.py b/chainercv/utils/__init__.py index 7ca3b72e57..9706f92390 100644 --- a/chainercv/utils/__init__.py +++ b/chainercv/utils/__init__.py @@ -11,5 +11,6 @@ from chainercv.utils.testing import assert_is_detection_link # NOQA from chainercv.utils.testing import assert_is_image # NOQA from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA +from chainercv.utils.testing import assert_is_semantic_segmentation_link # NOQA from chainercv.utils.testing import ConstantStubLink # NOQA from chainercv.utils.testing import generate_random_bbox # NOQA diff --git a/chainercv/utils/testing/__init__.py b/chainercv/utils/testing/__init__.py index 69aeb858f6..d43dba70c5 100644 --- a/chainercv/utils/testing/__init__.py +++ b/chainercv/utils/testing/__init__.py @@ -3,5 +3,6 @@ from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA from chainercv.utils.testing.assertions import assert_is_image # NOQA from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA +from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_link # NOQA from chainercv.utils.testing.constant_stub_link import ConstantStubLink # NOQA from chainercv.utils.testing.generate_random_bbox import generate_random_bbox # NOQA diff --git a/chainercv/utils/testing/assertions/__init__.py b/chainercv/utils/testing/assertions/__init__.py index 6b96acf1a6..35782a71de 100644 --- a/chainercv/utils/testing/assertions/__init__.py +++ b/chainercv/utils/testing/assertions/__init__.py @@ -3,3 +3,4 @@ from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA +from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_link import assert_is_semantic_segmentation_link # NOQA diff --git a/chainercv/utils/testing/assertions/assert_is_semantic_segmentation_link.py b/chainercv/utils/testing/assertions/assert_is_semantic_segmentation_link.py new file mode 100644 index 0000000000..b49fc9e3b0 --- /dev/null +++ b/chainercv/utils/testing/assertions/assert_is_semantic_segmentation_link.py @@ -0,0 +1,35 @@ +import numpy as np +import six + + +def assert_is_semantic_segmentation_link(link, n_class): + """Checks if a link satisfies semantic segmentation link APIs. + + This function checks if a given link satisfies semantic segmentation link + APIs or not. + If the link does not satifiy the APIs, this function raises an + :class:`AssertionError`. + + Args: + link: A link to be checked. + n_class (int): The number of classes including background. + + """ + + imgs = [ + np.random.randint(0, 256, size=(3, 480, 640)).astype(np.float32), + np.random.randint(0, 256, size=(3, 480, 320)).astype(np.float32)] + + labels = link.predict(imgs) + assert len(labels) == len(imgs), \ + 'The length of labels must be same as that of imgs.' + + for img, label in six.moves.zip(imgs, labels): + assert isinstance(label, np.ndarray), \ + 'label must be a numpy.ndarray.' + assert label.dtype == np.int32, \ + 'The type of label must be numpy.int32.' + assert label.shape == img.shape[1:], \ + 'The shape of label must be (H, W).' + assert label.min() >= 0 and label.max() < n_class, \ + 'The value of label must be in [0, n_class - 1].' diff --git a/docs/source/reference/utils.rst b/docs/source/reference/utils.rst index 14d36e4703..4339bb46e2 100644 --- a/docs/source/reference/utils.rst +++ b/docs/source/reference/utils.rst @@ -75,6 +75,10 @@ assert_is_semantic_segmentation_dataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunctions:: assert_is_semantic_segmentation_dataset +assert_is_semantic_segmentation_link +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunctions:: assert_is_semantic_segmentation_link + ConstantStubLink ~~~~~~~~~~~~~~~~ .. autoclass:: ConstantStubLink diff --git a/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_semantic_segmentation_link.py b/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_semantic_segmentation_link.py new file mode 100644 index 0000000000..bb5f6cf37e --- /dev/null +++ b/tests/utils_tests/testing_tests/assertions_tests/test_assert_is_semantic_segmentation_link.py @@ -0,0 +1,62 @@ +import numpy as np +import unittest + +import chainer +from chainer import testing + +from chainercv.utils import assert_is_semantic_segmentation_link + + +class SemanticSegmentationLink(chainer.Link): + + def predict(self, imgs): + labels = list() + + for img in imgs: + labels.append(np.random.randint( + 0, 21, size=img.shape[1:]).astype(np.int32)) + + return labels + + +class InvalidPredictionSizeLink(SemanticSegmentationLink): + + def predict(self, imgs): + labels = super( + InvalidPredictionSizeLink, self).predict(imgs) + return labels[1:] + + +class InvalidLabelSizeLink(SemanticSegmentationLink): + + def predict(self, imgs): + labels = super( + InvalidLabelSizeLink, self).predict(imgs) + return [label[1:] for label in labels] + + +class InvalidLabelValueLink(SemanticSegmentationLink): + + def predict(self, imgs): + labels = super( + InvalidLabelValueLink, self).predict(imgs) + return [label + 1000 for label in labels] + + +@testing.parameterize( + {'link': SemanticSegmentationLink(), 'valid': True}, + {'link': InvalidPredictionSizeLink(), 'valid': False}, + {'link': InvalidLabelSizeLink(), 'valid': False}, + {'link': InvalidLabelValueLink(), 'valid': False}, +) +class TestAssertIsSemanticSegmentationLink(unittest.TestCase): + + def test_assert_is_semantic_segmentation_link(self): + if self.valid: + assert_is_semantic_segmentation_link(self.link, 21) + else: + with self.assertRaises(AssertionError): + assert_is_semantic_segmentation_link(self.link, 21) + + +testing.run_module(__name__, __file__) From ea54d89cd4fa4c94e0945795857e20c0490824a4 Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 16:44:32 +0900 Subject: [PATCH 08/10] use assert_is_semantic_segmentation_link --- .../segnet_tests/test_segnet_basic.py | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/tests/links_tests/model_tests/segnet_tests/test_segnet_basic.py b/tests/links_tests/model_tests/segnet_tests/test_segnet_basic.py index b3bf0120f9..506e7e760c 100644 --- a/tests/links_tests/model_tests/segnet_tests/test_segnet_basic.py +++ b/tests/links_tests/model_tests/segnet_tests/test_segnet_basic.py @@ -6,6 +6,7 @@ from chainer.testing import attr from chainercv.links import SegNetBasic +from chainercv.utils import assert_is_semantic_segmentation_link @testing.parameterize( @@ -37,29 +38,13 @@ def test_call_gpu(self): self.link.to_gpu() self.check_call() - def check_predict(self): - hs = np.random.randint(128, 160, size=(2,)) - ws = np.random.randint(128, 160, size=(2,)) - imgs = [ - np.random.uniform(size=(3, hs[0], ws[0])).astype(np.float32), - np.random.uniform(size=(3, hs[1], ws[1])).astype(np.float32), - ] - - labels = self.link.predict(imgs) - - self.assertEqual(len(labels), 2) - for i in range(2): - self.assertIsInstance(labels[i], np.ndarray) - self.assertEqual(labels[i].shape, (hs[i], ws[i])) - self.assertEqual(labels[i].dtype, np.int64) - def test_predict_cpu(self): - self.check_predict() + assert_is_semantic_segmentation_link(self.link, self.n_class) @attr.gpu def test_predict_gpu(self): self.link.to_gpu() - self.check_predict() + assert_is_semantic_segmentation_link(self.link, self.n_class) testing.run_module(__name__, __file__) From 1af165c8d2aab7f72c333ffe0e5e26eb06fe1e06 Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 16 Jun 2017 16:44:58 +0900 Subject: [PATCH 09/10] fix return type --- chainercv/links/model/segnet/segnet_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainercv/links/model/segnet/segnet_basic.py b/chainercv/links/model/segnet/segnet_basic.py index c30e866d7a..dd86419e1c 100644 --- a/chainercv/links/model/segnet/segnet_basic.py +++ b/chainercv/links/model/segnet/segnet_basic.py @@ -176,6 +176,6 @@ def predict(self, imgs): dtype = score.dtype score = resize(score, (H, W)).astype(dtype) - label = np.argmax(score, axis=0) + label = np.argmax(score, axis=0).astype(np.int32) labels.append(label) return labels From eaac7f5bb3db3db991fb89901e2f7b1ba0661f7b Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Fri, 23 Jun 2017 17:04:20 +0900 Subject: [PATCH 10/10] autofunctions -> autofunction --- docs/source/reference/utils.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/reference/utils.rst b/docs/source/reference/utils.rst index 4339bb46e2..9471ca72e3 100644 --- a/docs/source/reference/utils.rst +++ b/docs/source/reference/utils.rst @@ -57,27 +57,27 @@ Testing Utilities assert_is_bbox ~~~~~~~~~~~~~~ -.. autofunctions:: assert_is_bbox +.. autofunction:: assert_is_bbox assert_is_detection_dataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunctions:: assert_is_detection_dataset +.. autofunction:: assert_is_detection_dataset assert_is_detection_link ~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunctions:: assert_is_detection_link +.. autofunction:: assert_is_detection_link assert_is_image ~~~~~~~~~~~~~~~ -.. autofunctions:: assert_is_image +.. autofunction:: assert_is_image assert_is_semantic_segmentation_dataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunctions:: assert_is_semantic_segmentation_dataset +.. autofunction:: assert_is_semantic_segmentation_dataset assert_is_semantic_segmentation_link ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunctions:: assert_is_semantic_segmentation_link +.. autofunction:: assert_is_semantic_segmentation_link ConstantStubLink ~~~~~~~~~~~~~~~~