Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

add InstanceSegmentationCOCOEvaluator #674

Merged
merged 5 commits into from
Jul 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainercv/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from chainercv.extensions.evaluator.detection_voc_evaluator import DetectionVOCEvaluator # NOQA
from chainercv.extensions.evaluator.instance_segmentation_coco_evaluator import InstanceSegmentationCOCOEvaluator # NOQA
from chainercv.extensions.evaluator.instance_segmentation_voc_evaluator import InstanceSegmentationVOCEvaluator # NOQA
from chainercv.extensions.evaluator.semantic_segmentation_evaluator import SemanticSegmentationEvaluator # NOQA
from chainercv.extensions.vis_report.detection_vis_report import DetectionVisReport # NOQA
176 changes: 176 additions & 0 deletions chainercv/extensions/evaluator/instance_segmentation_coco_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import copy
import numpy as np

from chainer import reporter
import chainer.training.extensions

from chainercv.evaluations import eval_instance_segmentation_coco
from chainercv.utils import apply_to_iterator

try:
import pycocotools.coco # NOQA
_available = True
except ImportError:
_available = False


class InstanceSegmentationCOCOEvaluator(chainer.training.extensions.Evaluator):

"""An extension that evaluates a instance segmentation model by MS COCO metric.

This extension iterates over an iterator and evaluates the prediction
results.
The results consist of average precisions (APs) and average
recalls (ARs) as well as the mean of each (mean average precision and mean
average recall).
This extension reports the following values with keys.
Please note that if
:obj:`label_names` is not specified, only the mAPs and mARs are reported.

The underlying dataset of the iterator is assumed to return
:obj:`img, mask, label` or :obj:`img, mask, label, area, crowded`.

.. csv-table::
:header: key, description

ap/iou=0.50:0.95/area=all/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_1]_
ap/iou=0.50/area=all/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_1]_
ap/iou=0.75/area=all/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_1]_
ap/iou=0.50:0.95/area=small/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_1]_ [#coco_ins_ext_5]_
ap/iou=0.50:0.95/area=medium/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_1]_ [#coco_ins_ext_5]_
ap/iou=0.50:0.95/area=large/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_1]_ [#coco_ins_ext_5]_
ar/iou=0.50:0.95/area=all/max_dets=1/<label_names[l]>, \
[#coco_ins_ext_2]_
ar/iou=0.50/area=all/max_dets=10/<label_names[l]>, \
[#coco_ins_ext_2]_
ar/iou=0.75/area=all/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_2]_
ar/iou=0.50:0.95/area=small/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_2]_ [#coco_ins_ext_5]_
ar/iou=0.50:0.95/area=medium/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_2]_ [#coco_ins_ext_5]_
ar/iou=0.50:0.95/area=large/max_dets=100/<label_names[l]>, \
[#coco_ins_ext_2]_ [#coco_ins_ext_5]_
map/iou=0.50:0.95/area=all/max_dets=100, \
[#coco_ins_ext_3]_
map/iou=0.50/area=all/max_dets=100, \
[#coco_ins_ext_3]_
map/iou=0.75/area=all/max_dets=100, \
[#coco_ins_ext_3]_
map/iou=0.50:0.95/area=small/max_dets=100, \
[#coco_ins_ext_3]_ [#coco_ins_ext_5]_
map/iou=0.50:0.95/area=medium/max_dets=100, \
[#coco_ins_ext_3]_ [#coco_ins_ext_5]_
map/iou=0.50:0.95/area=large/max_dets=100, \
[#coco_ins_ext_3]_ [#coco_ins_ext_5]_
ar/iou=0.50:0.95/area=all/max_dets=1, \
[#coco_ins_ext_4]_
ar/iou=0.50/area=all/max_dets=10, \
[#coco_ins_ext_4]_
ar/iou=0.75/area=all/max_dets=100, \
[#coco_ins_ext_4]_
ar/iou=0.50:0.95/area=small/max_dets=100, \
[#coco_ins_ext_4]_ [#coco_ins_ext_5]_
ar/iou=0.50:0.95/area=medium/max_dets=100, \
[#coco_ins_ext_4]_ [#coco_ins_ext_5]_
ar/iou=0.50:0.95/area=large/max_dets=100, \
[#coco_ins_ext_4]_ [#coco_ins_ext_5]_

.. [#coco_ins_ext_1] Average precision for class \
:obj:`label_names[l]`, where :math:`l` is the index of the class. \
If class :math:`l` does not exist in either :obj:`pred_labels` or \
:obj:`gt_labels`, the corresponding value is set to :obj:`numpy.nan`.
.. [#coco_ins_ext_2] Average recall for class \
:obj:`label_names[l]`, where :math:`l` is the index of the class. \
If class :math:`l` does not exist in either :obj:`pred_labels` or \
:obj:`gt_labels`, the corresponding value is set to :obj:`numpy.nan`.
.. [#coco_ins_ext_3] The average of average precisions over classes.
.. [#coco_ins_ext_4] The average of average recalls over classes.
.. [#coco_ins_ext_5] Skip if :obj:`gt_areas` is :obj:`None`.

Args:
iterator (chainer.Iterator): An iterator. Each sample should be
following tuple :obj:`img, mask, label, area, crowded`.
target (chainer.Link): A detection link. This link must have
:meth:`predict` method that takes a list of images and returns
:obj:`masks`, :obj:`labels` and :obj:`scores`.
label_names (iterable of strings): An iterable of names of classes.
If this value is specified, average precision and average
recalls for each class are reported.

"""

