Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

Chainer 3.0.0 #445

Merged
merged 19 commits into from
Nov 28, 2017
9 changes: 5 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -5,10 +5,10 @@ python:
- "2.7"
- "3.6"
env:
- CHAINER_VERSION=">=2.0" OPTIONAL_MODULES=0
- CHAINER_VERSION=">=2.0" OPTIONAL_MODULES=1
- CHAINER_VERSION="==3.0.0rc1" OPTIONAL_MODULES=0
- CHAINER_VERSION="==3.0.0rc1" OPTIONAL_MODULES=1
- CHAINER_VERSION=">=3.0" OPTIONAL_MODULES=0
- CHAINER_VERSION=">=3.0" OPTIONAL_MODULES=1
- CHAINER_VERSION="==4.0.0b1" OPTIONAL_MODULES=0
- CHAINER_VERSION="==4.0.0b1" OPTIONAL_MODULES=1
notifications:
email: false

@@ -44,6 +44,7 @@ script:
- pip install autopep8
- pip install mock
- pip install nose
- pip install pytest
- flake8 .
- autopep8 -r . | tee check_autopep8
- test ! -s check_autopep8
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -43,8 +43,10 @@ For additional features

Environments under Python 2.7.12 and 3.6.0 are tested.

+ The master branch will work on both the stable version (v2) and the development version (v3).
+ For users using Chainer v1, please use version `0.4.11`, which can be installed by `pip install chainercv==0.4.11`. This branch is unmaintained.
+ The master branch is designed to work on Chainer v3 (the stable version) and v4 (the development version).
+ The following branches are kept for the previous version of Chainer. Note that these branches are unmaintained.
+ `0.4.11` (for Chainer v1). It can be installed by `pip install chainercv==0.4.11`.
+ `0.7` (for Chainer v2). It can be installed by `pip install chainercv==0.7`.

# Data Conventions

6 changes: 3 additions & 3 deletions chainercv/links/model/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -288,8 +288,8 @@ def predict(self, imgs):
roi_cls_locs, roi_scores, rois, _ = self.__call__(
img_var, scale=scale)
# We are assuming that batch size is 1.
roi_cls_loc = roi_cls_locs.data
roi_score = roi_scores.data
roi_cls_loc = roi_cls_locs.array
roi_score = roi_scores.array
roi = rois / scale

# Convert predictions to bounding boxes in image coordinates.
@@ -308,7 +308,7 @@ def predict(self, imgs):
cls_bbox[:, 0::2] = self.xp.clip(cls_bbox[:, 0::2], 0, size[0])
cls_bbox[:, 1::2] = self.xp.clip(cls_bbox[:, 1::2], 0, size[1])

prob = F.softmax(roi_score).data
prob = F.softmax(roi_score).array

raw_cls_bbox = cuda.to_cpu(cls_bbox)
raw_prob = cuda.to_cpu(prob)
8 changes: 4 additions & 4 deletions chainercv/links/model/faster_rcnn/faster_rcnn_train_chain.py
Original file line number Diff line number Diff line change
@@ -90,11 +90,11 @@ def __call__(self, imgs, bboxes, labels, scale):

