Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Change eval_semantic_segmentation #217

Merged
merged 17 commits into from
Jun 2, 2017
Prev Previous commit
Next Next commit
change interface to accept iterator
yuyu2172 committed May 30, 2017
commit 8b92dc89576f8b0505c018a277d5159a4c7f75a6
71 changes: 34 additions & 37 deletions chainercv/evaluations/eval_semantic_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
from __future__ import division

import numpy as np
import six


def calc_confusion_matrix(pred_labels, gt_labels, n_class):
def calc_confusion_matrix(pred_label, gt_label, n_class):
"""Collect confusion matrix.

Args:
pred_labels (iterable of numpy.ndarray): A collection of predicted
labels. This is a batch of labels whose shape is :math:`(N, H, W)`
or a list containing :math:`N` labels. The shape of a label array
pred_label (numpy.ndarray): A predicted label.
The shape of a label array
is :math:`(H, W)`. :math:`H` and :math:`W`
are height and width of the label. We assume that there are
:math:`N` labels.
gt_labels (iterable of numpy.ndarray): A collection of the ground
truth labels.
It is organized similarly to :obj:`pred_labels`. A pixel with value
"-1" will be ignored during evaluation.
are height and width of the label.
gt_label (numpy.ndarray): The ground truth label.
Its shape is :math:`(H, W)`.
A pixel with value "-1" will be ignored during evaluation.
n_class (int): The number of classes.

Returns:
@@ -30,32 +26,21 @@ def calc_confusion_matrix(pred_labels, gt_labels, n_class):
# Evaluation code is based on
# https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/
# score.py#L37
if (isinstance(pred_labels, np.ndarray) and pred_labels.ndim != 3
or isinstance(gt_labels, np.ndarray) and gt_labels.ndim != 3):
raise ValueError('If batch of arrays are given, they have '
'to have dimension 3')
N = len(pred_labels)

if len(pred_labels) != len(gt_labels):
raise ValueError('Number of the predicted labels and the'
'ground truth labels are different')
for i in six.moves.range(N):
if pred_labels[i].shape != gt_labels[i].shape:
raise ValueError('Shape of the prediction and'
'the ground truth should match')

confusion = np.zeros((n_class, n_class), dtype=np.int64)
for i in six.moves.range(N):
pred_label = pred_labels[i].flatten()
gt_label = gt_labels[i].flatten()
mask = (gt_label >= 0) & (gt_label < n_class)
confusion += np.bincount(
n_class * gt_label[mask].astype(int) +
pred_label[mask], minlength=n_class**2).reshape(n_class, n_class)
if pred_label.ndim != 2 or gt_label.ndim != 2:
raise ValueError('ndim of inputs should be two.')
if pred_label.shape != gt_label.shape:
raise ValueError('Shapes of inputs should be same.')

pred_label = pred_label.flatten()
gt_label = gt_label.flatten()
mask = (gt_label >= 0) & (gt_label < n_class)
confusion = np.bincount(
n_class * gt_label[mask].astype(int) +
pred_label[mask], minlength=n_class**2).reshape(n_class, n_class)
return confusion


def eval_semantic_segmentation(confusion):
def eval_semantic_segmentation(pred_labels, gt_labels, n_class):
"""Evaluate results of semantic segmentation.

This function calculates Intersection over Union (IoU).
@@ -82,16 +67,28 @@ def eval_semantic_segmentation(confusion):
<https://arxiv.org/abs/1704.06857>`_. arXiv 2017.

Args:
confusion (numpy.ndarray): Confusion matrix calculated by
:func:`chainercv.evaluations.calc_confusion_matrix`.
Its shape is :math:`(n\_class, n\_class)`.
pred_labels (iterator or iterable of numpy.ndarray):
gt_labels (iterator or iterable of numpy.ndarray):
n_class (int): The number of classes.

Returns:
numpy.ndarray:
IoUs computed from the given confusion matrix.
Its shape is :math:`(n\_class,)`.

"""
pred_labels = iter(pred_labels)
gt_labels = iter(gt_labels)

confusion = np.zeros((n_class, n_class), dtype=np.int64)
while True:
try:
pred_label = next(pred_labels)
gt_label = next(gt_labels)
except StopIteration:
break
confusion += calc_confusion_matrix(pred_label, gt_label, n_class)

iou_denominator = (confusion.sum(axis=1) + confusion.sum(axis=0)
- np.diag(confusion))
iou = np.diag(confusion) / iou_denominator
36 changes: 20 additions & 16 deletions tests/evaluations_tests/test_eval_semantic_segmentation.py
Original file line number Diff line number Diff line change
@@ -11,44 +11,48 @@
@testing.parameterize(
{'pred_labels': np.repeat([[[1, 1, 0], [0, 0, 1]]], 2, axis=0),
'gt_labels': np.repeat([[[1, 0, 0], [0, -1, 1]]], 2, axis=0),
'confusion': np.array([[4, 2], [0, 4]])
'iou': np.array([4. / 6., 4. / 6.])
},
{'pred_labels': [np.array([[1, 1, 0], [0, 0, 1]]),
np.array([[1, 1, 0], [0, 0, 1]])],
'gt_labels': [np.array([[1, 0, 0], [0, -1, 1]]),
np.array([[1, 0, 0], [0, -1, 1]])],
'confusion': np.array([[4, 2], [0, 4]])
'iou': np.array([4. / 6., 4. / 6.])
},
{'pred_labels': np.array([[[0, 0, 0], [0, 0, 0]]]),
'gt_labels': np.array([[[1, 1, 1], [1, 1, 1]]]),
'confusion': np.array([[0., 0], [6, 0]]),
'iou': np.array([0, 0]),
}
)
class TestCalcConfusionMatrix(unittest.TestCase):
class TestEvalSemanticSegmentation(unittest.TestCase):

n_class = 2

def test_calc_confusion_matrix(self):
confusion = calc_confusion_matrix(
self.pred_labels, self.gt_labels, self.n_class)

self.assertIsInstance(confusion, np.ndarray)
np.testing.assert_equal(confusion, self.confusion)
iou = eval_semantic_segmentation(
self.pred_labels, self.gt_labels, self.n_class)
np.testing.assert_equal(iou, self.iou)


@testing.parameterize(
{'confusion': np.array([[4, 2], [0, 4]]),
'iou': np.array([4. / 6., 4. / 6.])
{'pred_label': np.array([[1, 1, 0], [0, 0, 1]]),
'gt_label': np.array([[1, 0, 0], [0, -1, 1]]),
'confusion': np.array([[2, 1], [0, 2]])
},
{'confusion': np.array([[0, 0], [6, 0]]),
'iou': np.array([0, 0]),
{'pred_label': np.array([[0, 0, 0], [0, 0, 0]]),
'gt_label': np.array([[1, 1, 1], [1, 1, -1]]),
'confusion': np.array([[0, 0], [5, 0]])
}
)
class TestEvalSemanticSegmentation(unittest.TestCase):
class TestCalcConfusionMatrix(unittest.TestCase):

def test_eval_semantic_segmentation(self):
iou = eval_semantic_segmentation(self.confusion)
np.testing.assert_equal(iou, self.iou)
n_class = 2

def test_calc_confusion_matrix(self):
confusion = calc_confusion_matrix(
self.pred_label, self.gt_label, self.n_class)
np.testing.assert_equal(confusion, self.confusion)


testing.run_module(__name__, __file__)