Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Commit e4b5ea0

Browse files
committed
Merge branch 'cityscapes-label-variable-names' of https://github.com/yuyu2172/chainercv into add-pspnet-infer
2 parents 07c8594 + 74fab33 commit e4b5ea0

23 files changed

+288
-279
lines changed

.travis.yml

+4-3
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@ notifications:
1313
email: false
1414

1515
install:
16+
# We use the older version of conda because there is a bug in version 4.3.27.
1617
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
17-
wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh;
18+
wget https://repo.continuum.io/miniconda/Miniconda2-4.3.21-Linux-x86_64.sh -O miniconda.sh;
1819
else
19-
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
20+
wget https://repo.continuum.io/miniconda/Miniconda3-4.3.21-Linux-x86_64.sh -O miniconda.sh;
2021
fi
2122
- bash miniconda.sh -b -p $HOME/miniconda
2223
- export PATH="$HOME/miniconda/bin:$PATH"
2324
- hash -r
2425
- conda config --set always_yes yes --set changeps1 no
25-
- conda update -q conda
26+
# - conda update -q conda
2627
# Useful for debugging any issues with conda
2728
- conda info -a
2829

chainercv/datasets/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from chainercv.datasets.camvid.camvid_dataset import camvid_label_names # NOQA
44
from chainercv.datasets.camvid.camvid_dataset import CamVidDataset # NOQA
55
from chainercv.datasets.cityscapes.cityscapes_semantic_segmentation_dataset import CityscapesSemanticSegmentationDataset # NOQA
6-
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_label_colors # NOQA
7-
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_label_names # NOQA
6+
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_colors # NOQA
7+
from chainercv.datasets.cityscapes.cityscapes_utils import cityscapes_semantic_segmentation_label_names # NOQA
88
from chainercv.datasets.cub.cub_keypoint_dataset import CUBKeypointDataset # NOQA
99
from chainercv.datasets.cub.cub_label_dataset import CUBLabelDataset # NOQA
1010
from chainercv.datasets.cub.cub_utils import cub_label_names # NOQA
11-
from chainercv.datasets.directory_parsing_classification_dataset import directory_parsing_label_names # NOQA
12-
from chainercv.datasets.directory_parsing_classification_dataset import DirectoryParsingClassificationDataset # NOQA
11+
from chainercv.datasets.directory_parsing_label_dataset import directory_parsing_label_names # NOQA
12+
from chainercv.datasets.directory_parsing_label_dataset import DirectoryParsingLabelDataset # NOQA
1313
from chainercv.datasets.online_products.online_products_dataset import OnlineProductsDataset # NOQA
1414
from chainercv.datasets.transform_dataset import TransformDataset # NOQA
1515
from chainercv.datasets.voc.voc_detection_dataset import VOCDetectionDataset # NOQA

chainercv/datasets/cityscapes/cityscapes_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
Label('licenseplate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
4848
])
4949

