Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Add VOC and SBD instance segmentation dataset #540

Merged
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ For additional features

+ Matplotlib
+ OpenCV
+ SciPy

Environments under Python 2.7.12 and 3.6.0 are tested.

Expand Down
4 changes: 4 additions & 0 deletions chainercv/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
from chainercv.datasets.directory_parsing_label_dataset import DirectoryParsingLabelDataset # NOQA
from chainercv.datasets.online_products.online_products_dataset import online_products_super_label_names # NOQA
from chainercv.datasets.online_products.online_products_dataset import OnlineProductsDataset # NOQA
from chainercv.datasets.sbd.sbd_instance_segmentation_dataset import SBDInstanceSegmentationDataset # NOQA
from chainercv.datasets.sbd.sbd_utils import sbd_instance_segmentation_label_names # NOQA
from chainercv.datasets.siamese_dataset import SiameseDataset # NOQA
from chainercv.datasets.transform_dataset import TransformDataset # NOQA
from chainercv.datasets.voc.voc_bbox_dataset import VOCBboxDataset # NOQA
from chainercv.datasets.voc.voc_instance_segmentation_dataset import VOCInstanceSegmentationDataset # NOQA
from chainercv.datasets.voc.voc_semantic_segmentation_dataset import VOCSemanticSegmentationDataset # NOQA
from chainercv.datasets.voc.voc_utils import voc_bbox_label_names # NOQA
from chainercv.datasets.voc.voc_utils import voc_instance_segmentation_label_names # NOQA
from chainercv.datasets.voc.voc_utils import voc_semantic_segmentation_ignore_label_color # NOQA
from chainercv.datasets.voc.voc_utils import voc_semantic_segmentation_label_colors # NOQA
from chainercv.datasets.voc.voc_utils import voc_semantic_segmentation_label_names # NOQA
Empty file.
103 changes: 103 additions & 0 deletions chainercv/datasets/sbd/sbd_instance_segmentation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np
import os
import warnings

import chainer

from chainercv.datasets.sbd import sbd_utils
from chainercv.datasets.voc import voc_utils
from chainercv.utils import read_image

try:
import scipy
_available = True
except ImportError:
_available = False


def _check_available():
if not _available:
warnings.warn(
'SciPy is not installed in your environment,',
'so the dataset cannot be loaded.'
'Please install SciPy to load dataset.\n\n'
'$ pip install scipy')


class SBDInstanceSegmentationDataset(chainer.dataset.DatasetMixin):

"""Instance segmentation dataset for Semantic Boundaries Dataset `SBD`_.

The class name of the label :math:`l` is :math:`l` th element of
:obj:`chainercv.datasets.sbd_instance_segmentation_label_names`.

.. _`SBD`: http://home.bharathh.info/pubs/codes/SBD/download.html

Args:
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/sbd`.
split ({'train', 'val', 'trainval'}): Select a split of the dataset.

"""

def __init__(self, data_dir='auto', split='train'):
_check_available()

if split not in ['train', 'trainval', 'val']:
raise ValueError(
'please pick split from \'train\', \'trainval\', \'val\'')

if data_dir == 'auto':
data_dir = sbd_utils.get_sbd()

id_list_file = os.path.join(
data_dir, '{}_voc2012.txt'.format(split))
self.ids = [id_.strip() for id_ in open(id_list_file)]

self.data_dir = data_dir

def __len__(self):
return len(self.ids)

def get_example(self, i):
"""Returns the i-th example.

Returns a color image, bounding boxes, masks and labels. The color
image is in CHW format.

Args:
i (int): The index of the example.

Returns:
A tuple of color image, bounding boxes, masks and labels whose
shapes are :math:`(3, H, W), (R, 4), (R, H, W), (R, )`
respectively.
:math:`H` and :math:`W` are height and width of the images,
and :math:`R` is the number of objects in the image.
The dtype of the color image and the bounding boxes are
:obj:`numpy.float32`, that of the masks is :obj: `numpy.bool`,
and that of the labels is :obj:`numpy.int32`.

"""
data_id = self.ids[i]
img_file = os.path.join(
self.data_dir, 'img', data_id + '.jpg')
img = read_image(img_file, color=True)
label_img, inst_img = self._load_label_inst(data_id)
bbox, mask, label = voc_utils.image_wise_to_instance_wise(
label_img, inst_img)
return img, bbox, mask, label

def _load_label_inst(self, data_id):
label_file = os.path.join(
self.data_dir, 'cls', data_id + '.mat')
inst_file = os.path.join(
self.data_dir, 'inst', data_id + '.mat')
label_anno = scipy.io.loadmat(label_file)
label_img = label_anno['GTcls']['Segmentation'][0][0].astype(np.int32)
inst_anno = scipy.io.loadmat(inst_file)
inst_img = inst_anno['GTinst']['Segmentation'][0][0].astype(np.int32)
inst_img[inst_img == 0] = -1
inst_img[inst_img == 255] = -1
return label_img, inst_img
50 changes: 50 additions & 0 deletions chainercv/datasets/sbd/sbd_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import six

