Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Add vis instance segmentation func #541

Merged
merged 3 commits into from
Mar 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainercv/visualizations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from chainercv.visualizations.vis_bbox import vis_bbox # NOQA
from chainercv.visualizations.vis_image import vis_image # NOQA
from chainercv.visualizations.vis_instance_segmentation import vis_instance_segmentation # NOQA
from chainercv.visualizations.vis_point import vis_point # NOQA
from chainercv.visualizations.vis_semantic_segmentation import vis_semantic_segmentation # NOQA
97 changes: 97 additions & 0 deletions chainercv/visualizations/vis_instance_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import division

import numpy as np

from chainercv.visualizations.vis_image import vis_image
from chainercv.visualizations.vis_semantic_segmentation import _default_cmap


def vis_instance_segmentation(
img, bbox, mask, label=None, score=None, label_names=None,
alpha=0.7, ax=None):
"""Visualize instance segmentation.

Example:

>>> from chainercv.datasets import SBDInstanceSegmentationDataset
>>> from chainercv.datasets \
... import sbd_instance_segmentation_label_names
>>> from chainercv.visualizations import vis_instance_segmentation
>>> import matplotlib.pyplot as plot
>>> dataset = SBDSegmentationDataset()
>>> img, bbox, mask, label = dataset[0]
>>> vis_instance_segmentation(
... img, bbox, mask, label,
... label_names=sbd_instance_segmentation_label_names)
>>> plot.show()

Args:
img (~numpy.ndarray): An array of shape :math:`(3, H, W)`.
This is in RGB format and the range of its value is
:math:`[0, 255]`.
bbox (~numpy.ndarray): A float array of shape :math:`(R, 4)`.
:math:`R` is the number of objects in the image, and each
vector represents a bounding box of an object.
The bounding box is :math:`(y_min, x_min, y_max, x_max)`.
mask (~numpy.ndarray): A bool array of shape
:math`(R, H, W)`.
If there is an object, the value of the pixel is :obj:`True`,
and otherwise, it is :obj:`False`.
label (~numpy.ndarray): An integer array of shape :math:`(R, )`.
The values correspond to id for label names stored in
:obj:`label_names`.
label_names (iterable of strings): Name of labels ordered according
to label ids.
alpha (float): The value which determines transparency of the figure.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is alpha=1 a reasonable default value?
To me, alpha=0.7 looks better.
Also, the doc needs to be changed if we are changing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer alpha=0.7, so I will change it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated.

The range of this value is :math:`[0, 1]`. If this
value is :obj:`0`, the figure will be completely transparent.
The default value is :obj:`0.7`. This option is useful for
overlaying the label on the source image.
ax (matplotlib.axes.Axis): The visualization is displayed on this
axis. If this is :obj:`None` (default), a new axis is created.

Returns:
matploblib.axes.Axes: Returns :obj:`ax`.
:obj:`ax` is an :class:`matploblib.axes.Axes` with the plot.

"""
if len(bbox) != len(mask):
raise ValueError('The length of mask must be same as that of bbox')
if label is not None and len(bbox) != len(label):
raise ValueError('The length of label must be same as that of bbox')
if score is not None and len(bbox) != len(score):
raise ValueError('The length of score must be same as that of bbox')

n_inst = len(bbox)
colors = np.array([_default_cmap(l) for l in range(1, n_inst + 1)])

# Returns newly instantiated matplotlib.axes.Axes object if ax is None
ax = vis_image(img, ax=ax)

canvas_img = np.zeros((mask.shape[1], mask.shape[2], 4), dtype=np.uint8)
for i, (color, bb, msk) in enumerate(zip(colors, bbox, mask)):
rgba = np.append(color, alpha * 255)
bb = np.round(bb).astype(np.int32)
y_min, x_min, y_max, x_max = bb
if y_max > y_min and x_max > x_min:
canvas_img[msk] = rgba

caption = []
if label is not None and label_names is not None:
lb = label[i]
if not (0 <= lb < len(label_names)):
raise ValueError('No corresponding name is given')
caption.append(label_names[lb])
if score is not None:
sc = score[i]
caption.append('{:.2f}'.format(sc))

if len(caption) > 0:
ax.text((x_max + x_min) / 2, y_min,
': '.join(caption),
style='italic',
bbox={'facecolor': color / 255, 'alpha': alpha},
fontsize=8, color='white')

ax.imshow(canvas_img)
return ax
4 changes: 4 additions & 0 deletions docs/source/reference/visualizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ vis_image
~~~~~~~~~
.. autofunction:: vis_image

vis_instance_segmentation
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: vis_instance_segmentation

vis_point
~~~~~~~~~
.. autofunction:: vis_point
Expand Down
111 changes: 111 additions & 0 deletions tests/visualizations_tests/test_vis_instance_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
import unittest

from chainer import testing

from chainercv.utils import generate_random_bbox
from chainercv.visualizations import vis_instance_segmentation

try:
import matplotlib # NOQA
optional_modules = True
except ImportError:
optional_modules = False


@testing.parameterize(
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': None,
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
'label_names': None},
{
'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1),
'label_names': None},
{
'n_bbox': 3, 'label': None, 'score': None,
'label_names': None},
{
'n_bbox': 3, 'label': (0, 1, 1), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 0, 'label': (), 'score': (),
'label_names': ('c0', 'c1', 'c2')},
)
class TestVisInstanceSegmentation(unittest.TestCase):

def setUp(self):
self.img = np.random.randint(0, 255, size=(3, 32, 48))
self.bbox = generate_random_bbox(
self.n_bbox, (48, 32), 8, 16)
self.mask = np.random.randint(
0, 1, size=(self.n_bbox, 32, 48), dtype=bool)
if self.label is not None:
self.label = np.array(self.label, dtype=np.int32)
if self.score is not None:
self.score = np.array(self.score)

def test_vis_instance_segmentation(self):
if not optional_modules:
return

ax = vis_instance_segmentation(
self.img, self.bbox, self.mask, self.label, self.score,
label_names=self.label_names)

self.assertIsInstance(ax, matplotlib.axes.Axes)


@testing.parameterize(
{
'n_bbox': 3, 'label': (0, 1), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2, 1), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},

{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1, 0.75),
'label_names': ('c0', 'c1', 'c2')},

{
'n_bbox': 3, 'label': (0, 1, 3), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},
{
'n_bbox': 3, 'label': (-1, 1, 2), 'score': (0, 0.5, 1),
'label_names': ('c0', 'c1', 'c2')},

)
class TestVisInstanceSegmentationInvalidInputs(unittest.TestCase):

def setUp(self):
self.img = np.random.randint(0, 255, size=(3, 32, 48))
self.bbox = np.random.uniform(size=(self.n_bbox, 4))
self.mask = np.random.randint(
0, 1, size=(self.n_bbox, 32, 48), dtype=bool)
if self.label is not None:
self.label = np.array(self.label, dtype=int)
if self.score is not None:
self.score = np.array(self.score)

def test_vis_instance_segmentation_invalid_inputs(self):
if not optional_modules:
return

with self.assertRaises(ValueError):
vis_instance_segmentation(
self.img, self.bbox, self.mask, self.label, self.score,
label_names=self.label_names)


testing.run_module(__name__, __file__)