"""
if isinstance(bboxes, chainer.Variable):
bboxes = bboxes.data
bboxes = bboxes.array
if isinstance(labels, chainer.Variable):
labels = labels.data
labels = labels.array
if isinstance(scale, chainer.Variable):
scale = scale.data
scale = scale.array
scale = np.asscalar(cuda.to_cpu(scale))
n = bboxes.shape[0]
if n != 1:
@@ -151,7 +151,7 @@ def _smooth_l1_loss(x, t, in_weight, sigma):
sigma2 = sigma ** 2
diff = in_weight * (x - t)
abs_diff = F.absolute(diff)
flag = (abs_diff.data < (1. / sigma2)).astype(np.float32)
flag = (abs_diff.array < (1. / sigma2)).astype(np.float32)

y = (flag * (sigma2 / 2.) * F.square(diff) +
(1 - flag) * (abs_diff - 0.5 / sigma2))
Original file line number Diff line number Diff line change
@@ -126,7 +126,7 @@ def __call__(self, x, img_size, scale=1.):
roi_indices = list()
for i in range(n):
roi = self.proposal_layer(
rpn_locs[i].data, rpn_fg_scores[i].data, anchor, img_size,
rpn_locs[i].array, rpn_fg_scores[i].array, anchor, img_size,
scale=scale)
batch_index = i * self.xp.ones((len(roi),), dtype=np.int32)
rois.append(roi)
4 changes: 2 additions & 2 deletions chainercv/links/model/feature_predictor.py
Original file line number Diff line number Diff line change
@@ -156,13 +156,13 @@ def predict(self, imgs):
if isinstance(features, tuple):
output = list()
for feature in features:
feature = feature.data
feature = feature.array
if n_crop > 1:
feature = self._average_crops(feature, n_crop)
output.append(cuda.to_cpu(feature))
output = tuple(output)
else:
output = cuda.to_cpu(features.data)
output = cuda.to_cpu(features.array)
if n_crop > 1:
output = self._average_crops(output, n_crop)

12 changes: 6 additions & 6 deletions chainercv/links/model/segnet/segnet_basic.py
Original file line number Diff line number Diff line change
@@ -10,9 +10,9 @@
from chainercv.utils import download_model


def _without_cudnn(f, x):
def _pool_without_cudnn(p, x):
with chainer.using_config('use_cudnn', 'never'):
return f(x)
return p.apply((x,))[0]


class SegNetBasic(chainer.Chain):
@@ -135,10 +135,10 @@ def __call__(self, x):
p3 = F.MaxPooling2D(2, 2)
p4 = F.MaxPooling2D(2, 2)
h = F.local_response_normalization(x, 5, 1, 1e-4 / 5., 0.75)
h = _without_cudnn(p1, F.relu(self.conv1_bn(self.conv1(h))))
h = _without_cudnn(p2, F.relu(self.conv2_bn(self.conv2(h))))
h = _without_cudnn(p3, F.relu(self.conv3_bn(self.conv3(h))))
h = _without_cudnn(p4, F.relu(self.conv4_bn(self.conv4(h))))
h = _pool_without_cudnn(p1, F.relu(self.conv1_bn(self.conv1(h))))
h = _pool_without_cudnn(p2, F.relu(self.conv2_bn(self.conv2(h))))
h = _pool_without_cudnn(p3, F.relu(self.conv3_bn(self.conv3(h))))
h = _pool_without_cudnn(p4, F.relu(self.conv4_bn(self.conv4(h))))
h = self._upsampling_2d(h, p4)
h = self.conv_decode4_bn(self.conv_decode4(h))
h = self._upsampling_2d(h, p3)
3 changes: 1 addition & 2 deletions chainercv/links/model/ssd/multibox_coder.py
Original file line number Diff line number Diff line change
@@ -163,8 +163,7 @@ def encode(self, bbox, label, iou_thresh=0.5):
masked_iou[:, j] = 0

mask = xp.logical_and(index < 0, iou.max(axis=1) >= iou_thresh)
if xp.count_nonzero(mask) > 0:
index[mask] = iou[mask].argmax(axis=1)
index[mask] = iou[mask].argmax(axis=1)

mb_bbox = bbox[index].copy()
# (y_min, x_min, y_max, x_max) -> (y_min, x_min, height, width)
27 changes: 10 additions & 17 deletions chainercv/links/model/ssd/multibox_loss.py
Original file line number Diff line number Diff line change
@@ -16,12 +16,9 @@ def _elementwise_softmax_cross_entropy(x, t):


def _hard_negative(x, positive, k):
xp = chainer.cuda.get_array_module(x, positive)
x = chainer.cuda.to_cpu(x)
positive = chainer.cuda.to_cpu(positive)
rank = (x * (positive - 1)).argsort(axis=1).argsort(axis=1)
hard_negative = rank < (positive.sum(axis=1) * k)[:, np.newaxis]
return xp.array(hard_negative)
return hard_negative


def multibox_loss(mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, k):
@@ -64,18 +61,14 @@ def multibox_loss(mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, k):
This function returns two :obj:`chainer.Variable`: :obj:`loc_loss` and
:obj:`conf_loss`.
"""
if not isinstance(mb_locs, chainer.Variable):
mb_locs = chainer.Variable(mb_locs)
if not isinstance(mb_confs, chainer.Variable):
mb_confs = chainer.Variable(mb_confs)
if not isinstance(gt_mb_locs, chainer.Variable):
gt_mb_locs = chainer.Variable(gt_mb_locs)
if not isinstance(gt_mb_labels, chainer.Variable):
gt_mb_labels = chainer.Variable(gt_mb_labels)

