Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Commit b8280c2

Browse files
authored
Merge pull request #541 from knorth55/add-vis-instance-segmentation
Add vis instance segmentation func
2 parents ff7f705 + d163005 commit b8280c2

File tree

4 files changed

+213
-0
lines changed

4 files changed

+213
-0
lines changed

chainercv/visualizations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from chainercv.visualizations.vis_bbox import vis_bbox # NOQA
22
from chainercv.visualizations.vis_image import vis_image # NOQA
3+
from chainercv.visualizations.vis_instance_segmentation import vis_instance_segmentation # NOQA
34
from chainercv.visualizations.vis_point import vis_point # NOQA
45
from chainercv.visualizations.vis_semantic_segmentation import vis_semantic_segmentation # NOQA
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from __future__ import division
2+
3+
import numpy as np
4+
5+
from chainercv.visualizations.vis_image import vis_image
6+
from chainercv.visualizations.vis_semantic_segmentation import _default_cmap
7+
8+
9+
def vis_instance_segmentation(
10+
img, bbox, mask, label=None, score=None, label_names=None,
11+
alpha=0.7, ax=None):
12+
"""Visualize instance segmentation.
13+
14+
Example:
15+
16+
>>> from chainercv.datasets import SBDInstanceSegmentationDataset
17+
>>> from chainercv.datasets \
18+
... import sbd_instance_segmentation_label_names
19+
>>> from chainercv.visualizations import vis_instance_segmentation
20+
>>> import matplotlib.pyplot as plot
21+
>>> dataset = SBDSegmentationDataset()
22+
>>> img, bbox, mask, label = dataset[0]
23+
>>> vis_instance_segmentation(
24+
... img, bbox, mask, label,
25+
... label_names=sbd_instance_segmentation_label_names)
26+
>>> plot.show()
27+
28+
Args:
29+
img (~numpy.ndarray): An array of shape :math:`(3, H, W)`.
30+
This is in RGB format and the range of its value is
31+
:math:`[0, 255]`.
32+
bbox (~numpy.ndarray): A float array of shape :math:`(R, 4)`.
33+
:math:`R` is the number of objects in the image, and each
34+
vector represents a bounding box of an object.
35+
The bounding box is :math:`(y_min, x_min, y_max, x_max)`.
36+
mask (~numpy.ndarray): A bool array of shape
37+
:math`(R, H, W)`.
38+
If there is an object, the value of the pixel is :obj:`True`,
39+
and otherwise, it is :obj:`False`.
40+
label (~numpy.ndarray): An integer array of shape :math:`(R, )`.
41+
The values correspond to id for label names stored in
42+
:obj:`label_names`.
43+
label_names (iterable of strings): Name of labels ordered according
44+
to label ids.
45+
alpha (float): The value which determines transparency of the figure.
46+
The range of this value is :math:`[0, 1]`. If this
47+
value is :obj:`0`, the figure will be completely transparent.
48+
The default value is :obj:`0.7`. This option is useful for
49+
overlaying the label on the source image.
50+
ax (matplotlib.axes.Axis): The visualization is displayed on this
51+
axis. If this is :obj:`None` (default), a new axis is created.
52+
53+
Returns:
54+
matploblib.axes.Axes: Returns :obj:`ax`.
55+
:obj:`ax` is an :class:`matploblib.axes.Axes` with the plot.
56+
57+
"""
58+
if len(bbox) != len(mask):
59+
raise ValueError('The length of mask must be same as that of bbox')
60+
if label is not None and len(bbox) != len(label):
61+
raise ValueError('The length of label must be same as that of bbox')
62+
if score is not None and len(bbox) != len(score):
63+
raise ValueError('The length of score must be same as that of bbox')
64+
65+
n_inst = len(bbox)
66+
colors = np.array([_default_cmap(l) for l in range(1, n_inst + 1)])
67+
68+
# Returns newly instantiated matplotlib.axes.Axes object if ax is None
69+
ax = vis_image(img, ax=ax)
70+
71+
canvas_img = np.zeros((mask.shape[1], mask.shape[2], 4), dtype=np.uint8)
72+
for i, (color, bb, msk) in enumerate(zip(colors, bbox, mask)):
73+
rgba = np.append(color, alpha * 255)
74+
bb = np.round(bb).astype(np.int32)
75+
y_min, x_min, y_max, x_max = bb
76+
if y_max > y_min and x_max > x_min:
77+
canvas_img[msk] = rgba
78+
79+
caption = []
80+
if label is not None and label_names is not None:
81+
lb = label[i]
82+
if not (0 <= lb < len(label_names)):
83+
raise ValueError('No corresponding name is given')
84+
caption.append(label_names[lb])
85+
if score is not None:
86+
sc = score[i]
87+
caption.append('{:.2f}'.format(sc))
88+
89+
if len(caption) > 0:
90+
ax.text((x_max + x_min) / 2, y_min,
91+
': '.join(caption),
92+
style='italic',
93+
bbox={'facecolor': color / 255, 'alpha': alpha},
94+
fontsize=8, color='white')
95+
96+
ax.imshow(canvas_img)
97+
return ax

