diff --git a/chainercv/links/__init__.py b/chainercv/links/__init__.py index 655a799fdb..17eddff0ac 100644 --- a/chainercv/links/__init__.py +++ b/chainercv/links/__init__.py @@ -9,3 +9,4 @@ from chainercv.links.model.segnet.segnet_basic import SegNetBasic # NOQA from chainercv.links.model.ssd import SSD300 # NOQA from chainercv.links.model.ssd import SSD512 # NOQA +from chainercv.links.model.vgg import VGG16 # NOQA diff --git a/chainercv/links/model/faster_rcnn/__init__.py b/chainercv/links/model/faster_rcnn/__init__.py index f4e5faaa73..9f39e1c44f 100644 --- a/chainercv/links/model/faster_rcnn/__init__.py +++ b/chainercv/links/model/faster_rcnn/__init__.py @@ -1,7 +1,6 @@ from chainercv.links.model.faster_rcnn.faster_rcnn import FasterRCNN # NOQA from chainercv.links.model.faster_rcnn.faster_rcnn_train_chain import FasterRCNNTrainChain # NOQA from chainercv.links.model.faster_rcnn.faster_rcnn_vgg import FasterRCNNVGG16 # NOQA -from chainercv.links.model.faster_rcnn.faster_rcnn_vgg import VGG16FeatureExtractor # NOQA from chainercv.links.model.faster_rcnn.faster_rcnn_vgg import VGG16RoIHead # NOQA from chainercv.links.model.faster_rcnn.region_proposal_network import RegionProposalNetwork # NOQA from chainercv.links.model.faster_rcnn.utils.anchor_target_creator import AnchorTargetCreator # NOQA diff --git a/chainercv/links/model/faster_rcnn/faster_rcnn_vgg.py b/chainercv/links/model/faster_rcnn/faster_rcnn_vgg.py index c5bb2e157c..dd374e61bd 100644 --- a/chainercv/links/model/faster_rcnn/faster_rcnn_vgg.py +++ b/chainercv/links/model/faster_rcnn/faster_rcnn_vgg.py @@ -1,14 +1,13 @@ -import collections import numpy as np import chainer import chainer.functions as F import chainer.links as L -from chainer.links import VGG16Layers from chainercv.links.model.faster_rcnn.faster_rcnn import FasterRCNN from chainercv.links.model.faster_rcnn.region_proposal_network import \ RegionProposalNetwork +from chainercv.links.model.vgg.vgg16 import VGG16 from chainercv.utils import download_model @@ -74,7 +73,8 @@ class FasterRCNNVGG16(FasterRCNN): 'voc07': { 'n_fg_class': 20, 'url': 'https://github.com/yuyu2172/share-weights/releases/' - 'download/0.0.3/faster_rcnn_vgg16_voc07_2017_06_06.npz' + 'download/0.0.4/' + 'faster_rcnn_vgg16_voc07_trained_2017_08_06.npz' } } feat_stride = 16 @@ -103,7 +103,10 @@ def __init__(self, if vgg_initialW is None and pretrained_model: vgg_initialW = chainer.initializers.constant.Zero() - extractor = VGG16FeatureExtractor(initialW=vgg_initialW) + extractor = VGG16(initialW=vgg_initialW) + extractor.feature_names = 'conv5_3' + # Delete all layers after conv5_3. + extractor.remove_unused() rpn = RegionProposalNetwork( 512, 512, ratios=ratios, @@ -139,12 +142,8 @@ def __init__(self, chainer.serializers.load_npz(pretrained_model, self) def _copy_imagenet_pretrained_vgg16(self): - pretrained_model = VGG16Layers() + pretrained_model = VGG16(pretrained_model='imagenet') self.extractor.conv1_1.copyparams(pretrained_model.conv1_1) - # The pretrained weights are trained to accept BGR images. - # Convert weights so that they accept RGB images. - self.extractor.conv1_1.W.data[:] =\ - self.extractor.conv1_1.W.data[:, ::-1] self.extractor.conv1_2.copyparams(pretrained_model.conv1_2) self.extractor.conv2_1.copyparams(pretrained_model.conv2_1) self.extractor.conv2_2.copyparams(pretrained_model.conv2_2) @@ -225,75 +224,8 @@ def __call__(self, x, rois, roi_indices): return roi_cls_locs, roi_scores -class VGG16FeatureExtractor(chainer.Chain): - """Truncated VGG-16 that extracts a conv5_3 feature map. - - Args: - initialW (callable): Initializer for the weights. - - """ - - def __init__(self, initialW=None): - super(VGG16FeatureExtractor, self).__init__() - with self.init_scope(): - self.conv1_1 = L.Convolution2D(3, 64, 3, 1, 1, initialW=initialW) - self.conv1_2 = L.Convolution2D(64, 64, 3, 1, 1, initialW=initialW) - self.conv2_1 = L.Convolution2D(64, 128, 3, 1, 1, initialW=initialW) - self.conv2_2 = L.Convolution2D( - 128, 128, 3, 1, 1, initialW=initialW) - self.conv3_1 = L.Convolution2D( - 128, 256, 3, 1, 1, initialW=initialW) - self.conv3_2 = L.Convolution2D( - 256, 256, 3, 1, 1, initialW=initialW) - self.conv3_3 = L.Convolution2D( - 256, 256, 3, 1, 1, initialW=initialW) - self.conv4_1 = L.Convolution2D( - 256, 512, 3, 1, 1, initialW=initialW) - self.conv4_2 = L.Convolution2D( - 512, 512, 3, 1, 1, initialW=initialW) - self.conv4_3 = L.Convolution2D( - 512, 512, 3, 1, 1, initialW=initialW) - self.conv5_1 = L.Convolution2D( - 512, 512, 3, 1, 1, initialW=initialW) - self.conv5_2 = L.Convolution2D( - 512, 512, 3, 1, 1, initialW=initialW) - self.conv5_3 = L.Convolution2D( - 512, 512, 3, 1, 1, initialW=initialW) - - self.functions = collections.OrderedDict([ - ('conv1_1', [self.conv1_1, F.relu]), - ('conv1_2', [self.conv1_2, F.relu]), - ('pool1', [_max_pooling_2d]), - ('conv2_1', [self.conv2_1, F.relu]), - ('conv2_2', [self.conv2_2, F.relu]), - ('pool2', [_max_pooling_2d]), - ('conv3_1', [self.conv3_1, F.relu]), - ('conv3_2', [self.conv3_2, F.relu]), - ('conv3_3', [self.conv3_3, F.relu]), - ('pool3', [_max_pooling_2d]), - ('conv4_1', [self.conv4_1, F.relu]), - ('conv4_2', [self.conv4_2, F.relu]), - ('conv4_3', [self.conv4_3, F.relu]), - ('pool4', [_max_pooling_2d]), - ('conv5_1', [self.conv5_1, F.relu]), - ('conv5_2', [self.conv5_2, F.relu]), - ('conv5_3', [self.conv5_3, F.relu]), - ]) - - def __call__(self, x): - h = x - for key, funcs in self.functions.items(): - for func in funcs: - h = func(h) - return h - - def _roi_pooling_2d_yx(x, indices_and_rois, outh, outw, spatial_scale): xy_indices_and_rois = indices_and_rois[:, [0, 2, 1, 4, 3]] pool = F.roi_pooling_2d( x, xy_indices_and_rois, outh, outw, spatial_scale) return pool - - -def _max_pooling_2d(x): - return F.max_pooling_2d(x, ksize=2) diff --git a/chainercv/links/model/vgg/__init__.py b/chainercv/links/model/vgg/__init__.py new file mode 100644 index 0000000000..40faafabb1 --- /dev/null +++ b/chainercv/links/model/vgg/__init__.py @@ -0,0 +1 @@ +from chainercv.links.model.vgg.vgg16 import VGG16 # NOQA diff --git a/chainercv/links/model/vgg/vgg16.py b/chainercv/links/model/vgg/vgg16.py new file mode 100644 index 0000000000..bd43f102e5 --- /dev/null +++ b/chainercv/links/model/vgg/vgg16.py @@ -0,0 +1,164 @@ +from __future__ import division + +import numpy as np + +import chainer +from chainer.functions import dropout +from chainer.functions import max_pooling_2d +from chainer.functions import relu +from chainer.functions import softmax +from chainer.initializers import constant +from chainer.initializers import normal + +from chainer.links import Linear + +from chainercv.utils import download_model + +from chainercv.links.connection.conv_2d_activ import Conv2DActiv +from chainercv.links.model.sequential_feature_extractor import \ + SequentialFeatureExtractor + + +# RGB order +_imagenet_mean = np.array( + [123.68, 116.779, 103.939], dtype=np.float32)[:, np.newaxis, np.newaxis] + + +class VGG16(SequentialFeatureExtractor): + + """VGG-16 Network for classification and feature extraction. + + This is a feature extraction model. + The network can choose output features from set of all + intermediate features. + The value of :obj:`VGG16.feature_names` selects the features that are going + to be collected by :meth:`__call__`. + :obj:`self.all_feature_names` is the list of the names of features + that can be collected. + + Examples: + + >>> model = VGG16() + # By default, __call__ returns a probability score (after Softmax). + >>> prob = model(imgs) + + >>> model.feature_names = 'conv5_3' + # This is feature conv5_3 (after ReLU). + >>> feat5_3 = model(imgs) + + >>> model.feature_names = ['conv5_3', 'fc6'] + >>> # These are features conv5_3 (after ReLU) and fc6 (before ReLU). + >>> feat5_3, feat6 = model(imgs) + + .. seealso:: + :class:`chainercv.links.model.SequentialFeatureExtractor` + + When :obj:`pretrained_model` is the path of a pre-trained chainer model + serialized as a :obj:`.npz` file in the constructor, this chain model + automatically initializes all the parameters with it. + When a string in the prespecified set is provided, a pretrained model is + loaded from weights distributed on the Internet. + The list of pretrained models supported are as follows: + + * :obj:`imagenet`: Loads weights trained with ImageNet and distributed \ + at `Model Zoo \ + `_. + + Args: + pretrained_model (str): The destination of the pre-trained + chainer model serialized as a :obj:`.npz` file. + If this is one of the strings described + above, it automatically loads weights stored under a directory + :obj:`$CHAINER_DATASET_ROOT/pfnet/chainercv/models/`, + where :obj:`$CHAINER_DATASET_ROOT` is set as + :obj:`$HOME/.chainer/dataset` unless you specify another value + by modifying the environment variable. + n_class (int): The number of classes. If :obj:`None`, + the default values are used. + If a supported pretrained model is used, + the number of classes used to train the pretrained model + is used. Otherwise, the number of classes in ILSVRC 2012 dataset + is used. + mean (numpy.ndarray): A mean value. If :obj:`None`, + the default values are used. + If a supported pretrained model is used, + the mean value used to train the pretrained model is used. + Otherwise, the mean value calculated from ILSVRC 2012 dataset + is used. + initialW (callable): Initializer for the weights. + initial_bias (callable): Initializer for the biases. + + """ + + _models = { + 'imagenet': { + 'n_class': 1000, + 'url': 'https://github.com/yuyu2172/share-weights/releases/' + 'download/0.0.4/vgg16_imagenet_convert_2017_07_18.npz', + 'mean': _imagenet_mean + } + } + + def __init__(self, + pretrained_model=None, n_class=None, mean=None, + initialW=None, initial_bias=None): + if n_class is None: + if pretrained_model in self._models: + n_class = self._models[pretrained_model]['n_class'] + else: + n_class = 1000 + + if mean is None: + if pretrained_model in self._models: + mean = self._models[pretrained_model]['mean'] + else: + mean = _imagenet_mean + self.mean = mean + + if initialW is None: + # Employ default initializers used in the original paper. + initialW = normal.Normal(0.01) + if pretrained_model: + # As a sampling process is time-consuming, + # we employ a zero initializer for faster computation. + initialW = constant.Zero() + kwargs = {'initialW': initialW, 'initial_bias': initial_bias} + + super(VGG16, self).__init__() + with self.init_scope(): + self.conv1_1 = Conv2DActiv(None, 64, 3, 1, 1, **kwargs) + self.conv1_2 = Conv2DActiv(None, 64, 3, 1, 1, **kwargs) + self.pool1 = _max_pooling_2d + self.conv2_1 = Conv2DActiv(None, 128, 3, 1, 1, **kwargs) + self.conv2_2 = Conv2DActiv(None, 128, 3, 1, 1, **kwargs) + self.pool2 = _max_pooling_2d + self.conv3_1 = Conv2DActiv(None, 256, 3, 1, 1, **kwargs) + self.conv3_2 = Conv2DActiv(None, 256, 3, 1, 1, **kwargs) + self.conv3_3 = Conv2DActiv(None, 256, 3, 1, 1, **kwargs) + self.pool3 = _max_pooling_2d + self.conv4_1 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) + self.conv4_2 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) + self.conv4_3 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) + self.pool4 = _max_pooling_2d + self.conv5_1 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) + self.conv5_2 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) + self.conv5_3 = Conv2DActiv(None, 512, 3, 1, 1, **kwargs) + self.pool5 = _max_pooling_2d + self.fc6 = Linear(None, 4096, **kwargs) + self.fc6_relu = relu + self.fc6_dropout = dropout + self.fc7 = Linear(None, 4096, **kwargs) + self.fc7_relu = relu + self.fc7_dropout = dropout + self.fc8 = Linear(None, n_class, **kwargs) + self.prob = softmax + + if pretrained_model in self._models: + path = download_model(self._models[pretrained_model]['url']) + chainer.serializers.load_npz(path, self) + elif pretrained_model: + chainer.serializers.load_npz(pretrained_model, self) + + +def _max_pooling_2d(x): + return max_pooling_2d(x, ksize=2) diff --git a/docs/source/reference/links.rst b/docs/source/reference/links.rst index 5a7136ad97..e1f59d820f 100644 --- a/docs/source/reference/links.rst +++ b/docs/source/reference/links.rst @@ -1,10 +1,28 @@ Links ===== -.. module:: chainercv.links.model.faster_rcnn + +Model +----- + + +Feature Extraction +~~~~~~~~~~~~~~~~~~ +Feature extraction models extract feature(s) from given images. + +.. toctree:: + + links/vgg + + +.. autoclass:: chainercv.links.SequentialFeatureExtractor + :members: + +.. autoclass:: chainercv.links.FeaturePredictor + Detection ---------- +~~~~~~~~~ Detection links share a common method :meth:`predict` to detect objects in images. For more details, please read :func:`FasterRCNN.predict`. @@ -16,7 +34,7 @@ For more details, please read :func:`FasterRCNN.predict`. Semantic Segmentation ---------------------- +~~~~~~~~~~~~~~~~~~~~~ .. module:: chainercv.links.model.segnet @@ -29,7 +47,7 @@ For more details, please read :func:`SegNetBasic.predict`. Classifiers ------------ +~~~~~~~~~~~ .. toctree:: diff --git a/docs/source/reference/links/faster_rcnn.rst b/docs/source/reference/links/faster_rcnn.rst index 8a50ece3c9..bec09c5675 100644 --- a/docs/source/reference/links/faster_rcnn.rst +++ b/docs/source/reference/links/faster_rcnn.rst @@ -45,10 +45,6 @@ RegionProposalNetwork :members: :special-members: __call__ -VGG16FeatureExtractor -~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: VGG16FeatureExtractor - VGG16RoIHead ~~~~~~~~~~~~ .. autoclass:: VGG16RoIHead diff --git a/docs/source/reference/links/vgg.rst b/docs/source/reference/links/vgg.rst new file mode 100644 index 0000000000..201aefca12 --- /dev/null +++ b/docs/source/reference/links/vgg.rst @@ -0,0 +1,11 @@ +VGG +=== + +.. module:: chainercv.links.model.vgg + + +VGG16 +----- + +.. autoclass:: VGG16 + :members: diff --git a/examples/classification/README.md b/examples/classification/README.md new file mode 100644 index 0000000000..874e3c2ef8 --- /dev/null +++ b/examples/classification/README.md @@ -0,0 +1,43 @@ +# Classification + +## Performance + +| Model | Top 1 Error (single crop) | Reference Top 1 Error (single crop) | +|:-:|:-:|:-:| +| VGG16 | 29.0 % | 28.5 % [1] | + +The results can be reproduced by the following command. +The score is reported using a weight converted from a weight trained by Caffe. + +``` +$ python eval_imagenet.py [--model vgg16] [--pretrained_model ] [--batchsize ] [--gpu ] [--crop center|10] +``` + + +## How to prepare ImageNet Dataset + +This instructions are based on the instruction found [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset). + +The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) dataset has 1000 categories and 1.2 million images. The images do not need to be preprocessed or packaged in any database, but the validation images need to be moved into appropriate subfolders. + +1. Download the images from http://image-net.org/download-images + +2. Extract the training data: + ```bash + $ mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train + $ tar -xvf ILSVRC2012_img_train.tar && mv ILSVRC2012_img_train.tar .. + $ find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done + $ cd .. + ``` + +3. Extract the validation data and move images to subfolders: + ```bash + $ mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar + $ wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash + $ mv ILSVRC2012_img_val.tar .. && cd .. + ``` + + +## References + +1. Karen Simonyan, Andrew Zisserman. "Very Deep Convolutional Networks for Large-Scale Image Recognition" ICLR 2015 diff --git a/examples/classification/eval_imagenet.py b/examples/classification/eval_imagenet.py new file mode 100644 index 0000000000..27483a809d --- /dev/null +++ b/examples/classification/eval_imagenet.py @@ -0,0 +1,79 @@ +import argparse +import sys +import time + +import numpy as np + +import chainer +import chainer.functions as F +from chainer import iterators + +from chainercv.datasets import directory_parsing_label_names +from chainercv.datasets import DirectoryParsingClassificationDataset +from chainercv.links import FeaturePredictor +from chainercv.links import VGG16 + +from chainercv.utils import apply_prediction_to_iterator + + +class ProgressHook(object): + + def __init__(self, n_total): + self.n_total = n_total + self.start = time.time() + self.n_processed = 0 + + def __call__(self, imgs, pred_values, gt_values): + self.n_processed += len(imgs) + fps = self.n_processed / (time.time() - self.start) + sys.stdout.write( + '\r{:d} of {:d} images, {:.2f} FPS'.format( + self.n_processed, self.n_total, fps)) + sys.stdout.flush() + + +def main(): + chainer.config.train = False + + parser = argparse.ArgumentParser( + description='Learning convnet from ILSVRC2012 dataset') + parser.add_argument('val', help='Path to root of the validation dataset') + parser.add_argument('--model', choices=('vgg16',)) + parser.add_argument('--pretrained_model', default='imagenet') + parser.add_argument('--gpu', type=int, default=-1) + parser.add_argument('--batchsize', type=int, default=32) + parser.add_argument('--crop', choices=('center', '10'), default='center') + args = parser.parse_args() + + dataset = DirectoryParsingClassificationDataset(args.val) + label_names = directory_parsing_label_names(args.val) + iterator = iterators.MultiprocessIterator( + dataset, args.batchsize, repeat=False, shuffle=False, + n_processes=6, shared_mem=300000000) + + if args.model == 'vgg16': + extractor = VGG16(pretrained_model=args.pretrained_model, + n_class=len(label_names)) + model = FeaturePredictor( + extractor, crop_size=224, scale_size=256, crop=args.crop) + + if args.gpu >= 0: + chainer.cuda.get_device(args.gpu).use() + model.to_gpu() + + print('Model has been prepared. Evaluation starts.') + imgs, pred_values, gt_values = apply_prediction_to_iterator( + model.predict, iterator, hook=ProgressHook(len(dataset))) + del imgs + + pred_probs, = pred_values + gt_probs, = gt_values + + accuracy = F.accuracy( + np.array(list(pred_probs)), np.array(list(gt_probs))).data + print() + print('Top 1 Error {}'.format(1. - accuracy)) + + +if __name__ == '__main__': + main() diff --git a/examples/vgg/README.md b/examples/vgg/README.md new file mode 100644 index 0000000000..fd07e4d451 --- /dev/null +++ b/examples/vgg/README.md @@ -0,0 +1,13 @@ +# VGG + +For evaluation, please go to [`examples/classification`](https://github.com/chainer/chainercv/tree/master/examples). + +## Convert Caffe model +Convert `*.caffemodel` to `*.npz`. + +``` +$ python caffe2npz.py .caffemodel .npz +``` + +The pretrained `.caffemodel` for VGG-16 can be downloaded from here. +http://www.robots.ox.ac.uk/%7Evgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel diff --git a/examples/vgg/caffe2npz.py b/examples/vgg/caffe2npz.py new file mode 100644 index 0000000000..34ae81ee1c --- /dev/null +++ b/examples/vgg/caffe2npz.py @@ -0,0 +1,57 @@ +import argparse +import re + +import chainer +from chainer import Link +import chainer.links.caffe.caffe_function as caffe + + +""" +Please download a weight from here. +http://www.robots.ox.ac.uk/%7Evgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel +""" + + +def rename(name): + m = re.match(r'conv(\d+)_(\d+)$', name) + if m: + i, j = map(int, m.groups()) + return 'conv{:d}_{:d}/conv'.format(i, j) + + return name + + +class VGGCaffeFunction(caffe.CaffeFunction): + + def __init__(self, model_path): + print('loading weights from {:s} ... '.format(model_path)) + super(VGGCaffeFunction, self).__init__(model_path) + + def __setattr__(self, name, value): + if self.within_init_scope and isinstance(value, Link): + new_name = rename(name) + + if new_name == 'conv1_1/conv': + # BGR -> RGB + value.W.data[:, ::-1] = value.W.data + print('{:s} -> {:s} (BGR -> RGB)'.format(name, new_name)) + else: + print('{:s} -> {:s}'.format(name, new_name)) + else: + new_name = name + + super(VGGCaffeFunction, self).__setattr__(new_name, value) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('caffemodel') + parser.add_argument('output') + args = parser.parse_args() + + model = VGGCaffeFunction(args.caffemodel) + chainer.serializers.save_npz(args.output, model) + + +if __name__ == '__main__': + main() diff --git a/tests/links_tests/model_tests/vgg_tests/test_vgg16.py b/tests/links_tests/model_tests/vgg_tests/test_vgg16.py new file mode 100644 index 0000000000..8fb6718572 --- /dev/null +++ b/tests/links_tests/model_tests/vgg_tests/test_vgg16.py @@ -0,0 +1,50 @@ +import unittest + +import numpy as np + +from chainer.initializers import Zero +from chainer import testing +from chainer.testing import attr +from chainer import Variable + +from chainercv.links import VGG16 + + +@testing.parameterize( + {'feature_names': 'prob', 'shapes': (1, 200), 'n_class': 200}, + {'feature_names': 'pool5', 'shapes': (1, 512, 7, 7), 'n_class': None}, + {'feature_names': ['conv5_3', 'conv4_2'], + 'shapes': ((1, 512, 14, 14), (1, 512, 28, 28)), 'n_class': None}, +) +@attr.slow +class TestVGG16Call(unittest.TestCase): + + def setUp(self): + self.link = VGG16( + pretrained_model=None, n_class=self.n_class, + initialW=Zero()) + self.link.feature_names = self.feature_names + + def check_call(self): + xp = self.link.xp + + x1 = Variable(xp.asarray(np.random.uniform( + -1, 1, (1, 3, 224, 224)).astype(np.float32))) + features = self.link(x1) + if isinstance(features, tuple): + for activation, shape in zip(features, self.shapes): + self.assertEqual(activation.shape, shape) + else: + self.assertEqual(features.shape, self.shapes) + self.assertEqual(features.dtype, np.float32) + + def test_call_cpu(self): + self.check_call() + + @attr.gpu + def test_call_gpu(self): + self.link.to_gpu() + self.check_call() + + +testing.run_module(__name__, __file__)