Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Change eval_semantic_segmentation #217

Merged
merged 17 commits into from
Jun 2, 2017
4 changes: 3 additions & 1 deletion chainercv/evaluations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from chainercv.evaluations.eval_detection_voc import eval_detection_voc # NOQA
from chainercv.evaluations.eval_pck import eval_pck # NOQA
from chainercv.evaluations.eval_semantic_segmentation import eval_semantic_segmentation # NOQA
from chainercv.evaluations.eval_semantic_segmentation_iou import calc_semantic_segmentation_confusion # NOQA
from chainercv.evaluations.eval_semantic_segmentation_iou import calc_semantic_segmentation_iou # NOQA
from chainercv.evaluations.eval_semantic_segmentation_iou import eval_semantic_segmentation_iou # NOQA
127 changes: 0 additions & 127 deletions chainercv/evaluations/eval_semantic_segmentation.py

This file was deleted.

155 changes: 155 additions & 0 deletions chainercv/evaluations/eval_semantic_segmentation_iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import division

import numpy as np
import six


def calc_semantic_segmentation_confusion(pred_labels, gt_labels):
"""Collect a confusion matrix.

The number of classes :math:`n\_class` is computed as the maximum
class id among :obj:`pred_labels` and :obj:`gt_labels`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds n_class = max(pred_labels, gt_labels). as -> from? Perhaps, an equation like n_class = max(pred_labels, gt_labels) + 1 will be helpful.


Args:
pred_labels (iterable of numpy.ndarray): A collection of predicted
labels. The shape of a label array
is :math:`(H, W)`. :math:`H` and :math:`W`
are height and width of the label.
gt_labels (iterable of numpy.ndarray): A collection of ground
truth label. The shape of a ground truth label array is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label -> labels

:math:`(H, W)`. The corresponding prediction label should
have the same shape.
A pixel with value :obj:`-1` will be ignored during evaluation.

Returns:
numpy.ndarray:
A confusion matrix. Its shape is :math:`(n\_class, n\_class)`.
The :math:`(i, j)` th element corresponds to the number of pixels
that are labeled as class :math:`i` by the ground truth and
class :math:`j` by the prediction.

"""
pred_labels = iter(pred_labels)
gt_labels = iter(gt_labels)

n_class = 0
confusion = np.zeros((n_class, n_class), dtype=np.int64)
for pred_label, gt_label in six.moves.zip(pred_labels, gt_labels):
if pred_label.ndim != 2 or gt_label.ndim != 2:
raise ValueError('ndim of inputs should be two.')
if pred_label.shape != gt_label.shape:
raise ValueError('Shapes of inputs should be same.')
pred_label = pred_label.flatten()
gt_label = gt_label.flatten()

# Dynamically expand the confusion matrix if necessary.
lb_max = np.max((pred_label, gt_label))
if lb_max >= n_class:
expanded_confusion = np.zeros
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this line?