trigger = 1, 'epoch'
default_name = 'validation'
priority = chainer.training.PRIORITY_WRITER

def __init__(
self, iterator, target,
label_names=None):
if not _available:
raise ValueError(
'Please install pycocotools \n'
'pip install -e \'git+https://github.com/pdollar/coco.git'
'#egg=pycocotools&subdirectory=PythonAPI\'')
super(InstanceSegmentationCOCOEvaluator, self).__init__(
iterator, target)
self.label_names = label_names

def evaluate(self):
iterator = self._iterators['main']
target = self._targets['main']

if hasattr(iterator, 'reset'):
iterator.reset()
it = iterator
else:
it = copy.copy(iterator)

in_values, out_values, rest_values = apply_to_iterator(
target.predict, it)
# delete unused iterators explicitly
del in_values

pred_masks, pred_labels, pred_scores = out_values

if len(rest_values) == 2:
gt_masks, gt_labels = rest_values
gt_areas = None
gt_crowdeds = None
elif len(rest_values) == 4:
gt_masks, gt_labels, gt_areas, gt_crowdeds =\
rest_values
else:
raise ValueError('the dataset should return '
'sets of (img, mask, label) or sets of '
'(img, mask, label, area, crowded).')

result = eval_instance_segmentation_coco(
pred_masks, pred_labels, pred_scores,
gt_masks, gt_labels, gt_areas, gt_crowdeds)

report = {}
for key in result.keys():
if key.startswith('map') or key.startswith('mar'):
report[key] = result[key]

if self.label_names is not None:
for key in result.keys():
if key.startswith('ap') or key.startswith('ar'):
for l, label_name in enumerate(self.label_names):
report_key = '{}/{:s}'.format(key, label_name)
try:
report[report_key] = result[key][l]
except IndexError:
report[report_key] = np.nan

observation = {}
with reporter.report_scope(observation):
reporter.report(report, target)
return observation
4 changes: 4 additions & 0 deletions docs/source/reference/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ DetectionVOCEvaluator
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DetectionVOCEvaluator

InstanceSegmentationCOCOEvaluator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: InstanceSegmentationCOCOEvaluator

InstanceSegmentationVOCEvaluator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: InstanceSegmentationVOCEvaluator
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np
import unittest

import chainer
from chainer.datasets import TupleDataset
from chainer.iterators import SerialIterator
from chainer import testing

