Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Add dilate option and MultiNodeBatchNormalization to Conv2DActiv and Conv2DBNActiv #494

Merged
merged 37 commits into from
Mar 25, 2018
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
76a529a
Add dilate option to Conv2DActiv and Conv2DBNActiv, and add enable Mu…
mitmul Dec 10, 2017
4706184
Add tests for dilate option
mitmul Dec 10, 2017
f760ebf
Update travis.yml
mitmul Dec 10, 2017
18f0779
Fix .travis.yml
mitmul Dec 10, 2017
3f7f158
Skip a test if ChainerMN is not installed
mitmul Dec 10, 2017
89bb008
Fix .travis.yml
mitmul Dec 10, 2017
a7d8123
Fix .travis.yml
mitmul Dec 10, 2017
875a216
Fix .travis.yml
mitmul Dec 10, 2017
8f9f63f
Fix .travis.yml
mitmul Dec 10, 2017
96857b8
Fix .travis.yml
mitmul Dec 10, 2017
f0a8544
Fix .travis.yml
mitmul Dec 10, 2017
80be876
Update tests
mitmul Dec 10, 2017
b68099d
Update tests
mitmul Dec 10, 2017
14987d2
Update tests
mitmul Dec 10, 2017
416839d
Update tests
mitmul Dec 10, 2017
8fcba84
Update tests and travis settings
mitmul Dec 10, 2017
eec207d
Fix flake8 errors
mitmul Dec 10, 2017
4909c12
Fix tests
mitmul Dec 10, 2017
f878fb1
Fix Conv2DBNActiv
mitmul Dec 10, 2017
9604d3e
Follow reviews
mitmul Dec 11, 2017
bc29554
Update
mitmul Dec 11, 2017
0b0d28f
Follow review comments
mitmul Dec 11, 2017
2fe21b8
Fix flake8 errors
mitmul Dec 11, 2017
4357341
Merge branch 'master' of github.com:chainer/chainercv into add-conv2d…
mitmul Mar 6, 2018
b88d94e
Follow reviews
mitmul Mar 6, 2018
07c8789
Use latest ChainerMN for tests
mitmul Mar 6, 2018
59fe88b
Remove TypeError
mitmul Mar 6, 2018
3585df6
Fix tests and travis.yml
mitmul Mar 7, 2018
6b46c9b
Fix tests
mitmul Mar 7, 2018
22dd5e1
Add chainermn to environment.yml
mitmul Mar 7, 2018
eddfb10
Fix environment.yml
mitmul Mar 7, 2018
1153247
Fix environment.yml
mitmul Mar 15, 2018
024d00e
Merge branch 'master' of github.com:chainer/chainercv into add-conv2d…
mitmul Mar 16, 2018
59ce661
Fix tests in test_conv_2d_activ.py
mitmul Mar 16, 2018
0132e33
Fix a typo
mitmul Mar 16, 2018
aa858ea
Install mpi4py through conda
mitmul Mar 22, 2018
a9e908d
Update .travis.yml to export LD_LIBRARY_PATH
mitmul Mar 22, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ install:
- conda info -a

- if [[ "$OPTIONAL_MODULES" == "1" ]]; then
export LIBRARY_PATH="$HOME/miniconda/lib:$LIBRARY_PATH";
conda env create -f environment.yml;
source activate chainercv;
cd $HOME;
wget https://github.com/chainer/chainermn/archive/v1.0.0.tar.gz -O chainermn.tar.gz;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update to the latest release.

tar zxf chainermn.tar.gz;
cd chainermn-1.0.0;
python setup.py install --no-nccl;
cd $TRAVIS_BUILD_DIR;
else
conda env create -f environment_minimum.yml;
source activate chainercv_minimum;
Expand All @@ -49,4 +56,8 @@ script:
- autopep8 -r . | tee check_autopep8
- test ! -s check_autopep8
- python style_checker.py .
- MPLBACKEND="agg" nosetests -a '!gpu,!slow' tests
- if [[ "$OPTIONAL_MODULES" == "1" ]]; then
MPLBACKEND="agg" mpiexec -n 2 nosetests -a '!gpu,!slow' tests;
else
MPLBACKEND="agg" nosetests -a '!gpu,!slow' tests;
fi
16 changes: 12 additions & 4 deletions chainercv/links/connection/conv_2d_activ.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import chainer
from chainer.functions import relu
from chainer.links import Convolution2D
from chainer.links import DilatedConvolution2D