expanded_confusion = np.zeros((lb_max + 1, lb_max + 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it is better to use dtype=np.int64 option. You are using this option at L.36 but it will be overwritten by expansion.

expanded_confusion[0:n_class, 0:n_class] = confusion

n_class = lb_max + 1
confusion = expanded_confusion

# Count statistics from valid pixels.
mask = (gt_label >= 0) & (gt_label < n_class)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(gt_label < n_class) is always true because n_class > lb_max = np.max((pred_label, gt_label)).

confusion += np.bincount(
n_class * gt_label[mask].astype(int) +
pred_label[mask], minlength=n_class**2).reshape(n_class, n_class)
Copy link
Member

@Hakuyume Hakuyume Jun 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can expand the confusion matrix dynamically. n_class can be removed.

n_class = 0
...
lb_max = gt_label.max()
if  lb_max >= n_class:
    expanded_confusion = np.zeros((lb_max + 1, lb_max + 1))
    expanded_confusion[0:n_class, 0:n_class] = confusion

    n_class = lb_max + 1
    confusion = expanded_confusion

This expansion may affect performance. But I think it will be ignorable because neural networks are much slower.
How do you think of it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That code seems simple enough.
I left it as is because I thought the code gets too complex.

I will try it.


for iter_ in (pred_labels, gt_labels):
# This code assumes any iterator does not contain None as its items.
if next(iter_, None) is not None:
raise ValueError('Length of input iterables need to be same')
return confusion


def calc_semantic_segmentation_iou(confusion):
"""Calculate Intersection over Union with a given confusion matrix.

The definition of Intersection over Union (IoU) is as follows,
where :math:`N_{ij}` is the number of pixels
that are labeled as class :math:`i` by the ground truth and
class :math:`j` by the prediction.

* :math:`\\text{IoU of the i-th class} = \
\\frac{N_{ii}}{\\sum_{j=1}^k N_{ij} + \\sum_{j=1}^k N_{ji} - N_{ii}}`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, we need denote the definition of N_{ij}.


Args:
confusion (numpy.ndarray): A confusion matrix. Its shape is
:math:`(n\_class, n\_class)`.
The :math:`(i, j)` th element corresponds to the number of pixels
that are labeled as class :math:`i` by the ground truth and
class :math:`j` by the prediction.

Returns:
numpy.ndarray:
An array of IoUs for the :math:`n\_class` classes. Its shape is
:math:`(n\_class,)`.

"""
iou_denominator = (confusion.sum(axis=1) + confusion.sum(axis=0)
- np.diag(confusion))
iou = np.diag(confusion) / iou_denominator
return iou


def eval_semantic_segmentation_iou(pred_labels, gt_labels):
"""Evaluate Intersection over Union from labels.

This function calculates Intersection over Union (IoU)
for the task of semantic segmentation.

The definition of IoU and a related metric, mean Intersection
over Union (mIoU), are as follow,
where :math:`N_{ij}` is the number of pixels
that are labeled as class :math:`i` by the ground truth and
class :math:`j` by the prediction.

* :math:`\\text{IoU of the i-th class} = \
\\frac{N_{ii}}{\\sum_{j=1}^k N_{ij} + \\sum_{j=1}^k N_{ji} - N_{ii}}`
* :math:`\\text{mIoU} = \\frac{1}{k} \
\\sum_{i=1}^k \
\\frac{N_{ii}}{\\sum_{j=1}^k N_{ij} + \\sum_{j=1}^k N_{ji} - N_{ii}}`

mIoU can be computed by taking :obj:`numpy.nanmean` of the IoUs returned
by this function.
The more detailed descriptions of the above metric can be found in a
review on semantic segmentation [#]_.

The number of classes :math:`n\_class` is computed as the maximum
class id among :obj:`pred_labels` and :obj:`gt_labels`.

.. [#] Alberto Garcia-Garcia, Sergio Orts-Escolano, Sergiu Oprea, \
Victor Villena-Martinez, Jose Garcia-Rodriguez. \
`A Review on Deep Learning Techniques Applied to Semantic Segmentation \
<https://arxiv.org/abs/1704.06857>`_. arXiv 2017.

Args:
pred_labels (iterable of numpy.ndarray): A collection of predicted
labels. The shape of a label array
is :math:`(H, W)`. :math:`H` and :math:`W`
are height and width of the label.
For example, this is a list of labels
:obj:`[label_0, label_1, ...]`, where
:obj:`label_i.shape = (H_i, W_i)`.
gt_labels (iterable of numpy.ndarray): A collection of ground
truth labels. The shape of a ground truth label array is
:math:`(H, W)`. The corresponding prediction label should
have the same shape.
A pixel with value :obj:`-1` will be ignored during evaluation.

Returns:
numpy.ndarray:
An array of IoUs for the :math:`n\_class` classes. Its shape is
:math:`(n\_class,)`.

"""
# Evaluation code is based on
# https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/
# score.py#L37
confusion = calc_semantic_segmentation_confusion(
pred_labels, gt_labels)
iou = calc_semantic_segmentation_iou(confusion)
return iou
24 changes: 2 additions & 22 deletions chainercv/links/model/pixelwise_softmax_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@

import numpy as np

from chainercv.evaluations import eval_semantic_segmentation


class PixelwiseSoftmaxClassifier(chainer.Chain):

"""A pixel-wise classifier.

It computes the loss and accuracy based on a given input/label pair for
It computes the loss based on a given input/label pair for
semantic segmentation.

Args:
Expand All @@ -23,21 +21,17 @@ class PixelwiseSoftmaxClassifier(chainer.Chain):
that contains constant weights that will be multiplied with the
loss values along with the channel dimension. This will be
used in :func:`chainer.functions.softmax_cross_entropy`.
compute_accuracy (bool): If :obj:`True`, compute accuracy on the
forward computation. The default value is :obj:`True`.

"""

def __init__(self, predictor, ignore_label=-1, class_weight=None,
compute_accuracy=True):
def __init__(self, predictor, ignore_label=-1, class_weight=None):
super(PixelwiseSoftmaxClassifier, self).__init__(predictor=predictor)
self.n_class = predictor.n_class
self.ignore_label = ignore_label
if class_weight is not None:
self.class_weight = np.asarray(class_weight, dtype=np.float32)
else:
self.class_weight = class_weight
self.compute_accuracy = compute_accuracy

def to_cpu(self):
super(PixelwiseSoftmaxClassifier, self).to_cpu()
Expand All @@ -52,8 +46,6 @@ def to_gpu(self, device=None):
def __call__(self, x, t):
"""Computes the loss value for an image and label pair.

It also computes accuracy and stores it to the attribute.

Args:
x (~chainer.Variable): A variable with a batch of images.
t (~chainer.Variable): A variable with the ground truth
Expand All @@ -69,16 +61,4 @@ def __call__(self, x, t):
ignore_label=self.ignore_label)

reporter.report({'loss': self.loss}, self)

self.accuracy = None
if self.compute_accuracy:
label = self.xp.argmax(self.y.data, axis=1)
self.accuracy = eval_semantic_segmentation(
label, t.data, self.n_class)
reporter.report({
'pixel_accuracy': self.xp.mean(self.accuracy[0]),
'mean_accuracy': self.xp.mean(self.accuracy[1]),
'mean_iou': self.xp.mean(self.accuracy[2]),
'fw_iou': self.xp.mean(self.accuracy[3])
}, self)
return self.loss
15 changes: 12 additions & 3 deletions docs/source/reference/evaluations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ eval_pck
~~~~~~~~
.. autofunction:: eval_pck

eval_semantic_segmentation
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: eval_semantic_segmentation

eval_semantic_segmentation_iou
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: eval_semantic_segmentation_iou

calc_semantic_segmentation_confusion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not an alphabetical order. I guess you put these functions here because they are helper functions of eval_semantic_segmentation_iou. How about nesting these items to show the hierarchy more explicitly?

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: calc_semantic_segmentation_confusion

calc_semantic_segmentation_iou
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: calc_semantic_segmentation_iou
Loading