from chainercv.extensions import InstanceSegmentationCOCOEvaluator

try:
import pycocotools.coco # NOQA
_available = True
except ImportError:
_available = False


class _InstanceSegmentationStubLink(chainer.Link):

def __init__(self, masks, labels):
super(_InstanceSegmentationStubLink, self).__init__()
self.count = 0
self.masks = masks
self.labels = labels

def predict(self, imgs):
n_img = len(imgs)
masks = self.masks[self.count:self.count + n_img]
labels = self.labels[self.count:self.count + n_img]
scores = [np.ones_like(l) for l in labels]

self.count += n_img

return masks, labels, scores


@unittest.skipUnless(_available, 'pycocotools is not installed')
class TestInstanceSegmentationCOCOEvaluator(unittest.TestCase):

def setUp(self):
masks = np.random.uniform(size=(10, 5, 32, 48)) > 0.5
labels = np.ones((10, 5), dtype=np.int32)
self.dataset = TupleDataset(
np.random.uniform(size=(10, 3, 32, 48)),
masks, labels)
self.link = _InstanceSegmentationStubLink(masks, labels)
self.iterator = SerialIterator(
self.dataset, 1, repeat=False, shuffle=False)
self.evaluator = InstanceSegmentationCOCOEvaluator(
self.iterator, self.link, label_names=('cls0', 'cls1', 'cls2'))
self.expected_ap = 1

def test_evaluate(self):
reporter = chainer.Reporter()
reporter.add_observer('target', self.link)
with reporter:
mean = self.evaluator.evaluate()

# No observation is reported to the current reporter. Instead the
# evaluator collect results in order to calculate their mean.
self.assertEqual(len(reporter.observation), 0)

key = 'ap/iou=0.50:0.95/area=all/max_dets=100'
np.testing.assert_equal(
mean['target/m{}'.format(key)], self.expected_ap)
np.testing.assert_equal(mean['target/{}/cls0'.format(key)], np.nan)
np.testing.assert_equal(
mean['target/{}/cls1'.format(key)], self.expected_ap)
np.testing.assert_equal(mean['target/{}/cls2'.format(key)], np.nan)

def test_call(self):
mean = self.evaluator()
# main is used as default
key = 'ap/iou=0.50:0.95/area=all/max_dets=100'
key = 'ap/iou=0.50:0.95/area=all/max_dets=100'
np.testing.assert_equal(mean['main/m{}'.format(key)], self.expected_ap)
np.testing.assert_equal(mean['main/{}/cls0'.format(key)], np.nan)
np.testing.assert_equal(
mean['main/{}/cls1'.format(key)], self.expected_ap)
np.testing.assert_equal(mean['main/{}/cls2'.format(key)], np.nan)

def test_evaluator_name(self):
self.evaluator.name = 'eval'
mean = self.evaluator()
# name is used as a prefix
key = 'ap/iou=0.50:0.95/area=all/max_dets=100'
np.testing.assert_equal(
mean['eval/main/m{}'.format(key)], self.expected_ap)
np.testing.assert_equal(mean['eval/main/{}/cls0'.format(key)], np.nan)
np.testing.assert_equal(
mean['eval/main/{}/cls1'.format(key)], self.expected_ap)
np.testing.assert_equal(mean['eval/main/{}/cls2'.format(key)], np.nan)

def test_current_report(self):
reporter = chainer.Reporter()
with reporter:
mean = self.evaluator()
# The result is reported to the current reporter.
self.assertEqual(reporter.observation, mean)


testing.run_module(__name__, __file__)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TestInstanceSegmentationVOCEvaluator(unittest.TestCase):

def setUp(self):
masks = np.random.uniform(size=(10, 5, 32, 48)) > 0.5
labels = np.ones((10, 5))
labels = np.ones((10, 5), dtype=np.int32)
self.dataset = TupleDataset(
np.random.uniform(size=(10, 3, 32, 48)),
masks, labels)
Expand Down