Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Commit 2d337bd

Browse files
authored
Merge pull request #674 from knorth55/coco-ins-eval-ext
add InstanceSegmentationCOCOEvaluator
2 parents d3be047 + 5fdf86e commit 2d337bd

File tree

5 files changed

+284
-1
lines changed

5 files changed

+284
-1
lines changed

chainercv/extensions/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from chainercv.extensions.evaluator.detection_voc_evaluator import DetectionVOCEvaluator # NOQA
2+
from chainercv.extensions.evaluator.instance_segmentation_coco_evaluator import InstanceSegmentationCOCOEvaluator # NOQA
23
from chainercv.extensions.evaluator.instance_segmentation_voc_evaluator import InstanceSegmentationVOCEvaluator # NOQA
34
from chainercv.extensions.evaluator.semantic_segmentation_evaluator import SemanticSegmentationEvaluator # NOQA
45
from chainercv.extensions.vis_report.detection_vis_report import DetectionVisReport # NOQA
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import copy
2+
import numpy as np
3+
4+
from chainer import reporter
5+
import chainer.training.extensions
6+
7+
from chainercv.evaluations import eval_instance_segmentation_coco
8+
from chainercv.utils import apply_to_iterator
9+
10+
try:
11+
import pycocotools.coco # NOQA
12+
_available = True
13+
except ImportError:
14+
_available = False
15+
16+
17+
class InstanceSegmentationCOCOEvaluator(chainer.training.extensions.Evaluator):
18+
19+
"""An extension that evaluates a instance segmentation model by MS COCO metric.
20+
21+
This extension iterates over an iterator and evaluates the prediction
22+
results.
23+
The results consist of average precisions (APs) and average
24+
recalls (ARs) as well as the mean of each (mean average precision and mean
25+
average recall).
26+
This extension reports the following values with keys.
27+
Please note that if
28+
:obj:`label_names` is not specified, only the mAPs and mARs are reported.
29+
30+
The underlying dataset of the iterator is assumed to return
31+
:obj:`img, mask, label` or :obj:`img, mask, label, area, crowded`.
32+
33+
.. csv-table::
34+
:header: key, description
35+
36+
ap/iou=0.50:0.95/area=all/max_dets=100/<label_names[l]>, \
37+
[#coco_ins_ext_1]_
38+
ap/iou=0.50/area=all/max_dets=100/<label_names[l]>, \
39+
[#coco_ins_ext_1]_
40+
ap/iou=0.75/area=all/max_dets=100/<label_names[l]>, \
41+
[#coco_ins_ext_1]_
42+
ap/iou=0.50:0.95/area=small/max_dets=100/<label_names[l]>, \
43+
[#coco_ins_ext_1]_ [#coco_ins_ext_5]_
44+
ap/iou=0.50:0.95/area=medium/max_dets=100/<label_names[l]>, \
45+
[#coco_ins_ext_1]_ [#coco_ins_ext_5]_
46+
ap/iou=0.50:0.95/area=large/max_dets=100/<label_names[l]>, \
47+
[#coco_ins_ext_1]_ [#coco_ins_ext_5]_
48+
ar/iou=0.50:0.95/area=all/max_dets=1/<label_names[l]>, \
49+
[#coco_ins_ext_2]_
50+
ar/iou=0.50/area=all/max_dets=10/<label_names[l]>, \
51+
[#coco_ins_ext_2]_
52+
ar/iou=0.75/area=all/max_dets=100/<label_names[l]>, \
53+
[#coco_ins_ext_2]_
54+
ar/iou=0.50:0.95/area=small/max_dets=100/<label_names[l]>, \
55+
[#coco_ins_ext_2]_ [#coco_ins_ext_5]_
56+
ar/iou=0.50:0.95/area=medium/max_dets=100/<label_names[l]>, \
57+
[#coco_ins_ext_2]_ [#coco_ins_ext_5]_
58+
ar/iou=0.50:0.95/area=large/max_dets=100/<label_names[l]>, \
59+
[#coco_ins_ext_2]_ [#coco_ins_ext_5]_
60+
map/iou=0.50:0.95/area=all/max_dets=100, \
61+
[#coco_ins_ext_3]_
62+
map/iou=0.50/area=all/max_dets=100, \
63+
[#coco_ins_ext_3]_
64+
map/iou=0.75/area=all/max_dets=100, \
65+
[#coco_ins_ext_3]_
66+
map/iou=0.50:0.95/area=small/max_dets=100, \
67+
[#coco_ins_ext_3]_ [#coco_ins_ext_5]_
68+
map/iou=0.50:0.95/area=medium/max_dets=100, \
69+
[#coco_ins_ext_3]_ [#coco_ins_ext_5]_
70+
map/iou=0.50:0.95/area=large/max_dets=100, \
71+
[#coco_ins_ext_3]_ [#coco_ins_ext_5]_
72+
ar/iou=0.50:0.95/area=all/max_dets=1, \
73+
[#coco_ins_ext_4]_
74+
ar/iou=0.50/area=all/max_dets=10, \
75+
[#coco_ins_ext_4]_
76+
ar/iou=0.75/area=all/max_dets=100, \
77+
[#coco_ins_ext_4]_
78+
ar/iou=0.50:0.95/area=small/max_dets=100, \
79+
[#coco_ins_ext_4]_ [#coco_ins_ext_5]_
80+
ar/iou=0.50:0.95/area=medium/max_dets=100, \
81+
[#coco_ins_ext_4]_ [#coco_ins_ext_5]_
82+
ar/iou=0.50:0.95/area=large/max_dets=100, \
83+
[#coco_ins_ext_4]_ [#coco_ins_ext_5]_
84+
85+
.. [#coco_ins_ext_1] Average precision for class \
86+
:obj:`label_names[l]`, where :math:`l` is the index of the class. \
87+
If class :math:`l` does not exist in either :obj:`pred_labels` or \
88+
:obj:`gt_labels`, the corresponding value is set to :obj:`numpy.nan`.
89+
.. [#coco_ins_ext_2] Average recall for class \
90+
:obj:`label_names[l]`, where :math:`l` is the index of the class. \
91+
If class :math:`l` does not exist in either :obj:`pred_labels` or \
92+
:obj:`gt_labels`, the corresponding value is set to :obj:`numpy.nan`.
93+
.. [#coco_ins_ext_3] The average of average precisions over classes.
94+
.. [#coco_ins_ext_4] The average of average recalls over classes.
95+
.. [#coco_ins_ext_5] Skip if :obj:`gt_areas` is :obj:`None`.
96+
97+
Args:
98+
iterator (chainer.Iterator): An iterator. Each sample should be
99+
following tuple :obj:`img, mask, label, area, crowded`.
100+
target (chainer.Link): A detection link. This link must have
101+
:meth:`predict` method that takes a list of images and returns
102+
:obj:`masks`, :obj:`labels` and :obj:`scores`.
103+
label_names (iterable of strings): An iterable of names of classes.
104+
If this value is specified, average precision and average
105+
recalls for each class are reported.
106+
107+
"""
108+
109+
trigger = 1, 'epoch'
110+
default_name = 'validation'
111+
priority = chainer.training.PRIORITY_WRITER
112+
113+
def __init__(
114+
self, iterator, target,
115+
label_names=None):
116+
if not _available:
117+
raise ValueError(
118+
'Please install pycocotools \n'
119+
'pip install -e \'git+https://github.com/pdollar/coco.git'
120+
'#egg=pycocotools&subdirectory=PythonAPI\'')
121+
super(InstanceSegmentationCOCOEvaluator, self).__init__(
122+
iterator, target)
123+
self.label_names = label_names
124+
125+
def evaluate(self):
126+
iterator = self._iterators['main']
127+
target = self._targets['main']
128+
129+
if hasattr(iterator, 'reset'):
130+
iterator.reset()
131+
it = iterator
132+
else:
133+
it = copy.copy(iterator)
134+
135+
in_values, out_values, rest_values = apply_to_iterator(
136+
target.predict, it)
137+
# delete unused iterators explicitly
138+
del in_values
139+
140+
pred_masks, pred_labels, pred_scores = out_values
141+
142+
if len(rest_values) == 2:
143+
gt_masks, gt_labels = rest_values
144+
gt_areas = None
145+
gt_crowdeds = None
146+
elif len(rest_values) == 4:
147+
gt_masks, gt_labels, gt_areas, gt_crowdeds =\
148+
rest_values
149+
else:
150+
raise ValueError('the dataset should return '
151+
'sets of (img, mask, label) or sets of '
152+
'(img, mask, label, area, crowded).')
153+
154+
result = eval_instance_segmentation_coco(
155+
pred_masks, pred_labels, pred_scores,
156+
gt_masks, gt_labels, gt_areas, gt_crowdeds)
157+
158+
report = {}
159+
for key in result.keys():
160+
if key.startswith('map') or key.startswith('mar'):
161+
report[key] = result[key]
162+
163+
if self.label_names is not None:
164+
for key in result.keys():
165+
if key.startswith('ap') or key.startswith('ar'):
166+
for l, label_name in enumerate(self.label_names):
167+
report_key = '{}/{:s}'.format(key, label_name)
168+
try:
169+
report[report_key] = result[key][l]
170+
except IndexError:
171+
report[report_key] = np.nan
172+
173+
observation = {}
174+
with reporter.report_scope(observation):
175+
reporter.report(report, target)
176+
return observation