xp = chainer.cuda.get_array_module(gt_mb_labels.data)

positive = gt_mb_labels.data > 0
mb_locs = chainer.as_variable(mb_locs)
mb_confs = chainer.as_variable(mb_confs)
gt_mb_locs = chainer.as_variable(gt_mb_locs)
gt_mb_labels = chainer.as_variable(gt_mb_labels)

xp = chainer.cuda.get_array_module(gt_mb_labels.array)

positive = gt_mb_labels.array > 0
n_positive = positive.sum()
if n_positive == 0:
z = chainer.Variable(xp.zeros((), dtype=np.float32))
@@ -87,7 +80,7 @@ def multibox_loss(mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, k):
loc_loss = F.sum(loc_loss) / n_positive

conf_loss = _elementwise_softmax_cross_entropy(mb_confs, gt_mb_labels)
hard_negative = _hard_negative(conf_loss.data, positive, k)
hard_negative = _hard_negative(conf_loss.array, positive, k)
conf_loss *= xp.logical_or(positive, hard_negative).astype(conf_loss.dtype)
conf_loss = F.sum(conf_loss) / n_positive

2 changes: 1 addition & 1 deletion chainercv/links/model/ssd/ssd.py
Original file line number Diff line number Diff line change
@@ -203,7 +203,7 @@ def predict(self, imgs):
chainer.function.no_backprop_mode():
x = chainer.Variable(self.xp.stack(x))
mb_locs, mb_confs = self(x)
mb_locs, mb_confs = mb_locs.data, mb_confs.data
mb_locs, mb_confs = mb_locs.array, mb_confs.array

bboxes = list()
labels = list()
4 changes: 1 addition & 3 deletions chainercv/utils/bbox/non_maximum_suppression.py
Original file line number Diff line number Diff line change
@@ -105,9 +105,7 @@ def _non_maximum_suppression_gpu(bbox, thresh, score=None, limit=None):
n_bbox = bbox.shape[0]

if score is not None:
# CuPy does not currently support argsort.
order = cuda.to_cpu(score).argsort()[::-1].astype(np.int32)
order = cuda.to_gpu(order)
order = score.argsort()[::-1].astype(np.int32)
else:
order = cp.arange(n_bbox, dtype=np.int32)

19 changes: 1 addition & 18 deletions examples/faster_rcnn/train.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import numpy as np

import chainer
from chainer.datasets import ConcatenatedDataset
from chainer.datasets import TransformDataset
from chainer import training
from chainer.training import extensions
@@ -17,24 +18,6 @@
from chainercv import transforms


class ConcatenatedDataset(chainer.dataset.DatasetMixin):

def __init__(self, *datasets):
self._datasets = datasets

def __len__(self):
return sum(len(dataset) for dataset in self._datasets)

def get_example(self, i):
if i < 0:
raise IndexError
for dataset in self._datasets:
if i < len(dataset):
return dataset[i]
i -= len(dataset)
raise IndexError


class Transform(object):

def __init__(self, faster_rcnn):
6 changes: 3 additions & 3 deletions examples/ssd/caffe2npz.py
Original file line number Diff line number Diff line change
@@ -53,11 +53,11 @@ def __setattr__(self, name, value):

if new_name == 'extractor/conv1_1':
# BGR -> RGB
value.W.data[:, ::-1] = value.W.data
value.W.array[:, ::-1] = value.W.array
print('{:s} -> {:s} (BGR -> RGB)'.format(name, new_name))
elif new_name.startswith('multibox/loc/'):
# xy -> yx
for data in (value.W.data, value.b.data):
for data in (value.W.array, value.b.array):
data = data.reshape((-1, 4) + data.shape[1:])
data[:, [1, 0, 3, 2]] = data.copy()
print('{:s} -> {:s} (xy -> yx)'.format(name, new_name))
@@ -72,7 +72,7 @@ def __setattr__(self, name, value):
def _setup_normarize(self, layer):
blobs = layer.blobs
func = Normalize(caffe._get_num(blobs[0]))
func.scale.data[:] = np.array(blobs[0].data)
func.scale.array[:] = np.array(blobs[0].array)
with self.init_scope():
setattr(self, layer.name, func)

