-
Notifications
You must be signed in to change notification settings - Fork 302
Add SemanticSegmentationEvaluator #238
Changes from 17 commits
2ccc93f
1b0348a
9506e19
6b5c270
b68b2e8
3b298b2
2def1dc
a5d64eb
3fd76c1
65a431b
e5954e9
eb8eebc
7e7cab5
002a4b5
c048884
9fef692
7273241
46e4add
41be470
60d42f7
ac50340
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from chainercv.extensions.detection.detection_vis_report import DetectionVisReport # NOQA | ||
from chainercv.extensions.detection.detection_voc_evaluator import DetectionVOCEvaluator # NOQA | ||
from chainercv.extensions.evaluator.detection_voc_evaluator import DetectionVOCEvaluator # NOQA | ||
from chainercv.extensions.evaluator.semantic_segmentation_evaluator import SemanticSegmentationEvaluator # NOQA | ||
from chainercv.extensions.vis_report.detection_vis_report import DetectionVisReport # NOQA |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import copy | ||
import numpy as np | ||
|
||
from chainer import reporter | ||
import chainer.training.extensions | ||
|
||
from chainercv.evaluations import eval_semantic_segmentation | ||
from chainercv.utils import apply_prediction_to_iterator | ||
|
||
|
||
class SemanticSegmentationEvaluator(chainer.training.extensions.Evaluator): | ||
|
||
"""An extension that evaluates a semantic segmentation model. | ||
|
||
This extension iterates over an iterator and evaluates the prediction | ||
results of the model by common evaluation metrics for semantic | ||
segmentation. | ||
This extension reports values with keys below. | ||
Please note that :obj:`'iou/<label_names[l]>'` and | ||
:obj:`'class_accuracy/<label_names[l]>'` are reported only if | ||
:obj:`label_names` is specified. | ||
|
||
* :obj:`'miou'`: Mean of IoUs (mIoU). | ||
* :obj:`'iou/<label_names[l]>'`: IoU for class \ | ||
:obj:`label_names[l]`, where :math:`l` is the index of the class. \ | ||
For example, if :obj:`label_names` is \ | ||
:obj:`~chainercv.datasets.camvid_label_names`, \ | ||
this evaluator reports :obj:`'iou/Sky'`, \ | ||
:obj:`'ap/Building'`, etc. | ||
* :obj:`'mean_class_accuracy'`: Mean of class accuracies. | ||
* :obj:`'class_accuracy/<label_names[l]>'`: Class accuracy for class \ | ||
:obj:`label_names[l]`, where :math:`l` is the index of the class. | ||
* :obj:`'pixel_accuracy'`: Pixel accuracy. | ||
|
||
If there is no label assigned to class :obj:`label_names[l]` | ||
in the ground truth, values corresponding to keys | ||
:obj:`'iou/<label_names[l]>'` and :obj:`'class_accuracy/<label_names[l]>'` | ||
are :obj:`numpy.nan`. | ||
In that case, the means of them are calculated by excluding them from | ||
calculation. | ||
|
||
For details on the evaluation metrics, please see the documentation | ||
for :func:`chainercv.evaluations.eval_semantic_segmentation`. | ||
|
||
.. seealso:: | ||
:func:`chainercv.evaluations.eval_semantic_segmentation`. | ||
|
||
Args: | ||
iterator (chainer.Iterator): An iterator. Each sample should be | ||
following tuple :obj:`img, label`. | ||
:obj:`img` is an image, :obj:`label` is pixel-wise label. | ||
target (chainer.Link): A semantic segmentation link. This link should | ||
have :meth:`predict` method which takes a list of images and | ||
returns :obj:`labels`. | ||
label_names (iterable of strings): An iterable of names of classes. | ||
If this value is specified, IoU and class accuracy for each class | ||
is also reported with the keys | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
:obj:`'iou/<label_names[l]>'` and | ||
:obj:`'class_accuracy/<label_names[l]>'`. | ||
|
||
""" | ||
|
||
trigger = 1, 'epoch' | ||
default_name = 'validation' | ||
priority = chainer.training.PRIORITY_WRITER | ||
|
||
def __init__(self, iterator, target, label_names=None): | ||
super(SemanticSegmentationEvaluator, self).__init__( | ||
iterator, target) | ||
self.label_names = label_names | ||
|
||
def evaluate(self): | ||
iterator = self._iterators['main'] | ||
target = self._targets['main'] | ||
|
||
if hasattr(iterator, 'reset'): | ||
iterator.reset() | ||
it = iterator | ||
else: | ||
it = copy.copy(iterator) | ||
|
||
imgs, pred_values, gt_values = apply_prediction_to_iterator( | ||
target.predict, it) | ||
# delete unused iterator explicitly | ||
del imgs | ||
|
||
pred_labels, = pred_values | ||
gt_labels, = gt_values | ||
|
||
result = eval_semantic_segmentation(pred_labels, gt_labels) | ||
|
||
report = {'miou': result['miou'], | ||
'pixel_accuracy': result['pixel_accuracy'], | ||
'mean_class_accuracy': result['mean_class_accuracy']} | ||
|
||
if self.label_names is not None: | ||
for l, label_name in enumerate(self.label_names): | ||
try: | ||
report['iou/{:s}'.format(label_name)] = result['iou'][l] | ||
report['class_accuracy/{:s}'.format(label_name)] =\ | ||
result['class_accuracy'][l] | ||
except IndexError: | ||
report['iou/{:s}'.format(label_name)] = np.nan | ||
report['class_accuracy/{:s}'.format(label_name)] = np.nan | ||
|
||
observation = {} | ||
with reporter.report_scope(observation): | ||
reporter.report(report, target) | ||
return observation |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,21 @@ Extensions | |
.. module:: chainercv.extensions | ||
|
||
|
||
Detection | ||
Evaluator | ||
--------- | ||
|
||
DetectionVisReport | ||
~~~~~~~~~~~~~~~~~~ | ||
.. autofunction:: DetectionVisReport | ||
|
||
DetectionVOCEvaluator | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
.. autofunction:: DetectionVOCEvaluator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is a class (not a function) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We've been using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They should be fixed too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
SemanticSegmentationEvaluator | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
.. autofunction:: SemanticSegmentationEvaluator | ||
|
||
|
||
Visualization Report | ||
-------------------- | ||
|
||
DetectionVisReport | ||
~~~~~~~~~~~~~~~~~~ | ||
.. autofunction:: DetectionVisReport |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import numpy as np | ||
import unittest | ||
|
||
import chainer | ||
from chainer.datasets import TupleDataset | ||
from chainer.iterators import SerialIterator | ||
from chainer import testing | ||
|
||
from chainercv.extensions import SemanticSegmentationEvaluator | ||
|
||
|
||
class _SemanticSegmentationStubLink(chainer.Link): | ||
|
||
def __init__(self, labels): | ||
super(_SemanticSegmentationStubLink, self).__init__() | ||
self.count = 0 | ||
self.labels = labels | ||
|
||
def predict(self, imgs): | ||
n_img = len(imgs) | ||
labels = self.labels[self.count:self.count + n_img] | ||
|
||
self.count += n_img | ||
return labels | ||
|
||
|
||
class TestSemanticSegmentationEvaluator(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.label_names = ('a', 'b', 'c') | ||
imgs = np.random.uniform(size=(10, 3, 5, 5)) | ||
# There are labels for 'a' and 'b', but none for 'c'. | ||
labels = np.random.randint( | ||
low=0, high=2, size=(10, 5, 5), dtype=np.int32) | ||
self.dataset = TupleDataset(imgs, labels) | ||
self.link = _SemanticSegmentationStubLink(labels) | ||
self.iterator = SerialIterator( | ||
self.dataset, 5, repeat=False, shuffle=False) | ||
self.evaluator = SemanticSegmentationEvaluator( | ||
self.iterator, self.link, self.label_names) | ||
|
||
def test_evaluate(self): | ||
reporter = chainer.Reporter() | ||
reporter.add_observer('target', self.link) | ||
with reporter: | ||
eval_ = self.evaluator.evaluate() | ||
|
||
# No observation is reported to the current reporter. Instead the | ||
# evaluator collect results in order to calculate their mean. | ||
np.testing.assert_equal(len(reporter.observation), 0) | ||
|
||
np.testing.assert_equal(eval_['target/miou'], 1.) | ||
np.testing.assert_equal(eval_['target/pixel_accuracy'], 1.) | ||
np.testing.assert_equal(eval_['target/mean_class_accuracy'], 1.) | ||
np.testing.assert_equal(eval_['target/iou/a'], 1.) | ||
np.testing.assert_equal(eval_['target/iou/b'], 1.) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you make |
||
np.testing.assert_equal(eval_['target/iou/c'], np.nan) | ||
np.testing.assert_equal(eval_['target/class_accuracy/a'], 1.) | ||
np.testing.assert_equal(eval_['target/class_accuracy/b'], 1.) | ||
np.testing.assert_equal(eval_['target/class_accuracy/c'], np.nan) | ||
|
||
def test_call(self): | ||
eval_ = self.evaluator() | ||
# main is used as default | ||
np.testing.assert_equal(eval_['main/miou'], 1.) | ||
np.testing.assert_equal(eval_['main/pixel_accuracy'], 1.) | ||
np.testing.assert_equal(eval_['main/mean_class_accuracy'], 1.) | ||
np.testing.assert_equal(eval_['main/iou/a'], 1.) | ||
np.testing.assert_equal(eval_['main/iou/b'], 1.) | ||
np.testing.assert_equal(eval_['main/iou/c'], np.nan) | ||
np.testing.assert_equal(eval_['main/class_accuracy/a'], 1.) | ||
np.testing.assert_equal(eval_['main/class_accuracy/b'], 1.) | ||
np.testing.assert_equal(eval_['main/class_accuracy/c'], np.nan) | ||
|
||
def test_evaluator_name(self): | ||
self.evaluator.name = 'eval' | ||
eval_ = self.evaluator() | ||
# name is used as a prefix | ||
np.testing.assert_equal(eval_['eval/main/miou'], 1.) | ||
np.testing.assert_equal(eval_['eval/main/pixel_accuracy'], 1.) | ||
np.testing.assert_equal(eval_['eval/main/mean_class_accuracy'], 1.) | ||
np.testing.assert_equal(eval_['eval/main/iou/a'], 1.) | ||
np.testing.assert_equal(eval_['eval/main/iou/b'], 1.) | ||
np.testing.assert_equal(eval_['eval/main/iou/c'], np.nan) | ||
np.testing.assert_equal(eval_['eval/main/class_accuracy/a'], 1.) | ||
np.testing.assert_equal(eval_['eval/main/class_accuracy/b'], 1.) | ||
np.testing.assert_equal(eval_['eval/main/class_accuracy/c'], np.nan) | ||
|
||
def test_current_report(self): | ||
reporter = chainer.Reporter() | ||
with reporter: | ||
eval_ = self.evaluator() | ||
# The result is reported to the current reporter. | ||
np.testing.assert_equal(reporter.observation, eval_) | ||
|
||
|
||
testing.run_module(__name__, __file__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which
->that
? #229 (comment)