class Conv2DActiv(chainer.Chain):
Expand Down Expand Up @@ -40,6 +41,8 @@ class Conv2DActiv(chainer.Chain):
:obj:`stride=s` and :obj:`stride=(s, s)` are equivalent.
pad (int or pair of ints): Spatial padding width for input arrays.
:obj:`pad=p` and :obj:`pad=(p, p)` are equivalent.
dilate (int or pair of ints): Dilation factor of filter applications.
:obj:`dilate=d` and :obj:`dilate=(d, d)` are equivalent.
nobias (bool): If :obj:`True`,
then this link does not use the bias term.
initialW (4-D array): Initial weight value. If :obj:`None`, the default
Expand All @@ -56,17 +59,22 @@ class Conv2DActiv(chainer.Chain):
"""

def __init__(self, in_channels, out_channels, ksize=None,
stride=1, pad=0, nobias=False, initialW=None,
stride=1, pad=0, dilate=1, nobias=False, initialW=None,
initial_bias=None, activ=relu):
if ksize is None:
out_channels, ksize, in_channels = in_channels, out_channels, None

self.activ = activ
super(Conv2DActiv, self).__init__()
with self.init_scope():
self.conv = Convolution2D(
in_channels, out_channels, ksize, stride, pad,
nobias, initialW, initial_bias)
if dilate > 1:
self.conv = DilatedConvolution2D(
in_channels, out_channels, ksize, stride, pad, dilate,
nobias, initialW, initial_bias)
else:
self.conv = Convolution2D(
in_channels, out_channels, ksize, stride, pad,
nobias, initialW, initial_bias)

def __call__(self, x):
h = self.conv(x)
Expand Down
42 changes: 34 additions & 8 deletions chainercv/links/connection/conv_2d_bn_activ.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
from chainer.functions import relu
from chainer.links import BatchNormalization
from chainer.links import Convolution2D
from chainer.links import DilatedConvolution2D

try:
from chainermn.links import MultiNodeBatchNormalization
_chainermn_available = True
except (ImportError, TypeError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does TypeError occur?
If not, we can remove it.

_chainermn_available = False


class Conv2DBNActiv(chainer.Chain):
Expand All @@ -12,7 +19,10 @@ class Conv2DBNActiv(chainer.Chain):

The arguments are the same as that of
:class:`chainer.links.Convolution2D`
except for :obj:`activ` and :obj:`bn_kwargs`.
except for :obj:`activ`, :obj:`bn_kwargs`, and :obj:`comm`.
:obj:`comm` is a communicator of ChainerMN which is used for
:class:`chainermn.links.MultiNodeBatchNormalization`. If
:obj:`None` is given to the argument :obj:`comm`, :obj:`BatchNormalization` link from Chainer is used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:class:`chainer.links.BatchNormalization`

Note that the default value for the :obj:`nobias`
is changed to :obj:`True`.

Expand Down Expand Up @@ -43,6 +53,8 @@ class Conv2DBNActiv(chainer.Chain):
:obj:`stride=s` and :obj:`stride=(s, s)` are equivalent.
pad (int or pair of ints): Spatial padding width for input arrays.
:obj:`pad=p` and :obj:`pad=(p, p)` are equivalent.
dilate (int or pair of ints): Dilation factor of filter applications.
:obj:`dilate=d` and :obj:`dilate=(d, d)` are equivalent.
nobias (bool): If :obj:`True`,
then this link does not use the bias term.
initialW (4-D array): Initial weight value. If :obj:`None`, the default
Expand All @@ -56,23 +68,37 @@ class Conv2DBNActiv(chainer.Chain):
activ (callable): An activation function. The default value is
:func:`chainer.functions.relu`.
bn_kwargs (dict): Keyword arguments passed to initialize
:class:`chainer.links.BatchNormalization`.
:class:`chainer.links.BatchNormalization`. If a ChainerMN
communicator (:class:`~chainermn.communicators.CommunicatorBase)
is given with the key :obj:`comm`,
:obj:`~chainermn.links.MultiNodeBatchNormalization` will be used
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:class:`chainermn.links.MultiNodeBatchNormalization`

for the batch normalization. Otherwise,
:obj:`~chainer.links.BatchNormalization` will be used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to take comm as an element of bn_kwargs because comm is used only for batchnorm.
If 'comm' in bn_kwargs, it use MultiNodeBatchNormalization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I changed it


"""

def __init__(self, in_channels, out_channels, ksize=None,
stride=1, pad=0, nobias=True, initialW=None,
initial_bias=None, activ=relu, bn_kwargs=dict()):
stride=1, pad=0, dilate=1, nobias=True, initialW=None,
initial_bias=None, activ=relu, bn_kwargs=dict(), comm=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove comm option?

