diff --git a/chainercv/visualizations/__init__.py b/chainercv/visualizations/__init__.py index 9449ed8a8e..2adf9f7ba8 100644 --- a/chainercv/visualizations/__init__.py +++ b/chainercv/visualizations/__init__.py @@ -1,4 +1,5 @@ from chainercv.visualizations.vis_bbox import vis_bbox # NOQA from chainercv.visualizations.vis_image import vis_image # NOQA +from chainercv.visualizations.vis_instance_segmentation import vis_instance_segmentation # NOQA from chainercv.visualizations.vis_point import vis_point # NOQA from chainercv.visualizations.vis_semantic_segmentation import vis_semantic_segmentation # NOQA diff --git a/chainercv/visualizations/vis_instance_segmentation.py b/chainercv/visualizations/vis_instance_segmentation.py new file mode 100644 index 0000000000..338c51dae6 --- /dev/null +++ b/chainercv/visualizations/vis_instance_segmentation.py @@ -0,0 +1,97 @@ +from __future__ import division + +import numpy as np + +from chainercv.visualizations.vis_image import vis_image +from chainercv.visualizations.vis_semantic_segmentation import _default_cmap + + +def vis_instance_segmentation( + img, bbox, mask, label=None, score=None, label_names=None, + alpha=0.7, ax=None): + """Visualize instance segmentation. + + Example: + + >>> from chainercv.datasets import SBDInstanceSegmentationDataset + >>> from chainercv.datasets \ + ... import sbd_instance_segmentation_label_names + >>> from chainercv.visualizations import vis_instance_segmentation + >>> import matplotlib.pyplot as plot + >>> dataset = SBDSegmentationDataset() + >>> img, bbox, mask, label = dataset[0] + >>> vis_instance_segmentation( + ... img, bbox, mask, label, + ... label_names=sbd_instance_segmentation_label_names) + >>> plot.show() + + Args: + img (~numpy.ndarray): An array of shape :math:`(3, H, W)`. + This is in RGB format and the range of its value is + :math:`[0, 255]`. + bbox (~numpy.ndarray): A float array of shape :math:`(R, 4)`. + :math:`R` is the number of objects in the image, and each + vector represents a bounding box of an object. + The bounding box is :math:`(y_min, x_min, y_max, x_max)`. + mask (~numpy.ndarray): A bool array of shape + :math`(R, H, W)`. + If there is an object, the value of the pixel is :obj:`True`, + and otherwise, it is :obj:`False`. + label (~numpy.ndarray): An integer array of shape :math:`(R, )`. + The values correspond to id for label names stored in + :obj:`label_names`. + label_names (iterable of strings): Name of labels ordered according + to label ids. + alpha (float): The value which determines transparency of the figure. + The range of this value is :math:`[0, 1]`. If this + value is :obj:`0`, the figure will be completely transparent. + The default value is :obj:`0.7`. This option is useful for + overlaying the label on the source image. + ax (matplotlib.axes.Axis): The visualization is displayed on this + axis. If this is :obj:`None` (default), a new axis is created. + + Returns: + matploblib.axes.Axes: Returns :obj:`ax`. + :obj:`ax` is an :class:`matploblib.axes.Axes` with the plot. + + """ + if len(bbox) != len(mask): + raise ValueError('The length of mask must be same as that of bbox') + if label is not None and len(bbox) != len(label): + raise ValueError('The length of label must be same as that of bbox') + if score is not None and len(bbox) != len(score): + raise ValueError('The length of score must be same as that of bbox') + + n_inst = len(bbox) + colors = np.array([_default_cmap(l) for l in range(1, n_inst + 1)]) + + # Returns newly instantiated matplotlib.axes.Axes object if ax is None + ax = vis_image(img, ax=ax) + + canvas_img = np.zeros((mask.shape[1], mask.shape[2], 4), dtype=np.uint8) + for i, (color, bb, msk) in enumerate(zip(colors, bbox, mask)): + rgba = np.append(color, alpha * 255) + bb = np.round(bb).astype(np.int32) + y_min, x_min, y_max, x_max = bb + if y_max > y_min and x_max > x_min: + canvas_img[msk] = rgba + + caption = [] + if label is not None and label_names is not None: + lb = label[i] + if not (0 <= lb < len(label_names)): + raise ValueError('No corresponding name is given') + caption.append(label_names[lb]) + if score is not None: + sc = score[i] + caption.append('{:.2f}'.format(sc)) + + if len(caption) > 0: + ax.text((x_max + x_min) / 2, y_min, + ': '.join(caption), + style='italic', + bbox={'facecolor': color / 255, 'alpha': alpha}, + fontsize=8, color='white') + + ax.imshow(canvas_img) + return ax diff --git a/docs/source/reference/visualizations.rst b/docs/source/reference/visualizations.rst index c00682e6be..685b498e43 100644 --- a/docs/source/reference/visualizations.rst +++ b/docs/source/reference/visualizations.rst @@ -12,6 +12,10 @@ vis_image ~~~~~~~~~ .. autofunction:: vis_image +vis_instance_segmentation +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: vis_instance_segmentation + vis_point ~~~~~~~~~ .. autofunction:: vis_point diff --git a/tests/visualizations_tests/test_vis_instance_segmentation.py b/tests/visualizations_tests/test_vis_instance_segmentation.py new file mode 100644 index 0000000000..200c136ffe --- /dev/null +++ b/tests/visualizations_tests/test_vis_instance_segmentation.py @@ -0,0 +1,111 @@ +import numpy as np +import unittest + +from chainer import testing + +from chainercv.utils import generate_random_bbox +from chainercv.visualizations import vis_instance_segmentation + +try: + import matplotlib # NOQA + optional_modules = True +except ImportError: + optional_modules = False + + +@testing.parameterize( + { + 'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1), + 'label_names': ('c0', 'c1', 'c2')}, + { + 'n_bbox': 3, 'label': (0, 1, 2), 'score': None, + 'label_names': ('c0', 'c1', 'c2')}, + { + 'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1), + 'label_names': None}, + { + 'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1), + 'label_names': ('c0', 'c1', 'c2')}, + { + 'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1), + 'label_names': None}, + { + 'n_bbox': 3, 'label': None, 'score': None, + 'label_names': None}, + { + 'n_bbox': 3, 'label': (0, 1, 1), 'score': (0, 0.5, 1), + 'label_names': ('c0', 'c1', 'c2')}, + { + 'n_bbox': 0, 'label': (), 'score': (), + 'label_names': ('c0', 'c1', 'c2')}, +) +class TestVisInstanceSegmentation(unittest.TestCase): + + def setUp(self): + self.img = np.random.randint(0, 255, size=(3, 32, 48)) + self.bbox = generate_random_bbox( + self.n_bbox, (48, 32), 8, 16) + self.mask = np.random.randint( + 0, 1, size=(self.n_bbox, 32, 48), dtype=bool) + if self.label is not None: + self.label = np.array(self.label, dtype=np.int32) + if self.score is not None: + self.score = np.array(self.score) + + def test_vis_instance_segmentation(self): + if not optional_modules: + return + + ax = vis_instance_segmentation( + self.img, self.bbox, self.mask, self.label, self.score, + label_names=self.label_names) + + self.assertIsInstance(ax, matplotlib.axes.Axes) + + +@testing.parameterize( + { + 'n_bbox': 3, 'label': (0, 1), 'score': (0, 0.5, 1), + 'label_names': ('c0', 'c1', 'c2')}, + { + 'n_bbox': 3, 'label': (0, 1, 2, 1), 'score': (0, 0.5, 1), + 'label_names': ('c0', 'c1', 'c2')}, + + { + 'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5), + 'label_names': ('c0', 'c1', 'c2')}, + { + 'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1, 0.75), + 'label_names': ('c0', 'c1', 'c2')}, + + { + 'n_bbox': 3, 'label': (0, 1, 3), 'score': (0, 0.5, 1), + 'label_names': ('c0', 'c1', 'c2')}, + { + 'n_bbox': 3, 'label': (-1, 1, 2), 'score': (0, 0.5, 1), + 'label_names': ('c0', 'c1', 'c2')}, + +) +class TestVisInstanceSegmentationInvalidInputs(unittest.TestCase): + + def setUp(self): + self.img = np.random.randint(0, 255, size=(3, 32, 48)) + self.bbox = np.random.uniform(size=(self.n_bbox, 4)) + self.mask = np.random.randint( + 0, 1, size=(self.n_bbox, 32, 48), dtype=bool) + if self.label is not None: + self.label = np.array(self.label, dtype=int) + if self.score is not None: + self.score = np.array(self.score) + + def test_vis_instance_segmentation_invalid_inputs(self): + if not optional_modules: + return + + with self.assertRaises(ValueError): + vis_instance_segmentation( + self.img, self.bbox, self.mask, self.label, self.score, + label_names=self.label_names) + + +testing.run_module(__name__, __file__)