50-
cityscapes_label_names = tuple(
50+
cityscapes_semantic_segmentation_label_names = tuple(
5151
l.name for l in cityscapes_labels if not l.ignoreInEval)
5252

53-
cityscapes_label_colors = tuple(
53+
cityscapes_semantic_segmentation_label_colors = tuple(
5454
l.color for l in cityscapes_labels if not l.ignoreInEval)

chainercv/datasets/directory_parsing_classification_dataset.py chainercv/datasets/directory_parsing_label_dataset.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def directory_parsing_label_names(root, numerical_sort=False):
1212
layer below the root directory.
1313
1414
The label names can be used together with
15-
:class:`chainercv.datasets.DirectoryParsingClassificationDataset`.
15+
:class:`chainercv.datasets.DirectoryParsingLabelDataset`.
1616
The index of a label name corresponds to the label id
1717
that is used by the dataset to refer the label.
1818
@@ -44,8 +44,8 @@ def _check_img_ext(path):
4444
extension in img_extensions)
4545

4646

47-
def _parse_classification_dataset(root, label_names,
48-
check_img_file=_check_img_ext):
47+
def _parse_label_dataset(root, label_names,
48+
check_img_file=_check_img_ext):
4949
img_paths = list()
5050
labels = list()
5151
for label, label_name in enumerate(label_names):
@@ -65,8 +65,8 @@ def _parse_classification_dataset(root, label_names,
6565
return img_paths, np.array(labels, np.int32)
6666

6767

68-
class DirectoryParsingClassificationDataset(chainer.dataset.DatasetMixin):
69-
"""A classification dataset for directories whose names are label names.
68+
class DirectoryParsingLabelDataset(chainer.dataset.DatasetMixin):
69+
"""A label dataset whose label names are the names of the subdirectories.
7070
7171
The label names are the names of the directories that locate a layer below
7272
the root directory.
@@ -91,8 +91,8 @@ class DirectoryParsingClassificationDataset(chainer.dataset.DatasetMixin):
9191
--- class_1
9292
|-- img_0.png
9393
94-
>>> from chainercv.dataset import DirectoryParsingClassificationDataset
95-
>>> dataset = DirectoryParsingClassificationDataset('root')
94+
>>> from chainercv.datasets import DirectoryParsingLabelDataset
95+
>>> dataset = DirectoryParsingLabelDataset('root')
9696
>>> dataset.paths
9797
['root/class_0/img_0.png', 'root/class_0/img_1.png',
9898
'root_class_1/img_0.png']
@@ -123,7 +123,7 @@ def __init__(self, root, check_img_file=None, color=True,
123123
if check_img_file is None:
124124
check_img_file = _check_img_ext
125125

126-
self.img_paths, self.labels = _parse_classification_dataset(
126+
self.img_paths, self.labels = _parse_label_dataset(
127127
root, label_names, check_img_file)
128128

129129
def __len__(self):

chainercv/links/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from chainercv.links.connection.conv_2d_bn_activ import Conv2DBNActiv # NOQA
33

44
from chainercv.links.model.feature_predictor import FeaturePredictor # NOQA
5+
from chainercv.links.model.pickable_sequential_chain import PickableSequentialChain # NOQA
56
from chainercv.links.model.pixelwise_softmax_classifier import PixelwiseSoftmaxClassifier # NOQA
6-
from chainercv.links.model.sequential_feature_extractor import SequentialFeatureExtractor # NOQA
77

88
from chainercv.links.model.faster_rcnn.faster_rcnn_vgg import FasterRCNNVGG16 # NOQA
99
from chainercv.links.model.pspnet.pspnet import PSPNet # NOQA

chainercv/links/model/faster_rcnn/faster_rcnn_vgg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self,
109109
vgg_initialW = chainer.initializers.constant.Zero()
110110

111111
extractor = VGG16(initialW=vgg_initialW)
112-
extractor.feature_names = 'conv5_3'
112+
extractor.pick = 'conv5_3'
113113
# Delete all layers after conv5_3.
114114
extractor.remove_unused()
115115
rpn = RegionProposalNetwork(

chainercv/links/model/feature_predictor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class FeaturePredictor(chainer.Chain):
1414

15-
"""Wrapper that adds a prediction method to a feature extraction model.
15+
"""Wrapper that adds a prediction method to a feature extraction link.
1616
1717
The :meth:`predict` takes three steps to make a prediction.
1818
@@ -28,7 +28,7 @@ class FeaturePredictor(chainer.Chain):
2828
>>> model = FeaturePredictor(base_model, 224, 256)
2929
>>> prob = model.predict([img])
3030
# Predicting multiple features
31-
>>> model.extractor.feature_names = ['conv5_3', 'fc7']
31+
>>> model.extractor.pick = ['conv5_3', 'fc7']
3232
>>> conv5_3, fc7 = model.predict([img])
3333
3434
When :obj:`self.crop == 'center'`, :meth:`predict` extracts features from
@@ -41,7 +41,7 @@ class FeaturePredictor(chainer.Chain):
4141
crops.
4242
4343
Args:
44-
extractor: A feature extraction model. This is a callable chain
44+
extractor: A feature extraction link. This is a callable chain
4545
that takes a batch of images and returns a variable or a
4646
tuple of variables.
4747
crop_size (int or tuple): The height and the width of an image after
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import chainer
2+
3+
4+
class PickableSequentialChain(chainer.Chain):
5+
6+
"""A sequential chain that can pick intermediate layers.
7+
8+
Callable objects, such as :class:`chainer.Link` and
9+
:class:`chainer.Function`, can be registered to this chain with
10+
:meth:`init_scope`.
11+
This chain keeps the order of registrations and :meth:`__call__`
12+
executes callables in that order.
13+
A :class:`chainer.Link` object in the sequence will be added as
14+
a child link of this link.
15+
16+
:meth:`__call__` returns single or multiple layers that are picked up
17+
through a stream of computation.
18+
These layers can be specified by :obj:`pick`, which contains
19+
the names of the layers that are collected.
20+
When :obj:`pick` is a string, single layer is returned.
21+
When :obj:`pick` is an iterable of strings, a tuple of layers
22+
is returned. The order of the layers is the same as the order of
23+
the strings in :obj:`pick`.
24+
When :obj:`pick` is :obj:`None`, the last layer is returned.
25+
26+
Examples:
27+
28+
>>> import chainer.functions as F
29+
>>> import chainer.links as L
30+
>>> model = PickableSequentialChain()
31+
>>> with model.init_scope():
32+
>>> model.l1 = L.Linear(None, 1000)
33+
>>> model.l1_relu = F.relu
34+
>>> model.l2 = L.Linear(None, 1000)
35+
>>> model.l2_relu = F.relu
36+
>>> model.l3 = L.Linear(None, 10)
37+
>>> # This is layer l3
38+
>>> layer3 = model(x)
39+
>>> # The layers to be collected can be changed.
40+
>>> model.pick = ('l2_relu', 'l1_relu')
41+
>>> # These are layers l2_relu and l1_relu.
42+
>>> layer2, layer1 = model(x)
43+
44+
Parameters:
45+
pick (string or iterable of strings):
46+
Names of layers that are collected during
47+
the forward pass.
48+
layer_names (iterable of strings):
49+
Names of layers that can be collected from
50+
this chain. The names are ordered in the order
51+
of computation.
52+
53+
"""
54+
55+
def __init__(self):
56+
super(PickableSequentialChain, self).__init__()
57+
self.layer_names = list()
58+
# Two attributes are initialized by the setter of pick.
59+
# self._pick -> None
60+
# self._return_tuple -> False
61+
self.pick = None
62+
63+
def __setattr__(self, name, value):
64+
super(PickableSequentialChain, self).__setattr__(name, value)
65+
if self.within_init_scope and callable(value):
66+
self.layer_names.append(name)
67+
68+
def __delattr__(self, name):
69+
if self._pick and name in self._pick:
70+
raise AttributeError(
71+
'layer {:s} is registered to pick.'.format(name))
72+
super(PickableSequentialChain, self).__delattr__(name)
73+
try:
74+
self.layer_names.remove(name)
75+
except ValueError:
76+
pass
77+
78+
@property
79+
def pick(self):
80+
if self._pick is None:
81+
return None
82+
83+
if self._return_tuple:
84+
return self._pick
85+
else:
86+
return self._pick[0]
87+
88+
@pick.setter
89+
def pick(self, pick):
90+
if pick is None:
91+
self._return_tuple = False
92+
self._pick = None
93+
return
94+
95+
if (not isinstance(pick, str) and
96+
all(isinstance(name, str) for name in pick)):
97+
return_tuple = True
98+
else:
99+
return_tuple = False
100+
pick = (pick,)
101+
if any(name not in self.layer_names for name in pick):
102+
raise ValueError('Invalid layer name')
103+
104+
self._return_tuple = return_tuple
105+
self._pick = tuple(pick)
106+
107+
def remove_unused(self):
108+
"""Delete all layers that are not needed for the forward pass.
109+
110+
"""
111+
if self._pick is None:
112+
return
113+
114+
# The biggest index among indices of the layers that are included
115+
# in pick.
116+
last_index = max(self.layer_names.index(name) for name in self._pick)
117+
for name in self.layer_names[last_index + 1:]:
118+
delattr(self, name)
119+
120+
def __call__(self, x):
121+
"""Forward this model.
122+
123+
Args:
124+
x (chainer.Variable or array): Input to the model.
125+
126+
Returns:
127+
chainer.Variable or tuple of chainer.Variable:
128+
The returned layers are determined by :obj:`pick`.
129+
130+
"""
131+
if self._pick is None:
132+
pick = (self.layer_names[-1],)
133+
else:
134+
pick = self._pick
135+
136+
# The biggest index among indices of the layers that are included
137+
# in pick.
138+
last_index = max(self.layer_names.index(name) for name in pick)
139+
140+
layers = dict()
141+
h = x
142+
for name in self.layer_names[:last_index + 1]:
143+
h = self[name](h)
144+
if name in pick:
145+
layers[name] = h
146+
147+
if self._return_tuple:
148+
layers = tuple(layers[name] for name in pick)
149+
else:
150+
layers = list(layers.values())[0]
151+
return layers

0 commit comments

Comments
 (0)