from chainer.dataset import download

from chainercv.datasets.voc.voc_utils \
import voc_instance_segmentation_label_names
from chainercv import utils

root = 'pfnet/chainercv/sbd'
url = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz' # NOQA
train_voc2012_url = 'http://home.bharathh.info/pubs/codes/SBD/train_noval.txt'


def _generate_voc2012_txt(base_path):
with open(os.path.join(base_path, 'train.txt'), 'r') as f:
train_ids = f.read().split('\n')[:-1]
with open(os.path.join(base_path, 'val.txt'), 'r') as f:
val_ids = f.read().split('\n')[:-1]
with open(os.path.join(base_path, 'train_voc2012.txt'), 'r') as f:
train_voc2012_ids = f.read().split('\n')[:-1]
all_ids = list(set(train_ids + val_ids))
val_voc2012_ids = [i for i in all_ids if i not in train_voc2012_ids]

with open(os.path.join(base_path, 'val_voc2012.txt'), 'w') as f:
f.write('\n'.join(sorted(val_voc2012_ids)))
with open(os.path.join(base_path, 'trainval_voc2012.txt'), 'w') as f:
f.write('\n'.join(sorted(all_ids)))


def get_sbd():
data_root = download.get_dataset_directory(root)
base_path = os.path.join(data_root, 'benchmark_RELEASE/dataset')

train_voc2012_file = os.path.join(base_path, 'train_voc2012.txt')
if os.path.exists(train_voc2012_file):
# skip downloading
return base_path

download_file_path = utils.cached_download(url)
ext = os.path.splitext(url)[1]
utils.extractall(download_file_path, data_root, ext)

six.moves.urllib.request.urlretrieve(train_voc2012_url, train_voc2012_file)
_generate_voc2012_txt(base_path)

return base_path


sbd_instance_segmentation_label_names = voc_instance_segmentation_label_names
85 changes: 85 additions & 0 deletions chainercv/datasets/voc/voc_instance_segmentation_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import os

import chainer

from chainercv.datasets.voc import voc_utils
from chainercv.utils import read_image


class VOCInstanceSegmentationDataset(chainer.dataset.DatasetMixin):

"""Instance segmentation dataset for PASCAL `VOC2012`_.

The class name of the label :math:`l` is :math:`l` th element of
:obj:`chainercv.datasets.voc_instance_segmentation_label_names`.

.. _`VOC2012`: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/

Args:
data_dir (string): Path to the root of the training data. If this is
:obj:`auto`, this class will automatically download data for you
under :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/voc`.
split ({'train', 'val', 'trainval'}): Select a split of the dataset.

"""

def __init__(self, data_dir='auto', split='train'):
if split not in ['train', 'trainval', 'val']:
raise ValueError(
'please pick split from \'train\', \'trainval\', \'val\'')

if data_dir == 'auto':
data_dir = voc_utils.get_voc('2012', split)

id_list_file = os.path.join(
data_dir, 'ImageSets/Segmentation/{0}.txt'.format(split))
self.ids = [id_.strip() for id_ in open(id_list_file)]

self.data_dir = data_dir

def __len__(self):
return len(self.ids)

def get_example(self, i):
"""Returns the i-th example.

Returns a color image, bounding boxes, masks and labels. The color
image is in CHW format.

Args:
i (int): The index of the example.

Returns:
A tuple of color image, bounding boxes, masks and labels whose
shapes are :math:`(3, H, W), (R, 4), (R, H, W), (R, )`
respectively.
:math:`H` and :math:`W` are height and width of the images,
and :math:`R` is the number of objects in the image.
The dtype of the color image and the bounding boxes are
:obj:`numpy.float32`, that of the masks is :obj: `numpy.bool`,
and that of the labels is :obj:`numpy.int32`.

"""
data_id = self.ids[i]
img_file = os.path.join(
self.data_dir, 'JPEGImages', data_id + '.jpg')
img = read_image(img_file, color=True)
label_img, inst_img = self._load_label_inst(data_id)
bbox, mask, label = voc_utils.image_wise_to_instance_wise(
label_img, inst_img)
return img, bbox, mask, label

def _load_label_inst(self, data_id):
label_file = os.path.join(
self.data_dir, 'SegmentationClass', data_id + '.png')
inst_file = os.path.join(
self.data_dir, 'SegmentationObject', data_id + '.png')
label_img = read_image(label_file, dtype=np.int32, color=False)
label_img = label_img[0]
label_img[label_img == 255] = -1
inst_img = read_image(inst_file, dtype=np.int32, color=False)
inst_img = inst_img[0]
inst_img[inst_img == 0] = -1
inst_img[inst_img == 255] = -1
return label_img, inst_img
27 changes: 27 additions & 0 deletions chainercv/datasets/voc/voc_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import os