if ksize is None:
out_channels, ksize, in_channels = in_channels, out_channels, None

self.activ = activ
super(Conv2DBNActiv, self).__init__()
with self.init_scope():
self.conv = Convolution2D(
in_channels, out_channels, ksize, stride, pad,
nobias, initialW, initial_bias)
self.bn = BatchNormalization(out_channels, **bn_kwargs)
if dilate > 1:
self.conv = DilatedConvolution2D(
in_channels, out_channels, ksize, stride, pad, dilate,
nobias, initialW, initial_bias)
else:
self.conv = Convolution2D(
in_channels, out_channels, ksize, stride, pad,
nobias, initialW, initial_bias)
if 'comm' in bn_kwargs and _chainermn_available:
self.bn = MultiNodeBatchNormalization(
out_channels, [**bn_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ ?

else:
self.bn = BatchNormalization(out_channels, **bn_kwargs)

def __call__(self, x):
h = self.conv(x)
Expand Down
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ name: chainercv
channels:
- !!python/unicode
'menpo'
- !!python/unicode
'mpi4py'
- !!python/unicode
'defaults'
dependencies:
Expand All @@ -10,3 +12,4 @@ dependencies:
- matplotlib
- numpy
- Pillow
- openmpi
16 changes: 9 additions & 7 deletions tests/links_tests/connection_tests/test_conv_2d_activ.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def _add_one(x):


@testing.parameterize(*testing.product({
'dilate': [1, 2],
'args_style': ['explicit', 'None', 'omit'],
'activ': ['relu', 'add_one']
}))
Expand Down Expand Up @@ -44,19 +45,19 @@ def setUp(self):
if self.args_style == 'explicit':
self.l = Conv2DActiv(
self.in_channels, self.out_channels, self.ksize,
self.stride, self.pad,
self.stride, self.pad, self.dilate,
initialW=initialW, initial_bias=initial_bias,
activ=activ)
elif self.args_style == 'None':
self.l = Conv2DActiv(
None, self.out_channels, self.ksize, self.stride, self.pad,
initialW=initialW, initial_bias=initial_bias,
self.dilate, initialW=initialW, initial_bias=initial_bias,
activ=activ)
elif self.args_style == 'omit':
self.l = Conv2DActiv(
self.out_channels, self.ksize, stride=self.stride,
pad=self.pad, initialW=initialW, initial_bias=initial_bias,
activ=activ)
pad=self.pad, dilate=self.dilate, initialW=initialW,
initial_bias=initial_bias, activ=activ)

def check_forward(self, x_data):
x = chainer.Variable(x_data)
Expand All @@ -65,12 +66,13 @@ def check_forward(self, x_data):
self.assertIsInstance(y, chainer.Variable)
self.assertIsInstance(y.array, self.l.xp.ndarray)

_x_data = x_data if self.dilate == 1 else x_data[:, :, 1:-1, 1:-1]
if self.activ == 'relu':
np.testing.assert_almost_equal(
cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(x_data), 0))
cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(_x_data), 0))
elif self.activ == 'add_one':
np.testing.assert_almost_equal(
cuda.to_cpu(y.array), cuda.to_cpu(x_data) + 1)
cuda.to_cpu(y.array), cuda.to_cpu(_x_data) + 1)