docs/source/reference/visualizations.rst

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ vis_image
1212
~~~~~~~~~
1313
.. autofunction:: vis_image
1414

15+
vis_instance_segmentation
16+
~~~~~~~~~~~~~~~~~~~~~~~~~
17+
.. autofunction:: vis_instance_segmentation
18+
1519
vis_point
1620
~~~~~~~~~
1721
.. autofunction:: vis_point
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import numpy as np
2+
import unittest
3+
4+
from chainer import testing
5+
6+
from chainercv.utils import generate_random_bbox
7+
from chainercv.visualizations import vis_instance_segmentation
8+
9+
try:
10+
import matplotlib # NOQA
11+
optional_modules = True
12+
except ImportError:
13+
optional_modules = False
14+
15+
16+
@testing.parameterize(
17+
{
18+
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
19+
'label_names': ('c0', 'c1', 'c2')},
20+
{
21+
'n_bbox': 3, 'label': (0, 1, 2), 'score': None,
22+
'label_names': ('c0', 'c1', 'c2')},
23+
{
24+
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1),
25+
'label_names': None},
26+
{
27+
'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1),
28+
'label_names': ('c0', 'c1', 'c2')},
29+
{
30+
'n_bbox': 3, 'label': None, 'score': (0, 0.5, 1),
31+
'label_names': None},
32+
{
33+
'n_bbox': 3, 'label': None, 'score': None,
34+
'label_names': None},
35+
{
36+
'n_bbox': 3, 'label': (0, 1, 1), 'score': (0, 0.5, 1),
37+
'label_names': ('c0', 'c1', 'c2')},
38+
{
39+
'n_bbox': 0, 'label': (), 'score': (),
40+
'label_names': ('c0', 'c1', 'c2')},
41+
)
42+
class TestVisInstanceSegmentation(unittest.TestCase):
43+
44+
def setUp(self):
45+
self.img = np.random.randint(0, 255, size=(3, 32, 48))
46+
self.bbox = generate_random_bbox(
47+
self.n_bbox, (48, 32), 8, 16)
48+
self.mask = np.random.randint(
49+
0, 1, size=(self.n_bbox, 32, 48), dtype=bool)
50+
if self.label is not None:
51+
self.label = np.array(self.label, dtype=np.int32)
52+
if self.score is not None:
53+
self.score = np.array(self.score)
54+
55+
def test_vis_instance_segmentation(self):
56+
if not optional_modules:
57+
return
58+
59+
ax = vis_instance_segmentation(
60+
self.img, self.bbox, self.mask, self.label, self.score,
61+
label_names=self.label_names)
62+
63+
self.assertIsInstance(ax, matplotlib.axes.Axes)
64+
65+
66+
@testing.parameterize(
67+
{
68+
'n_bbox': 3, 'label': (0, 1), 'score': (0, 0.5, 1),
69+
'label_names': ('c0', 'c1', 'c2')},
70+
{
71+
'n_bbox': 3, 'label': (0, 1, 2, 1), 'score': (0, 0.5, 1),
72+
'label_names': ('c0', 'c1', 'c2')},
73+
74+
{
75+
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5),
76+
'label_names': ('c0', 'c1', 'c2')},
77+
{
78+
'n_bbox': 3, 'label': (0, 1, 2), 'score': (0, 0.5, 1, 0.75),
79+
'label_names': ('c0', 'c1', 'c2')},
80+
81+
{
82+
'n_bbox': 3, 'label': (0, 1, 3), 'score': (0, 0.5, 1),
83+
'label_names': ('c0', 'c1', 'c2')},
84+
{
85+
'n_bbox': 3, 'label': (-1, 1, 2), 'score': (0, 0.5, 1),
86+
'label_names': ('c0', 'c1', 'c2')},
87+
88+
)
89+
class TestVisInstanceSegmentationInvalidInputs(unittest.TestCase):
90+
91+
def setUp(self):
92+
self.img = np.random.randint(0, 255, size=(3, 32, 48))
93+
self.bbox = np.random.uniform(size=(self.n_bbox, 4))
94+
self.mask = np.random.randint(
95+
0, 1, size=(self.n_bbox, 32, 48), dtype=bool)
96+
if self.label is not None:
97+
self.label = np.array(self.label, dtype=int)
98+
if self.score is not None:
99+
self.score = np.array(self.score)
100+
101+
def test_vis_instance_segmentation_invalid_inputs(self):
102+
if not optional_modules:
103+
return
104+
105+
with self.assertRaises(ValueError):
106+
vis_instance_segmentation(
107+
self.img, self.bbox, self.mask, self.label, self.score,
108+
label_names=self.label_names)
109+
110+
111+
testing.run_module(__name__, __file__)

0 commit comments

Comments
 (0)