diff --git a/chainercv/extensions/__init__.py b/chainercv/extensions/__init__.py index 2abdb09a67..5f864ab80b 100644 --- a/chainercv/extensions/__init__.py +++ b/chainercv/extensions/__init__.py @@ -1,2 +1,3 @@ -from chainercv.extensions.detection.detection_vis_report import DetectionVisReport # NOQA -from chainercv.extensions.detection.detection_voc_evaluator import DetectionVOCEvaluator # NOQA +from chainercv.extensions.evaluator.detection_voc_evaluator import DetectionVOCEvaluator # NOQA +from chainercv.extensions.evaluator.semantic_segmentation_evaluator import SemanticSegmentationEvaluator # NOQA +from chainercv.extensions.vis_report.detection_vis_report import DetectionVisReport # NOQA diff --git a/chainercv/extensions/detection/__init__.py b/chainercv/extensions/evaluator/__init__.py similarity index 100% rename from chainercv/extensions/detection/__init__.py rename to chainercv/extensions/evaluator/__init__.py diff --git a/chainercv/extensions/detection/detection_voc_evaluator.py b/chainercv/extensions/evaluator/detection_voc_evaluator.py similarity index 98% rename from chainercv/extensions/detection/detection_voc_evaluator.py rename to chainercv/extensions/evaluator/detection_voc_evaluator.py index b2255d6d40..8a7936f7bb 100644 --- a/chainercv/extensions/detection/detection_voc_evaluator.py +++ b/chainercv/extensions/evaluator/detection_voc_evaluator.py @@ -40,7 +40,7 @@ class DetectionVOCEvaluator(chainer.training.extensions.Evaluator): not. If :obj:`difficult` is returned, difficult ground truth will be ignored from evaluation. target (chainer.Link): A detection link. This link must have - :meth:`predict` method which takes a list of images and returns + :meth:`predict` method that takes a list of images and returns :obj:`bboxes`, :obj:`labels` and :obj:`scores`. use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric for calculating average precision. The default value is diff --git a/chainercv/extensions/evaluator/semantic_segmentation_evaluator.py b/chainercv/extensions/evaluator/semantic_segmentation_evaluator.py new file mode 100644 index 0000000000..5b39dfa26d --- /dev/null +++ b/chainercv/extensions/evaluator/semantic_segmentation_evaluator.py @@ -0,0 +1,109 @@ +import copy +import numpy as np + +from chainer import reporter +import chainer.training.extensions + +from chainercv.evaluations import eval_semantic_segmentation +from chainercv.utils import apply_prediction_to_iterator + + +class SemanticSegmentationEvaluator(chainer.training.extensions.Evaluator): + + """An extension that evaluates a semantic segmentation model. + + This extension iterates over an iterator and evaluates the prediction + results of the model by common evaluation metrics for semantic + segmentation. + This extension reports values with keys below. + Please note that :obj:`'iou/'` and + :obj:`'class_accuracy/'` are reported only if + :obj:`label_names` is specified. + + * :obj:`'miou'`: Mean of IoUs (mIoU). + * :obj:`'iou/'`: IoU for class \ + :obj:`label_names[l]`, where :math:`l` is the index of the class. \ + For example, if :obj:`label_names` is \ + :obj:`~chainercv.datasets.camvid_label_names`, \ + this evaluator reports :obj:`'iou/Sky'`, \ + :obj:`'ap/Building'`, etc. + * :obj:`'mean_class_accuracy'`: Mean of class accuracies. + * :obj:`'class_accuracy/'`: Class accuracy for class \ + :obj:`label_names[l]`, where :math:`l` is the index of the class. + * :obj:`'pixel_accuracy'`: Pixel accuracy. + + If there is no label assigned to class :obj:`label_names[l]` + in the ground truth, values corresponding to keys + :obj:`'iou/'` and :obj:`'class_accuracy/'` + are :obj:`numpy.nan`. + In that case, the means of them are calculated by excluding them from + calculation. + + For details on the evaluation metrics, please see the documentation + for :func:`chainercv.evaluations.eval_semantic_segmentation`. + + .. seealso:: + :func:`chainercv.evaluations.eval_semantic_segmentation`. + + Args: + iterator (chainer.Iterator): An iterator. Each sample should be + following tuple :obj:`img, label`. + :obj:`img` is an image, :obj:`label` is pixel-wise label. + target (chainer.Link): A semantic segmentation link. This link should + have :meth:`predict` method that takes a list of images and + returns :obj:`labels`. + label_names (iterable of strings): An iterable of names of classes. + If this value is specified, IoU and class accuracy for each class + are also reported with the keys + :obj:`'iou/'` and + :obj:`'class_accuracy/'`. + + """ + + trigger = 1, 'epoch' + default_name = 'validation' + priority = chainer.training.PRIORITY_WRITER + + def __init__(self, iterator, target, label_names=None): + super(SemanticSegmentationEvaluator, self).__init__( + iterator, target) + self.label_names = label_names + + def evaluate(self): + iterator = self._iterators['main'] + target = self._targets['main'] + + if hasattr(iterator, 'reset'): + iterator.reset() + it = iterator + else: + it = copy.copy(iterator) + + imgs, pred_values, gt_values = apply_prediction_to_iterator( + target.predict, it) + # delete unused iterator explicitly + del imgs + + pred_labels, = pred_values + gt_labels, = gt_values + + result = eval_semantic_segmentation(pred_labels, gt_labels) + + report = {'miou': result['miou'], + 'pixel_accuracy': result['pixel_accuracy'], + 'mean_class_accuracy': result['mean_class_accuracy']} + + if self.label_names is not None: + for l, label_name in enumerate(self.label_names): + try: + report['iou/{:s}'.format(label_name)] = result['iou'][l] + report['class_accuracy/{:s}'.format(label_name)] =\ + result['class_accuracy'][l] + except IndexError: + report['iou/{:s}'.format(label_name)] = np.nan + report['class_accuracy/{:s}'.format(label_name)] = np.nan + + observation = {} + with reporter.report_scope(observation): + reporter.report(report, target) + return observation diff --git a/chainercv/extensions/vis_report/__init__.py b/chainercv/extensions/vis_report/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/chainercv/extensions/detection/detection_vis_report.py b/chainercv/extensions/vis_report/detection_vis_report.py similarity index 100% rename from chainercv/extensions/detection/detection_vis_report.py rename to chainercv/extensions/vis_report/detection_vis_report.py diff --git a/docs/source/reference/extensions.rst b/docs/source/reference/extensions.rst index 2189eb1842..066f0fcac1 100644 --- a/docs/source/reference/extensions.rst +++ b/docs/source/reference/extensions.rst @@ -4,13 +4,21 @@ Extensions .. module:: chainercv.extensions -Detection +Evaluator --------- -DetectionVisReport -~~~~~~~~~~~~~~~~~~ -.. autofunction:: DetectionVisReport - DetectionVOCEvaluator ~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: DetectionVOCEvaluator +.. autoclass:: DetectionVOCEvaluator + +SemanticSegmentationEvaluator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SemanticSegmentationEvaluator + + +Visualization Report +-------------------- + +DetectionVisReport +~~~~~~~~~~~~~~~~~~ +.. autoclass:: DetectionVisReport diff --git a/examples/segnet/train.py b/examples/segnet/train.py index 8b21278d7c..249e473641 100644 --- a/examples/segnet/train.py +++ b/examples/segnet/train.py @@ -14,22 +14,14 @@ from chainer import training from chainer.training import extensions +from chainercv.datasets import camvid_label_names from chainercv.datasets import CamVidDataset from chainercv.datasets import TransformDataset +from chainercv.extensions import SemanticSegmentationEvaluator from chainercv.links import PixelwiseSoftmaxClassifier from chainercv.links import SegNetBasic -class TestModeEvaluator(extensions.Evaluator): - - def evaluate(self): - model = self.get_target('main') - model.train = False - ret = super(TestModeEvaluator, self).evaluate() - model.train = True - return ret - - def main(): parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=int, default=-1) @@ -84,24 +76,32 @@ def transform(in_data): trainer.extend(extensions.LogReport(trigger=log_trigger)) trainer.extend(extensions.observe_lr(), trigger=log_trigger) trainer.extend(extensions.dump_graph('main/loss')) - trainer.extend(TestModeEvaluator(val_iter, model, - device=args.gpu), - trigger=validation_trigger) if extensions.PlotReport.available(): trainer.extend(extensions.PlotReport( - ['main/loss', 'validation/main/loss'], x_key='iteration', + ['main/loss'], x_key='iteration', file_name='loss.png')) + trainer.extend(extensions.PlotReport( + ['validation/main/miou'], x_key='iteration', + file_name='miou.png')) trainer.extend(extensions.snapshot_object( model.predictor, filename='model_iteration-{.updater.iteration}'), trigger=end_trigger) trainer.extend(extensions.PrintReport( ['epoch', 'iteration', 'elapsed_time', 'lr', - 'main/loss', 'validation/main/loss']), + 'main/loss', 'validation/main/miou', + 'validation/main/mean_class_accuracy', + 'validation/main/pixel_accuracy']), trigger=log_trigger) trainer.extend(extensions.ProgressBar(update_interval=10)) + trainer.extend( + SemanticSegmentationEvaluator( + val_iter, model.predictor, + camvid_label_names), + trigger=validation_trigger) + trainer.run() diff --git a/tests/extensions_tests/detection_tests/test_detection_voc_evaluator.py b/tests/extensions_tests/evaluator_tests/test_detection_voc_evaluator.py similarity index 100% rename from tests/extensions_tests/detection_tests/test_detection_voc_evaluator.py rename to tests/extensions_tests/evaluator_tests/test_detection_voc_evaluator.py diff --git a/tests/extensions_tests/evaluator_tests/test_semantic_segmentation_evaluator.py b/tests/extensions_tests/evaluator_tests/test_semantic_segmentation_evaluator.py new file mode 100644 index 0000000000..b038e43310 --- /dev/null +++ b/tests/extensions_tests/evaluator_tests/test_semantic_segmentation_evaluator.py @@ -0,0 +1,121 @@ +from __future__ import division + +import numpy as np +import unittest + +import chainer +from chainer.datasets import TupleDataset +from chainer.iterators import SerialIterator +from chainer import testing + +from chainercv.extensions import SemanticSegmentationEvaluator + + +class _SemanticSegmentationStubLink(chainer.Link): + + def __init__(self, labels): + super(_SemanticSegmentationStubLink, self).__init__() + self.count = 0 + self.labels = labels + + def predict(self, imgs): + n_img = len(imgs) + labels = self.labels[self.count:self.count + n_img] + + self.count += n_img + return labels + + +class TestSemanticSegmentationEvaluator(unittest.TestCase): + + def setUp(self): + self.label_names = ('a', 'b', 'c') + imgs = np.random.uniform(size=(1, 3, 2, 3)) + # There are labels for 'a' and 'b', but none for 'c'. + pred_labels = np.array([[[1, 1, 1], [0, 0, 1]]]) + gt_labels = np.array([[[1, 0, 0], [0, -1, 1]]]) + + self.iou_a = 1 / 3 + self.iou_b = 2 / 4 + self.pixel_accuracy = 3 / 5 + self.class_accuracy_a = 1 / 3 + self.class_accuracy_b = 2 / 2 + self.miou = np.mean((self.iou_a, self.iou_b)) + self.mean_class_accuracy = np.mean( + (self.class_accuracy_a, self.class_accuracy_b)) + + self.dataset = TupleDataset(imgs, gt_labels) + self.link = _SemanticSegmentationStubLink(pred_labels) + self.iterator = SerialIterator( + self.dataset, 5, repeat=False, shuffle=False) + self.evaluator = SemanticSegmentationEvaluator( + self.iterator, self.link, self.label_names) + + def test_evaluate(self): + reporter = chainer.Reporter() + reporter.add_observer('main', self.link) + with reporter: + eval_ = self.evaluator.evaluate() + + # No observation is reported to the current reporter. Instead the + # evaluator collect results in order to calculate their mean. + np.testing.assert_equal(len(reporter.observation), 0) + + np.testing.assert_equal(eval_['main/miou'], self.miou) + np.testing.assert_equal(eval_['main/pixel_accuracy'], + self.pixel_accuracy) + np.testing.assert_equal(eval_['main/mean_class_accuracy'], + self.mean_class_accuracy) + np.testing.assert_equal(eval_['main/iou/a'], self.iou_a) + np.testing.assert_equal(eval_['main/iou/b'], self.iou_b) + np.testing.assert_equal(eval_['main/iou/c'], np.nan) + np.testing.assert_equal(eval_['main/class_accuracy/a'], + self.class_accuracy_a) + np.testing.assert_equal(eval_['main/class_accuracy/b'], + self.class_accuracy_b) + np.testing.assert_equal(eval_['main/class_accuracy/c'], np.nan) + + def test_call(self): + eval_ = self.evaluator() + # main is used as default + np.testing.assert_equal(eval_['main/miou'], self.miou) + np.testing.assert_equal(eval_['main/pixel_accuracy'], + self.pixel_accuracy) + np.testing.assert_equal(eval_['main/mean_class_accuracy'], + self.mean_class_accuracy) + np.testing.assert_equal(eval_['main/iou/a'], self.iou_a) + np.testing.assert_equal(eval_['main/iou/b'], self.iou_b) + np.testing.assert_equal(eval_['main/iou/c'], np.nan) + np.testing.assert_equal(eval_['main/class_accuracy/a'], + self.class_accuracy_a) + np.testing.assert_equal(eval_['main/class_accuracy/b'], + self.class_accuracy_b) + np.testing.assert_equal(eval_['main/class_accuracy/c'], np.nan) + + def test_evaluator_name(self): + self.evaluator.name = 'eval' + eval_ = self.evaluator() + # name is used as a prefix + np.testing.assert_equal(eval_['eval/main/miou'], self.miou) + np.testing.assert_equal(eval_['eval/main/pixel_accuracy'], + self.pixel_accuracy) + np.testing.assert_equal(eval_['eval/main/mean_class_accuracy'], + self.mean_class_accuracy) + np.testing.assert_equal(eval_['eval/main/iou/a'], self.iou_a) + np.testing.assert_equal(eval_['eval/main/iou/b'], self.iou_b) + np.testing.assert_equal(eval_['eval/main/iou/c'], np.nan) + np.testing.assert_equal(eval_['eval/main/class_accuracy/a'], + self.class_accuracy_a) + np.testing.assert_equal(eval_['eval/main/class_accuracy/b'], + self.class_accuracy_b) + np.testing.assert_equal(eval_['eval/main/class_accuracy/c'], np.nan) + + def test_current_report(self): + reporter = chainer.Reporter() + with reporter: + eval_ = self.evaluator() + # The result is reported to the current reporter. + np.testing.assert_equal(reporter.observation, eval_) + + +testing.run_module(__name__, __file__) diff --git a/tests/extensions_tests/detection_tests/test_detection_vis_report.py b/tests/extensions_tests/vis_report_tests/test_detection_vis_report.py similarity index 100% rename from tests/extensions_tests/detection_tests/test_detection_vis_report.py rename to tests/extensions_tests/vis_report_tests/test_detection_vis_report.py