diff --git a/chainercv/experimental/links/model/pspnet/pspnet.py b/chainercv/experimental/links/model/pspnet/pspnet.py index 90df601d71..2ea76d216a 100644 --- a/chainercv/experimental/links/model/pspnet/pspnet.py +++ b/chainercv/experimental/links/model/pspnet/pspnet.py @@ -131,6 +131,9 @@ class PSPNet(chainer.Chain): * :obj:`'cityscapes'`: Load weights trained on the train split of \ Cityscapes dataset. \ :obj:`n_class` must be :obj:`19` or :obj:`None`. + * :obj:`'ade20k'`: Load weights trained on the train split of \ + ADE20K dataset. \ + :obj:`n_class` must be :obj:`150` or :obj:`None`. * :obj:`'imagenet'`: Load ImageNet pretrained weights for \ the extractor. * `filepath`: A path of npz file. In this case, :obj:`n_class` \ @@ -305,8 +308,13 @@ class PSPNetResNet101(PSPNet): _models = { 'cityscapes': { 'param': {'n_class': 19, 'input_size': (713, 713)}, - 'url': 'https://github.com/yuyu2172/share-weights/releases/' - 'download/0.0.6/pspnet_resnet101_cityscapes_convert_2018_05_22.npz' + 'url': 'https://chainercv-models.preferred.jp/' + 'pspnet_resnet101_cityscapes_trained_2018_12_19.npz', + }, + 'ade20k': { + 'param': {'n_class': 150, 'input_size': (473, 473)}, + 'url': 'https://chainercv-models.preferred.jp/' + 'pspnet_resnet101_ade20k_trained_2018_12_23.npz', }, } @@ -324,6 +332,16 @@ class PSPNetResNet50(PSPNet): _extractor_kwargs = {'n_layer': 50} _extractor_pick = ('res4', 'res5') _models = { + 'cityscapes': { + 'param': {'n_class': 19, 'input_size': (713, 713)}, + 'url': 'https://chainercv-models.preferred.jp/' + 'pspnet_resnet50_cityscapes_trained_2018_12_19.npz', + }, + 'ade20k': { + 'param': {'n_class': 150, 'input_size': (473, 473)}, + 'url': 'https://chainercv-models.preferred.jp/' + 'pspnet_resnet50_ade20k_trained_2018_12_23.npz', + }, } diff --git a/examples/pspnet/README.md b/examples/pspnet/README.md index ac9fc7bc31..cbcb72867c 100644 --- a/examples/pspnet/README.md +++ b/examples/pspnet/README.md @@ -1,32 +1,71 @@ # Examples of Pyramid Scene Parsing Network (PSPNet) -## Performance +## Demo +This demo downloads a pretrained model automatically if a pretrained model path is not given. +``` +$ python demo.py [--gpu ] [--pretrained-model ] [--input-size ] .jpg +``` + +## Weight Covnersion + +Convert `*.caffemodel` to `*.npz`. Some layers are renamed to fit ChainerCV. +``` +$ python caffe2npz.py .caffemodel .npz +``` + +The converted weight can be downloaded from [here](https://chainercv-models.preferred.jp/pspnet_resnet101_cityscapes_converted_2018_05_22.npz). + +The performance on the Cityscapes dataset is as follows with single scale inference. +Scores are measured by mean Intersection over Union (mIoU). | Model | Reference | ChainerCV (weight conversion) | |:-:|:-:|:-:| | Cityscapes (single scale) | 79.70 % [1] | 79.03 % | -Scores are measured by mean Intersection over Union (mIoU). +## Training model +The model can be trained with a script `train_mutli.py`. +When `cv2` and `MultiprocessIterator` are used together, the process stucks in some situation. +In that case, the problem can be solved by setting the configuration of multi-threaded methods ([detail]( +https://docs.chainer.org/en/stable/tips.html#my-training-process-gets-stuck-when-using-multiprocessiterator)). + +### Cityscapes + +The following table shows the performance of the models trained with our scripts. + +| Base model | Training Data | Reference | ChainerCV | +|:-:|:-:|:-:|:-:| +| Dilated ResNet50 | fine only (3K) | 76.9 % [2] | 73.99 % | +| Dilated ResNet101 | fine only (3K) | 77.9 % [2] | 76.01 % | + +Here are the commands used to train the models included in the table. -## Demo -This demo downloads Cityscapes pretrained model automatically if a pretrained model path is not given. ``` -$ python demo.py [--gpu ] [--pretrained-model ] [--input-size ] .jpg +$ mpiexec -n 8 python3 train_multi.py --dataset cityscapes --model pspnet_resnet50 --iteration 90000 +$ mpiexec -n 8 python3 train_multi.py --dataset cityscapes --model pspnet_resnet101 --iteration 90000 ``` +### ADE20K + +The following table shows the performance of the models trained with our scripts. + +| Base model | Reference | ChainerCV | +|:-:|:-:|:-:| +| Dilated ResNet50 | 41.68 % [1] | 34.97 % | +| Dilated ResNet101 | | 36.55 % | + +Here are the commands used to train the models included in the table. -## Convert Caffe model -Convert `*.caffemodel` to `*.npz`. Some layers are renamed to fit ChainerCV. ``` -$ python caffe2npz.py .caffemodel .npz +$ mpiexec -n 8 python3 train_multi.py --dataset ade20k --model pspnet_resnet50 --iteration 150000 +$ mpiexec -n 8 python3 train_multi.py --dataset ade20k --model pspnet_resnet101 --iteration 150000 ``` - ## Evaluation -The evaluation can be conducted using [`chainercv/examples/semantic_segmentation/eval_cityscapes.py`](https://github.com/chainer/chainercv/blob/master/examples/semantic_segmentation). +The evaluation can be conducted using [`chainercv/examples/semantic_segmentation/eval_semantic_segmentation.py`](https://github.com/chainer/chainercv/blob/master/examples/semantic_segmentation). ## References 1. Hengshuang Zhao et al. "Pyramid Scene Parsing Network" CVPR 2017. -2. [chainer-pspnet by mitmul](https://github.com/mitmul/chainer-pspnet) +2. https://github.com/holyseven/PSPNet-TF-Reproduce (Validation scores for Cityscapes are lacking in the original paper) +3. [chainer-pspnet by mitmul](https://github.com/mitmul/chainer-pspnet) diff --git a/examples/pspnet/train_multi.py b/examples/pspnet/train_multi.py new file mode 100644 index 0000000000..2bd5695f50 --- /dev/null +++ b/examples/pspnet/train_multi.py @@ -0,0 +1,285 @@ +import argparse +import copy +import multiprocessing +import numpy as np + +import chainer +import chainer.functions as F +import chainer.links as L +from chainer import training +from chainer.training import extensions +from chainer.training.extensions import PolynomialShift + +from chainercv.datasets import ade20k_semantic_segmentation_label_names +from chainercv.datasets import ADE20KSemanticSegmentationDataset +from chainercv.datasets import cityscapes_semantic_segmentation_label_names +from chainercv.datasets import CityscapesSemanticSegmentationDataset + +from chainercv.experimental.links import PSPNetResNet101 +from chainercv.experimental.links import PSPNetResNet50 + +from chainercv.chainer_experimental.datasets.sliceable import TransformDataset +from chainercv.extensions import SemanticSegmentationEvaluator +from chainercv.links import Conv2DBNActiv +from chainercv import transforms + +from chainercv.links.model.ssd import GradientScaling + +import PIL + +import chainermn + + +def create_mnbn_model(link, comm): + """Returns a copy of a model with BN replaced by Multi-node BN.""" + if isinstance(link, chainer.links.BatchNormalization): + mnbn = chainermn.links.MultiNodeBatchNormalization( + size=link.avg_mean.shape, + comm=comm, + decay=link.decay, + eps=link.eps, + dtype=link.avg_mean.dtype, + use_gamma=hasattr(link, 'gamma'), + use_beta=hasattr(link, 'beta'), + ) + mnbn.copyparams(link) + for name in link._persistent: + mnbn.__dict__[name] = copy.deepcopy(link.__dict__[name]) + return mnbn + elif isinstance(link, chainer.Chain): + new_children = [ + (child_name, create_mnbn_model( + link.__dict__[child_name], comm)) + for child_name in link._children + ] + new_link = copy.deepcopy(link) + for name, new_child in new_children: + new_link.__dict__[name] = new_child + return new_link + elif isinstance(link, chainer.ChainList): + new_children = [ + create_mnbn_model(l, comm) for l in link] + new_link = copy.deepcopy(link) + for i, new_child in enumerate(new_children): + new_link._children[i] = new_child + return new_link + else: + return copy.deepcopy(link) + + +class Transform(object): + + def __init__( + self, mean, + crop_size, scale_range=[0.5, 2.0]): + self.mean = mean + self.scale_range = scale_range + self.crop_size = crop_size + + def __call__(self, in_data): + img, label = in_data + _, H, W = img.shape + scale = np.random.uniform(self.scale_range[0], self.scale_range[1]) + + # Scale + scaled_H = int(scale * H) + scaled_W = int(scale * W) + img = transforms.resize(img, (scaled_H, scaled_W), PIL.Image.BICUBIC) + label = transforms.resize( + label[None], (scaled_H, scaled_W), PIL.Image.NEAREST)[0] + + # Crop + if (scaled_H < self.crop_size[0]) or (scaled_W < self.crop_size[1]): + shorter_side = min(img.shape[1:]) + img, param = transforms.random_crop( + img, (shorter_side, shorter_side), True) + else: + img, param = transforms.random_crop(img, self.crop_size, True) + label = label[param['y_slice'], param['x_slice']] + + # Rotate + angle = np.random.uniform(-10, 10) + img = transforms.rotate(img, angle, expand=False) + label = transforms.rotate( + label[None], angle, expand=False, + interpolation=PIL.Image.NEAREST, + fill=-1)[0] + + # Resize + if ((img.shape[1] < self.crop_size[0]) + or (img.shape[2] < self.crop_size[1])): + img = transforms.resize(img, self.crop_size, PIL.Image.BICUBIC) + if ((label.shape[0] < self.crop_size[0]) + or (label.shape[1] < self.crop_size[1])): + label = transforms.resize( + label[None].astype(np.float32), + self.crop_size, PIL.Image.NEAREST) + label = label.astype(np.int32)[0] + # Horizontal flip + if np.random.rand() > 0.5: + img = transforms.flip(img, x_flip=True) + label = transforms.flip(label[None], x_flip=True)[0] + + # Mean subtraction + img = img - self.mean + return img, label + + +class TrainChain(chainer.Chain): + + def __init__(self, model): + initialW = chainer.initializers.HeNormal() + super(TrainChain, self).__init__() + with self.init_scope(): + self.model = model + self.aux_conv1 = Conv2DBNActiv( + None, 512, 3, 1, 1, initialW=initialW) + self.aux_conv2 = L.Convolution2D( + None, model.n_class, 3, 1, 1, False, initialW=initialW) + + def __call__(self, imgs, labels): + h_aux, h_main = self.model.extractor(imgs) + h_aux = F.dropout(self.aux_conv1(h_aux), ratio=0.1) + h_aux = self.aux_conv2(h_aux) + h_aux = F.resize_images(h_aux, imgs.shape[2:]) + + h_main = self.model.ppm(h_main) + h_main = F.dropout(self.model.head_conv1(h_main), ratio=0.1) + h_main = self.model.head_conv2(h_main) + h_main = F.resize_images(h_main, imgs.shape[2:]) + + aux_loss = F.softmax_cross_entropy(h_aux, labels) + main_loss = F.softmax_cross_entropy(h_main, labels) + loss = 0.4 * aux_loss + main_loss + + chainer.reporter.report({'loss': loss}, self) + return loss + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data-dir', default='auto') + parser.add_argument('--dataset', + choices=('ade20k', 'cityscapes')) + parser.add_argument('--model', + choices=('pspnet_resnet101', 'pspnet_resnet50')) + parser.add_argument('--lr', default=1e-2) + parser.add_argument('--batch-size', default=2, type=int) + parser.add_argument('--out', default='result') + parser.add_argument('--iteration', default=None, type=int) + parser.add_argument('--communicator', default='hierarchical') + args = parser.parse_args() + + dataset_cfgs = { + 'ade20k': { + 'input_size': (473, 473), + 'label_names': ade20k_semantic_segmentation_label_names, + 'iteration': 150000}, + 'cityscapes': { + 'input_size': (713, 713), + 'label_names': cityscapes_semantic_segmentation_label_names, + 'iteration': 90000} + } + dataset_cfg = dataset_cfgs[args.dataset] + + # This fixes a crash caused by a bug with multiprocessing and MPI. + multiprocessing.set_start_method('forkserver') + p = multiprocessing.Process() + p.start() + p.join() + + comm = chainermn.create_communicator(args.communicator) + device = comm.intra_rank + + n_class = len(dataset_cfg['label_names']) + if args.model == 'pspnet_resnet101': + model = PSPNetResNet101( + n_class, pretrained_model='imagenet', + input_size=dataset_cfg['input_size']) + elif args.model == 'pspnet_resnet50': + model = PSPNetResNet50( + n_class, pretrained_model='imagenet', + input_size=dataset_cfg['input_size']) + train_chain = create_mnbn_model(TrainChain(model), comm) + model = train_chain.model + if device >= 0: + chainer.cuda.get_device_from_id(device).use() + train_chain.to_gpu() + + if args.iteration is None: + n_iter = dataset_cfg['iteration'] + else: + n_iter = args.iteration + + if comm.rank == 0: + if args.dataset == 'ade20k': + train = ADE20KSemanticSegmentationDataset( + data_dir=args.data_dir, split='train') + val = ADE20KSemanticSegmentationDataset( + data_dir=args.data_dir, split='val') + label_names = ade20k_semantic_segmentation_label_names + elif args.dataset == 'cityscapes': + train = CityscapesSemanticSegmentationDataset( + args.data_dir, + label_resolution='fine', split='train') + val = CityscapesSemanticSegmentationDataset( + args.data_dir, + label_resolution='fine', split='val') + label_names = cityscapes_semantic_segmentation_label_names + train = TransformDataset( + train, + ('img', 'label'), + Transform(model.mean, dataset_cfg['input_size'])) + else: + train, val = None, None + train = chainermn.scatter_dataset(train, comm, shuffle=True) + train_iter = chainer.iterators.MultiprocessIterator( + train, batch_size=args.batch_size, n_processes=2) + + optimizer = chainermn.create_multi_node_optimizer( + chainer.optimizers.MomentumSGD(args.lr, 0.9), comm) + optimizer.setup(train_chain) + for param in train_chain.params(): + if param.name not in ('beta', 'gamma'): + param.update_rule.add_hook(chainer.optimizer.WeightDecay(1e-4)) + for l in [ + model.ppm, model.head_conv1, model.head_conv2, + train_chain.aux_conv1, train_chain.aux_conv2]: + for param in l.params(): + param.update_rule.add_hook(GradientScaling(10)) + + updater = training.updaters.StandardUpdater( + train_iter, optimizer, device=device) + trainer = training.Trainer(updater, (n_iter, 'iteration'), args.out) + trainer.extend( + PolynomialShift('lr', 0.9, n_iter, optimizer=optimizer), + trigger=(1, 'iteration')) + + log_interval = 10, 'iteration' + + if comm.rank == 0: + trainer.extend(extensions.LogReport(trigger=log_interval)) + trainer.extend(extensions.observe_lr(), trigger=log_interval) + trainer.extend(extensions.PrintReport( + ['epoch', 'iteration', 'elapsed_time', 'lr', 'main/loss', + 'validation/main/miou', 'validation/main/mean_class_accuracy', + 'validation/main/pixel_accuracy']), + trigger=log_interval) + trainer.extend(extensions.ProgressBar(update_interval=10)) + trainer.extend( + extensions.snapshot_object( + train_chain.model, 'snapshot_model_{.updater.iteration}.npz'), + trigger=(n_iter, 'iteration')) + val_iter = chainer.iterators.SerialIterator( + val, batch_size=1, repeat=False, shuffle=False) + trainer.extend( + SemanticSegmentationEvaluator( + val_iter, model, + label_names), + trigger=(n_iter, 'iteration')) + + trainer.run() + + +if __name__ == '__main__': + main() diff --git a/examples/semantic_segmentation/README.md b/examples/semantic_segmentation/README.md index 3eb4007588..be2124b76c 100644 --- a/examples/semantic_segmentation/README.md +++ b/examples/semantic_segmentation/README.md @@ -13,9 +13,13 @@ The scores are mIoU. ### Cityscapes -| Model | Reference | ChainerCV (weight conversion) | -|:-:|:-:|:-:| -| PSPNet with ResNet101 (single scale) | 79.70 % [1] | 79.03 % | +| Model | Training Data | Reference | ChainerCV | +|:-:|:-:|:-:|:-:| +| PSPNet w/ Dilated ResNet50 | fine only (3K) | 76.9 % [2] | 73.99 % | +| PSPNet w/ Dilated ResNet101 | fine only (3K) | 77.9 % [2] | 76.01 % | + + +Example ``` $ python eval_semantic_segmentation.py --gpu --dataset cityscapes --model pspnet_resnet101 @@ -23,12 +27,23 @@ $ python eval_semantic_segmentation.py --gpu --dataset cityscapes --model $ mpiexec -n <#gpu> python eval_semantic_segmentation_multi.py --dataset cityscapes --model pspnet_resnet101 ``` +### ADE20k + +| Base model | Reference | ChainerCV | +|:-:|:-:|:-:| +| Dilated ResNet50 | 41.68 % [1] | 34.97 % | +| Dilated ResNet101 | | 36.55 % | + +``` +$ python eval_semantic_segmentation.py --gpu --dataset ade20k --model pspnet_resnet101 +``` + ### CamVid | Model | Reference | ChainerCV | |:-:|:-:|:-:| -| SegNet | 46.3 % [2] | 49.4 % | +| SegNet | 46.3 % [3] | 49.4 % | ``` $ python eval_semantic_segmentation.py --gpu --dataset camvid --model segnet @@ -38,4 +53,5 @@ $ python eval_semantic_segmentation.py --gpu --dataset camvid --model segn # Reference 1. Hengshuang Zhao et al. "Pyramid Scene Parsing Network" CVPR 2017. -2. Vijay Badrinarayanan et al. "SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation." PAMI, 2017. +2. https://github.com/holyseven/PSPNet-TF-Reproduce (Validation scores for Cityscapes are lacking in the original paper) +3. Vijay Badrinarayanan et al. "SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation." PAMI, 2017. diff --git a/examples_tests/pspnet_tests/test_train_multi.sh b/examples_tests/pspnet_tests/test_train_multi.sh new file mode 100644 index 0000000000..14ac01fb59 --- /dev/null +++ b/examples_tests/pspnet_tests/test_train_multi.sh @@ -0,0 +1,6 @@ +cd examples/pspnet +sed -e "s/data_dir=args.data_dir, split='val')/data_dir=args.data_dir, split='val').slice[:5]/" -i train_multi.py + +$MPIEXEC $PYTHON train_multi.py --dataset ade20k --model pspnet_resnet50 --batch-size 1 --iteration 10 +$MPIEXEC $PYTHON train_multi.py --dataset ade20k --model pspnet_resnet101 --batch-size 1 --iteration 10 + diff --git a/tests/experimental_tests/links_tests/model_tests/pspnet_tests/test_pspnet.py b/tests/experimental_tests/links_tests/model_tests/pspnet_tests/test_pspnet.py index 30847adaf8..1b65024fe2 100644 --- a/tests/experimental_tests/links_tests/model_tests/pspnet_tests/test_pspnet.py +++ b/tests/experimental_tests/links_tests/model_tests/pspnet_tests/test_pspnet.py @@ -56,12 +56,12 @@ def test_predict_gpu(self): def _create_paramters(): params = testing.product({ 'model': [PSPNetResNet50], - 'pretrained_model': ['imagenet'], + 'pretrained_model': ['imagenet', 'cityscapes', 'ade20k'], 'n_class': [None, 5], }) params += testing.product({ 'model': [PSPNetResNet101], - 'pretrained_model': ['imagenet', 'cityscapes'], + 'pretrained_model': ['imagenet', 'cityscapes', 'ade20k'], 'n_class': [None, 5, 19], }) return params