def test_forward_cpu(self):
self.check_forward(self.x)
Expand All @@ -83,7 +85,7 @@ def test_forward_gpu(self):
def check_backward(self, x_data, y_grad):
x = chainer.Variable(x_data)
y = self.l(x)
y.grad = y_grad
y.grad = y_grad if self.dilate == 1 else y_grad[:, :, 1:-1, 1:-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

if self.dilate == 1:
     y.grad = y_grad
elif self.dilate == 2:
     y.grad = y_grad[:, :, 1:-1, 1:-1]

This is better because y_grad[:, :, 1:-1, 1:-1] would not work when dilate > 2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I fixed it, but there's no test case here for dilate > 2.

y.backward()

def test_backward_cpu(self):
Expand Down
91 changes: 84 additions & 7 deletions tests/links_tests/connection_tests/test_conv_2d_bn_activ.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ def _add_one(x):
return x + 1


try:
from chainermn import create_communicator
_chainermn_available = True
except (ImportError, TypeError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove TypeError if not necessary.

_chainermn_available = False


@testing.parameterize(*testing.product({
'dilate': [1, 2],
'args_style': ['explicit', 'None', 'omit'],
'activ': ['relu', 'add_one'],
}))
Expand Down Expand Up @@ -45,19 +53,19 @@ def setUp(self):
if self.args_style == 'explicit':
self.l = Conv2DBNActiv(
self.in_channels, self.out_channels, self.ksize,
self.stride, self.pad,
self.stride, self.pad, self.dilate,
initialW=initialW, initial_bias=initial_bias,
activ=activ, bn_kwargs=bn_kwargs)
elif self.args_style == 'None':
self.l = Conv2DBNActiv(
None, self.out_channels, self.ksize, self.stride, self.pad,
initialW=initialW, initial_bias=initial_bias,
self.dilate, initialW=initialW, initial_bias=initial_bias,
activ=activ, bn_kwargs=bn_kwargs)
elif self.args_style == 'omit':
self.l = Conv2DBNActiv(
self.out_channels, self.ksize, stride=self.stride,
pad=self.pad, initialW=initialW, initial_bias=initial_bias,
activ=activ, bn_kwargs=bn_kwargs)
pad=self.pad, dilate=self.dilate, initialW=initialW,
initial_bias=initial_bias, activ=activ, bn_kwargs=bn_kwargs)

def check_forward(self, x_data):
x = chainer.Variable(x_data)
Expand All @@ -70,14 +78,15 @@ def check_forward(self, x_data):
self.assertIsInstance(y, chainer.Variable)
self.assertIsInstance(y.array, self.l.xp.ndarray)

_x_data = x_data if self.dilate == 1 else x_data[:, :, 1:-1, 1:-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if self.activ == 'relu':
np.testing.assert_almost_equal(
cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(x_data), 0),
cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(_x_data), 0),
decimal=4
)
elif self.activ == 'add_one':
np.testing.assert_almost_equal(
cuda.to_cpu(y.array), cuda.to_cpu(x_data) + 1,
cuda.to_cpu(y.array), cuda.to_cpu(_x_data) + 1,
decimal=4
)

Expand All @@ -92,7 +101,7 @@ def test_forward_gpu(self):
def check_backward(self, x_data, y_grad):
x = chainer.Variable(x_data)
y = self.l(x)
y.grad = y_grad
y.grad = y_grad if self.dilate == 1 else y_grad[:, :, 1:-1, 1:-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

y.backward()

def test_backward_cpu(self):
Expand All @@ -104,4 +113,72 @@ def test_backward_gpu(self):
self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy))


@unittest.skipIf(not _chainermn_available, 'ChainerMN is not installed')
class TestConv2DMultiNodeBNActiv(unittest.TestCase):

in_channels = 1
out_channels = 1
ksize = 3
stride = 1
pad = 1
dilate = 1

def setUp(self):
self.x = np.random.uniform(
-1, 1, (5, self.in_channels, 5, 5)).astype(np.float32)
self.gy = np.random.uniform(
-1, 1, (5, self.out_channels, 5, 5)).astype(np.float32)

# Convolution is the identity function.
initialW = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]],
dtype=np.float32).reshape((1, 1, 3, 3))
bn_kwargs = {'decay': 0.8}
initial_bias = 0
comm = create_communicator('naive')
activ = relu
self.l = Conv2DBNActiv(
self.in_channels, self.out_channels, self.ksize, self.stride,
self.pad, self.dilate, initialW=initialW,
initial_bias=initial_bias, activ=activ, bn_kwargs=bn_kwargs,
comm=comm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include comm in bn_kwards.


def check_forward(self, x_data):
x = chainer.Variable(x_data)
# Make the batch normalization to be the identity function.
self.l.bn.avg_var[:] = 1
self.l.bn.avg_mean[:] = 0
with chainer.using_config('train', False):
y = self.l(x)

self.assertIsInstance(y, chainer.Variable)
self.assertIsInstance(y.array, self.l.xp.ndarray)

np.testing.assert_almost_equal(
cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(x_data), 0),
decimal=4
)

def test_multi_node_bach_normalization_forward_cpu(self):
self.check_forward(self.x)

@attr.gpu
def test_multi_node_bach_normalization_forward_gpu(self):
self.l.to_gpu()
self.check_forward(cuda.to_gpu(self.x))

def check_backward(self, x_data, y_grad):
x = chainer.Variable(x_data)
y = self.l(x)
y.grad = y_grad
y.backward()

def test_multi_node_bach_normalization_backward_cpu(self):
self.check_backward(self.x, self.gy)

@attr.gpu
def test_multi_node_bach_normalization_backward_gpu(self):
self.l.to_gpu()
self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy))


testing.run_module(__name__, __file__)