docs/source/reference/extensions.rst

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ DetectionVOCEvaluator
1111
~~~~~~~~~~~~~~~~~~~~~
1212
.. autoclass:: DetectionVOCEvaluator
1313

14+
InstanceSegmentationCOCOEvaluator
15+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
16+
.. autoclass:: InstanceSegmentationCOCOEvaluator
17+
1418
InstanceSegmentationVOCEvaluator
1519
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1620
.. autoclass:: InstanceSegmentationVOCEvaluator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numpy as np
2+
import unittest
3+
4+
import chainer
5+
from chainer.datasets import TupleDataset
6+
from chainer.iterators import SerialIterator
7+
from chainer import testing
8+
9+
from chainercv.extensions import InstanceSegmentationCOCOEvaluator
10+
11+
try:
12+
import pycocotools.coco # NOQA
13+
_available = True
14+
except ImportError:
15+
_available = False
16+
17+
18+
class _InstanceSegmentationStubLink(chainer.Link):
19+
20+
def __init__(self, masks, labels):
21+
super(_InstanceSegmentationStubLink, self).__init__()
22+
self.count = 0
23+
self.masks = masks
24+
self.labels = labels
25+
26+
def predict(self, imgs):
27+
n_img = len(imgs)
28+
masks = self.masks[self.count:self.count + n_img]
29+
labels = self.labels[self.count:self.count + n_img]
30+
scores = [np.ones_like(l) for l in labels]
31+
32+
self.count += n_img
33+
34+
return masks, labels, scores
35+
36+
37+
@unittest.skipUnless(_available, 'pycocotools is not installed')
38+
class TestInstanceSegmentationCOCOEvaluator(unittest.TestCase):
39+
40+
def setUp(self):
41+
masks = np.random.uniform(size=(10, 5, 32, 48)) > 0.5
42+
labels = np.ones((10, 5), dtype=np.int32)
43+
self.dataset = TupleDataset(
44+
np.random.uniform(size=(10, 3, 32, 48)),
45+
masks, labels)
46+
self.link = _InstanceSegmentationStubLink(masks, labels)
47+
self.iterator = SerialIterator(
48+
self.dataset, 1, repeat=False, shuffle=False)
49+
self.evaluator = InstanceSegmentationCOCOEvaluator(
50+
self.iterator, self.link, label_names=('cls0', 'cls1', 'cls2'))
51+
self.expected_ap = 1
52+
53+
def test_evaluate(self):
54+
reporter = chainer.Reporter()
55+
reporter.add_observer('target', self.link)
56+
with reporter:
57+
mean = self.evaluator.evaluate()
58+
59+
# No observation is reported to the current reporter. Instead the
60+
# evaluator collect results in order to calculate their mean.
61+
self.assertEqual(len(reporter.observation), 0)
62+
63+
key = 'ap/iou=0.50:0.95/area=all/max_dets=100'
64+
np.testing.assert_equal(
65+
mean['target/m{}'.format(key)], self.expected_ap)
66+
np.testing.assert_equal(mean['target/{}/cls0'.format(key)], np.nan)
67+
np.testing.assert_equal(
68+
mean['target/{}/cls1'.format(key)], self.expected_ap)
69+
np.testing.assert_equal(mean['target/{}/cls2'.format(key)], np.nan)
70+
71+
def test_call(self):
72+
mean = self.evaluator()
73+
# main is used as default
74+
key = 'ap/iou=0.50:0.95/area=all/max_dets=100'
75+
key = 'ap/iou=0.50:0.95/area=all/max_dets=100'
76+
np.testing.assert_equal(mean['main/m{}'.format(key)], self.expected_ap)
77+
np.testing.assert_equal(mean['main/{}/cls0'.format(key)], np.nan)
78+
np.testing.assert_equal(
79+
mean['main/{}/cls1'.format(key)], self.expected_ap)
80+
np.testing.assert_equal(mean['main/{}/cls2'.format(key)], np.nan)
81+
82+
def test_evaluator_name(self):
83+
self.evaluator.name = 'eval'
84+
mean = self.evaluator()
85+
# name is used as a prefix
86+
key = 'ap/iou=0.50:0.95/area=all/max_dets=100'
87+
np.testing.assert_equal(
88+
mean['eval/main/m{}'.format(key)], self.expected_ap)
89+
np.testing.assert_equal(mean['eval/main/{}/cls0'.format(key)], np.nan)
90+
np.testing.assert_equal(
91+
mean['eval/main/{}/cls1'.format(key)], self.expected_ap)
92+
np.testing.assert_equal(mean['eval/main/{}/cls2'.format(key)], np.nan)
93+
94+
def test_current_report(self):
95+
reporter = chainer.Reporter()
96+
with reporter:
97+
mean = self.evaluator()
98+
# The result is reported to the current reporter.
99+
self.assertEqual(reporter.observation, mean)
100+
101+
102+
testing.run_module(__name__, __file__)

tests/extensions_tests/evaluator_tests/test_instance_segmentation_voc_evaluator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class TestInstanceSegmentationVOCEvaluator(unittest.TestCase):
3232

3333
def setUp(self):
3434
masks = np.random.uniform(size=(10, 5, 32, 48)) > 0.5
35-
labels = np.ones((10, 5))
35+
labels = np.ones((10, 5), dtype=np.int32)
3636
self.dataset = TupleDataset(
3737
np.random.uniform(size=(10, 3, 32, 48)),
3838
masks, labels)

0 commit comments

Comments
 (0)