Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Add VOC and SBD instance segmentation dataset #540

Merged

Conversation

knorth55
Copy link
Contributor

@knorth55 knorth55 commented Mar 20, 2018

[EDIT] This PR is splitted into #540 and #541

  • add VOCInstanceSegmentationDataset
  • add SBDInstanceSegmenationDataset
  • add voc_instance_segmentation_label_names
  • add sbd_instance_segmentation_label_names
  • add assert_is_instance_segmentation_dataset
  • add test_voc_instance_segmentation_dataset.py
  • add test_sbd_instance_segmentation_dataset.py
  • add test_assert_is_instance_segmentation_dataset.py

@knorth55 knorth55 requested a review from yuyu2172 March 20, 2018 12:34
@yuyu2172 yuyu2172 mentioned this pull request Mar 20, 2018
21 tasks
@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch from 71d5e27 to 800f71e Compare March 20, 2018 12:41
@knorth55
Copy link
Contributor Author

knorth55 commented Mar 20, 2018

[EDIT] This PR is splitted into #540 and #541

from chainercv.datasets import SBDInstanceSegmentationDataset
from chainercv.datasets import sbd_instance_segmentation_label_names
from chainercv.visualizations import vis_image
from chainercv.visualizations import vis_instance_segmentation
import matplotlib.pyplot as plt

dataset = SBDInstanceSegmentationDataset()
img, bbox, mask, label = dataset[0]
vis_instance_segmentation(
    img, bbox, mask, label, label_names=sbd_instance_segmentation_label_names,
    alpha=0.7)
plt.show()

ins_dataset_example

@knorth55 knorth55 changed the title Add voc instance segmentation dataset Add VOC and SBD instance segmentation dataset and visualize function Mar 20, 2018
@knorth55 knorth55 self-assigned this Mar 20, 2018