19 changes: 1 addition & 18 deletions examples/ssd/train.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import numpy as np

import chainer
from chainer.datasets import ConcatenatedDataset
from chainer.datasets import TransformDataset
from chainer.optimizer import WeightDecay
from chainer import serializers
@@ -24,24 +25,6 @@
from chainercv.links.model.ssd import resize_with_random_interpolation


class ConcatenatedDataset(chainer.dataset.DatasetMixin):

def __init__(self, *datasets):
self._datasets = datasets

def __len__(self):
return sum(len(dataset) for dataset in self._datasets)

def get_example(self, i):
if i < 0:
raise IndexError
for dataset in self._datasets:
if i < len(dataset):
return dataset[i]
i -= len(dataset)
raise IndexError


class MultiboxTrainChain(chainer.Chain):

def __init__(self, model, alpha=1, k=3):
2 changes: 1 addition & 1 deletion examples/vgg/caffe2npz.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ def __setattr__(self, name, value):

if new_name == 'conv1_1/conv':
# BGR -> RGB
value.W.data[:, ::-1] = value.W.data
value.W.array[:, ::-1] = value.W.array
print('{:s} -> {:s} (BGR -> RGB)'.format(name, new_name))
else:
print('{:s} -> {:s}'.format(name, new_name))
6 changes: 3 additions & 3 deletions tests/links_tests/connection_tests/test_conv_2d_activ.py
Original file line number Diff line number Diff line change
@@ -63,14 +63,14 @@ def check_forward(self, x_data):
y = self.l(x)

self.assertIsInstance(y, chainer.Variable)
self.assertIsInstance(y.data, self.l.xp.ndarray)
self.assertIsInstance(y.array, self.l.xp.ndarray)

if self.activ == 'relu':
np.testing.assert_almost_equal(
cuda.to_cpu(y.data), np.maximum(cuda.to_cpu(x_data), 0))
cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(x_data), 0))
elif self.activ == 'add_one':
np.testing.assert_almost_equal(
cuda.to_cpu(y.data), cuda.to_cpu(x_data) + 1)
cuda.to_cpu(y.array), cuda.to_cpu(x_data) + 1)

def test_forward_cpu(self):
self.check_forward(self.x)
6 changes: 3 additions & 3 deletions tests/links_tests/connection_tests/test_conv_2d_bn_activ.py
Original file line number Diff line number Diff line change
@@ -68,16 +68,16 @@ def check_forward(self, x_data):
y = self.l(x)

self.assertIsInstance(y, chainer.Variable)
self.assertIsInstance(y.data, self.l.xp.ndarray)
self.assertIsInstance(y.array, self.l.xp.ndarray)

if self.activ == 'relu':
np.testing.assert_almost_equal(
cuda.to_cpu(y.data), np.maximum(cuda.to_cpu(x_data), 0),
cuda.to_cpu(y.array), np.maximum(cuda.to_cpu(x_data), 0),
decimal=4
)
elif self.activ == 'add_one':
np.testing.assert_almost_equal(
cuda.to_cpu(y.data), cuda.to_cpu(x_data) + 1,
cuda.to_cpu(y.array), cuda.to_cpu(x_data) + 1,
decimal=4
)

Original file line number Diff line number Diff line change
@@ -39,11 +39,11 @@ def check_call(self):
roi_cls_locs, roi_scores, rois, roi_indices = self.link(x1)

self.assertIsInstance(roi_cls_locs, chainer.Variable)
self.assertIsInstance(roi_cls_locs.data, xp.ndarray)
self.assertIsInstance(roi_cls_locs.array, xp.ndarray)
self.assertEqual(roi_cls_locs.shape, (self.n_roi, self.n_class * 4))

self.assertIsInstance(roi_scores, chainer.Variable)
self.assertIsInstance(roi_scores.data, xp.ndarray)
self.assertIsInstance(roi_scores.array, xp.ndarray)
self.assertEqual(roi_scores.shape, (self.n_roi, self.n_class))

self.assertIsInstance(rois, xp.ndarray)
Loading