from chainer.dataset import download
Expand Down Expand Up @@ -37,6 +38,30 @@ def get_voc(year, split):
return base_path


def image_wise_to_instance_wise(label_img, inst_img):
bbox = []
mask = []
label = []
inst_ids = np.unique(inst_img)
for inst_id in inst_ids[inst_ids != -1]:
msk = inst_img == inst_id
lbl = np.unique(label_img[msk])[0] - 1

assert inst_id != -1
assert lbl != -1

where = np.argwhere(msk)
(y_min, x_min), (y_max, x_max) = where.min(0), where.max(0) + 1

bbox.append((y_min, x_min, y_max, x_max))
mask.append(msk)
label.append(lbl)
bbox = np.array(bbox).astype(np.float32)
mask = np.array(mask).astype(np.bool)
label = np.array(label).astype(np.int32)
return bbox, mask, label


voc_bbox_label_names = (
'aeroplane',
'bicycle',
Expand All @@ -62,6 +87,8 @@ def get_voc(year, split):
voc_semantic_segmentation_label_names = (('background',) +
voc_bbox_label_names)

voc_instance_segmentation_label_names = voc_bbox_label_names

# these colors are used in the original MATLAB tools
voc_semantic_segmentation_label_colors = (
(0, 0, 0),
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from chainercv.utils.testing import assert_is_bbox_dataset # NOQA
from chainercv.utils.testing import assert_is_detection_link # NOQA
from chainercv.utils.testing import assert_is_image # NOQA
from chainercv.utils.testing import assert_is_instance_segmentation_dataset # NOQA
from chainercv.utils.testing import assert_is_label_dataset # NOQA
from chainercv.utils.testing import assert_is_point # NOQA
from chainercv.utils.testing import assert_is_semantic_segmentation_dataset # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from chainercv.utils.testing.assertions import assert_is_bbox_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions import assert_is_image # NOQA
from chainercv.utils.testing.assertions import assert_is_instance_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_label_dataset # NOQA
from chainercv.utils.testing.assertions import assert_is_point # NOQA
from chainercv.utils.testing.assertions import assert_is_semantic_segmentation_dataset # NOQA
Expand Down
1 change: 1 addition & 0 deletions chainercv/utils/testing/assertions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from chainercv.utils.testing.assertions.assert_is_bbox_dataset import assert_is_bbox_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_detection_link import assert_is_detection_link # NOQA
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image # NOQA
from chainercv.utils.testing.assertions.assert_is_instance_segmentation_dataset import assert_is_instance_segmentation_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_label_dataset import assert_is_label_dataset # NOQA
from chainercv.utils.testing.assertions.assert_is_point import assert_is_point # NOQA
from chainercv.utils.testing.assertions.assert_is_semantic_segmentation_dataset import assert_is_semantic_segmentation_dataset # NOQA
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import six

from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


def assert_is_instance_segmentation_dataset(
dataset, n_fg_class, n_example=None
):
"""Checks if a dataset satisfies instance segmentation dataset APIs.

This function checks if a given dataset satisfies instance segmentation
dataset APIs or not.
If the dataset does not satifiy the APIs, this function raises an
:class:`AssertionError`.

Args:
dataset: A dataset to be checked.
n_fg_class (int): The number of foreground classes.
n_example (int): The number of examples to be checked.
If this argument is specified, this function picks
examples ramdomly and checks them. Otherwise,
this function checks all examples.

"""

assert len(dataset) > 0, 'The length of dataset must be greater than zero.'

if n_example:
for _ in six.moves.range(n_example):
i = np.random.randint(0, len(dataset))
_check_example(dataset[i], n_fg_class)
else:
for i in six.moves.range(len(dataset)):
_check_example(dataset[i], n_fg_class)


def _check_example(example, n_fg_class):
assert len(example) >= 4, \
'Each example must have at least four elements:' \
'img, bbox, mask and label.'

img, bbox, mask, label = example[:4]

assert_is_image(img, color=True)
_, H, W = img.shape
assert_is_bbox(bbox, size=(H, W))
R = bbox.shape[0]

assert isinstance(mask, np.ndarray), \
'mask must be a numpy.ndarray.'
assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert mask.dtype == np.bool, \
'The type of mask must be bool'
assert label.dtype == np.int32, \
'The type of label must be numpy.int32.'
assert mask.shape == (R, H, W), \
'The shape of mask must be (R, H, W).'
assert label.shape == (R,), \
'The shape of label must be (R, ).'
assert label.min() >= 0 and label.max() < n_fg_class, \
'The value of label must be in [0, n_fg_class - 1].'
Loading