sbd_instance_segmentation_label_names = (
'background',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data representation looks okay except for the range of label, which I want it to be [0, n_fg_class - 1]. n_fg_class is the number of classes excluding the background. This means that the label would never be "background". The same convention is adopted for detection.

I think that consistency really matters between detection and instance segmentation because two of them have very similar interface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i add background as semantic segmentation because we normally have background label in instance segmentation.
The task is pixelwise segmentation and it means that label includes baxkground for all pixels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now Mask-RCNN and FCIS have two stage structure with RPN, but this task predicts pixelwise probabilities, not detecting the object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods are OK with no-background label names, so I updated to remove background label.
However, we need to think better way to merge semantic segmentation, instance segmentation and object detection.

@@ -62,6 +62,8 @@ def get_voc(year, split):
voc_semantic_segmentation_label_names = (('background',) +
voc_bbox_label_names)

voc_instance_segmentation_label_names = voc_semantic_segmentation_label_names
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

from chainercv.visualizations.vis_semantic_segmentation import _default_cmap


def vis_instance_segmentation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want this to be consistent with vis_bbox.
This means it should take img together with annotations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i got comfused with vis_semantic_segmentation. why doesn't it have vis_image in vis_semantic _segmentation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i got comfused with vis_semantic_segmentation. why doesn't it have vis_image in vis_semantic _segmentation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated.

Copy link
Member

@yuyu2172 yuyu2172 Mar 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good.

Here is a rationale behind the design.
In semantic segmentation, an annotation is displayed in a different plot from an image.
In the case of detection and instance segmentation, annotations are displayed on top of an image.
That is why taking image is mandatory for vis_instance_segmentation.

@yuyu2172
Copy link
Member

yuyu2172 commented Mar 20, 2018

In [2]: chainercv.datasets.SBDInstanceSegmentationDataset()
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
<ipython-input-2-c6ccaee382d6> in <module>()
----> 1 chainercv.datasets.SBDInstanceSegmentationDataset()

~/projects/chainercv/chainercv/datasets/sbd/sbd_instance_segmentation_dataset.py in __init__(self, data_dir, split)
     32
     33         if data_dir == 'auto':
---> 34             data_dir = sbd_utils.get_sbd()
     35
     36         id_list_file = os.path.join(

~/projects/chainercv/chainercv/datasets/sbd/sbd_utils.py in get_sbd()
     17     if not os.path.exists(train_voc2012_file):
     18         six.moves.urllib.request.urlretrieve(
---> 19             train_voc2012_url, train_voc2012_file)
     20
     21     if os.path.exists(os.path.join(base_path, 'train.txt')):

~/miniconda2/envs/general3/lib/python3.6/urllib/request.py in urlretrieve(url, filename, reporthook, data)
    256         # Handle temporary file setup.
    257         if filename:
--> 258             tfp = open(filename, 'wb')
    259         else:
    260             tfp = tempfile.NamedTemporaryFile(delete=False)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/yuyu/.chainer/dataset/pfnet/chainercv/sbd/benchmark_RELEASE/dataset/train_voc2012.txt'

Does downloader work?

EDIT:
You need to make the directory before calling urlretrieve if it does not exist.

self.data_dir, 'cls', data_id + '.mat')
ins_file = os.path.join(
self.data_dir, 'inst', data_id + '.mat')
seg_mat = scipy.io.loadmat(seg_file)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, i will add warnings too

setup.py Outdated
@@ -23,7 +23,8 @@
setup_requires = ['numpy']
install_requires = [
'chainer>=3.2',
'Pillow'
'Pillow',
'scipy'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you not add this to the requirement?
I want to leave scipy as an optional dependency like matplotlib.

Please edit README.md and environment.yml instead.

@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch from 2f22f84 to f2f3f9e Compare March 20, 2018 16:33
@knorth55
Copy link
Contributor Author

You need to make the directory before calling urlretrieve if it does not exist.

thx, I updated.

@knorth55 knorth55 changed the title Add VOC and SBD instance segmentation dataset and visualize function [WIP] Add VOC and SBD instance segmentation dataset and visualize function Mar 20, 2018
@knorth55 knorth55 changed the title [WIP] Add VOC and SBD instance segmentation dataset and visualize function Add VOC and SBD instance segmentation dataset and visualize function Mar 20, 2018
@knorth55
Copy link
Contributor Author

@yuyu2172 I update to follow your advice and add test.

Copy link
Member

@yuyu2172 yuyu2172 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the following to docs (docs/source/reference)

  • assert_is_instance_segmentation_dataset
  • SBDInstanceSegmentationDataset
  • VOCInstanceSegmentationDataset
  • vis_instance_segmentation

'The shape of mask must be (N, H, W).'
assert label.shape == (N,), \
'The shape of label must be (N, ).'
assert label.min() >= -1 and label.max() < n_class, \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will label.min() ever be ==-1?
If not, can you change the minimum allowed value to 0?


assert_is_image(img, color=True)
_, H, W = img.shape
N = bbox.shape[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use R instead of N so that the notation is consistent with detection?

return base_path


sbd_instance_segmentation_label_names = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the same as voc_bbox_label_names, how about sbd_instance_segmentation_label_names = voc_instance_segmentation_label_names?

This way, we know that the two are the same.


assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert bbox.dtype == np.float32, \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using assert_is_bbox?

from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


def assert_is_instance_segmentation_dataset(dataset, n_class, n_example=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename this to n_fg_class?


Args:
dataset: A dataset to be checked.
n_class (int): The number of classes including background.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this to n_fg_class (int): The number of foreground classes.

"""Instance segmentation dataset for Semantic Boundaries Dataset `SBD`_.

The class name of the label :math:`l` is :math:`l` th element of
:obj:`chainercv.datasets.voc_semantic_segmentation_label_names`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong.

"""Instance segmentation dataset for PASCAL `VOC2012`_.

The class name of the label :math:`l` is :math:`l` th element of
:obj:`chainercv.datasets.voc_semantic_segmentation_label_names`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong


Returns:
tuple of color image, bounding boxes, masks and labels whose
shapes are (3, H, W), (N, 4), (N, H, W), (N, ) respectively.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use R instead of N?
We have been using math notation for shapes.

:math:`R`

H and W are height and width of the images, and N is the number
of objects in the image. The dtype of the color image and
the bounding boxes are :obj:`numpy.float32`, and that of the
masks and the labels are :obj:`numpy.int32`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask is not numpy.int32

>>> plot.show()

Args:
img (~numpy.ndarray): An array of shape :math:`(3, height, width)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

height, width --> H, W

represents a bounding box of an object.
The bounding box is :math:`(y_min, x_min, y_max, x_max)`.
mask (~numpy.ndarray): A bool array of shape
:math`(N, height, width)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


def setUp(self):
self.img = np.random.randint(0, 255, size=(3, 32, 48))
self.mask = np.random.randint(0, 1, size=(self.n_bbox, 32, 48))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is always 0.
np.random.randint(0, 2, size= ..., dtype=np.bool) is better.


def setUp(self):
self.img = np.random.randint(0, 255, size=(3, 32, 48))
self.mask = np.random.randint(0, 1, size=(self.n_bbox, 32, 48))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@yuyu2172
Copy link
Member

Please add tests for assert_is_instance_segmentation_dataset.

@yuyu2172
Copy link
Member

I think it is better to put vis_instance_segmentation into a different PR.

@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch 2 times, most recently from 5c0bcb1 to 69c2402 Compare March 21, 2018 03:57
@knorth55 knorth55 changed the title Add VOC and SBD instance segmentation dataset and visualize function Add VOC and SBD instance segmentation dataset Mar 21, 2018
@knorth55 knorth55 mentioned this pull request Mar 21, 2018
2 tasks
@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch from 69c2402 to 8bf9cc2 Compare March 21, 2018 04:07
@knorth55
Copy link
Contributor Author

knorth55 commented Mar 21, 2018

@yuyu2172 Updated and split this PR into two (this and #541).

@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch 2 times, most recently from 2020492 to 484e239 Compare March 21, 2018 04:35
img, bbox, mask, label = example[:4]

assert_is_image(img, color=True)
assert_is_bbox(bbox)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pass size argument?

'The type of bbox must be bool'
assert label.dtype == np.int32, \
'The type of label must be numpy.int32.'
assert (bbox[:, 0] < bbox[:, 2]).all(), \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These assertions for bbox is not necessary because assert_is_bbox is used.

R = bbox.shape[0]

assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert similar test for mask


assert isinstance(label, np.ndarray), \
'label must be a numpy.ndarray.'
assert mask.dtype == bool, \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool -> np.bool

Also, the message should be The type of mask must be np.bool..

def get_example(self, i):
img = np.random.randint(0, 256, size=(3, 48, 64))
n_bbox = np.random.randint(10, 20)
mask = np.random.randint(0, 1, size=(n_bbox, 48, 64), dtype=bool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

randint(0, 2, ...)

mask.append(mask_inst)
label.append(instance_class)
bbox = np.array(bbox).astype(np.float32)
mask = np.array(mask).astype(bool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use np.bool instead of bool.
I prefer not to use Python's builit-in types for casting numpy arrays because it is inconsistent.
We do not do array.astype(float), but do array.astype(np.float64).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think np.bool is same as bool.
In numpy doc, there is no np.bool dtype officially.
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html

In [1]: import numpy as np

In [2]: np.bool
Out[2]: bool

In [3]: bool
Out[3]: bool

In [4]: bool == np.bool
Out[4]: True

mask.append(mask_inst)
label.append(instance_class)
bbox = np.array(bbox).astype(np.float32)
mask = np.array(mask).astype(bool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

assert inst_id != -1
assert instance_class != -1

where = np.argwhere(mask_inst)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inst, ins_*, instances are used.
Please be consistent.

I would like to make a suggestion.

  • instances --> inst_ids
  • ins_img --> inst_img
  • ins_file --> inst_file

Also, it is weird that instance_class is an element of label. Same goes for mask_inst.
How about

  • instance_class --> lbl or just label.append(np.unique(seg_img[mask_inst])[0] - 1)
  • mask_inst --> m

Copy link
Contributor Author

@knorth55 knorth55 Mar 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer inst_*, so I will update.

assert instance_class != -1

where = np.argwhere(mask_inst)
(y1, x1), (y2, x2) = where.min(0), where.max(0) + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y1,x1,y2,x2 -> y_min, x_min, y_max, x_max
This is the notation we use.

assert instance_class != -1

where = np.argwhere(mask_inst)
(y1, x1), (y2, x2) = where.min(0), where.max(0) + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

ins_img = ins_mat['GTinst']['Segmentation'][0][0].astype(np.int32)
ins_img[ins_img == 0] = -1
ins_img[ins_img == 255] = -1
return seg_img, ins_img
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seg_img is a short notation for "segmentation image", but ins_img is also an image with segmentations.
How about label_img?

Copy link
Contributor Author

@knorth55 knorth55 Mar 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the value of ins_img is inst_id, not label.
label_img is not a good name, i think.
I prefer inst_img.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I was talking about chaning seg_img.
My proposal is to use label_img instead of seg_img. inst_img looks good.

Something like this

  • seg_file --> label_file.
  • seg_anno --> label_anno.
  • seg_img --> label_img

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will update

@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch from 484e239 to bd11c7a Compare March 22, 2018 00:20
@knorth55
Copy link
Contributor Author

knorth55 commented Mar 22, 2018

@yuyu2172
I searched a little bit about numpy boolean datatype.
I think numpy.bool_ is the correct dtype.
https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.dtype.html

In [1]: import numpy as np

In [2]: a = np.array([True, False], dtype=bool)

In [3]: a.dtype.type
Out[3]: numpy.bool_

In [4]: a = np.array([True, False], dtype=np.bool)

In [5]: a.dtype.type
Out[5]: numpy.bool_

In [6]: a = np.array([True, False], dtype=np.bool_)

In [7]: a.dtype.type
Out[7]: numpy.bool_

In [8]: np.bool_ == bool
Out[8]: False

In [9]: np.bool_ == np.bool
Out[9]: False

In [10]: bool == np.bool
Out[10]: True

@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch 2 times, most recently from d180138 to a05cc2e Compare March 22, 2018 00:40
assert_is_image(img, color=True)
assert_is_bbox(bbox, size=(H, W))

assert isinstance(bbox, np.ndarray), \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant.

ins_img = ins_mat['GTinst']['Segmentation'][0][0].astype(np.int32)
ins_img[ins_img == 0] = -1
ins_img[ins_img == 255] = -1
return seg_img, ins_img
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I was talking about chaning seg_img.
My proposal is to use label_img instead of seg_img. inst_img looks good.

Something like this

  • seg_file --> label_file.
  • seg_anno --> label_anno.
  • seg_img --> label_img

@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch 3 times, most recently from 9529629 to 59c0756 Compare March 23, 2018 05:31
@knorth55
Copy link
Contributor Author

@yuyu2172 Updated to pass the test,

@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch from 59c0756 to 0c04ffe Compare March 23, 2018 14:37
Copy link
Member

@yuyu2172 yuyu2172 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You do not need to rebase commits
(IMO rebasing is not recommend because the change from the previously reviewed version is not clear)

inst_img[inst_img == 255] = -1
return label_img, inst_img

def _prepare_data(self, label_img, inst_img):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to reuse this method, how about putting this method to voc_utils?
I did not find inheritance relationship betweenVOC and SBD effective.

The name of the method should be something like
convert_to_instance_wise_anno

Copy link
Member

@yuyu2172 yuyu2172 Mar 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe convert_to_instance_segmentation is better

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or image_wise_to_instance_wise

This name is nice because the name of the opposite operation is obvious.

If these operations are common, we can support them under utilities similar to bbox2loc and loc2bbox

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, you do not need to make the functions fully supported in utilities. This is just a possibility.

bbox, mask, label = self._prepare_data(label_img, inst_img)
return img, bbox, mask, label

def _load_label_ins(self, data_id):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be _load_label_inst

@knorth55
Copy link
Contributor Author

I rebase the commits because it is easy to understand from git log.
Update or fix or some other small commits are pretty annoying, and I dont expect this kind of endless review.

@yuyu2172
Copy link
Member

I rebase the commits because it is easy to understand from git log.

OK. I am just suggesting that rebasing is not necessary just in case you do not know. Please feel free to choose whatever you need to do. And yes, this PR is big.

@yuyu2172
Copy link
Member

yuyu2172 commented Mar 26, 2018

Please resolve conflict with the master branch.

@knorth55 knorth55 force-pushed the add-voc-instance-segmentation-dataset branch from 0c04ffe to 08258bc Compare March 26, 2018 04:48
@knorth55
Copy link
Contributor Author

I resolved conflicts and update to follow the reviews.

@yuyu2172 yuyu2172 merged commit ff7f705 into chainer:master Mar 26, 2018
@knorth55 knorth55 deleted the add-voc-instance-segmentation-dataset branch March 26, 2018 05:22
@yuyu2172 yuyu2172 added this to the v0.9 milestone Apr 17, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants