From 21fd574ec3e54204f06a40acd9ec3883cabea8ac Mon Sep 17 00:00:00 2001 From: robinbg Date: Fri, 8 Oct 2021 10:50:17 +0800 Subject: [PATCH 1/2] Add Shuffle --- python/paddle/vision/models/shufflenetv2.py | 239 ++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 python/paddle/vision/models/shufflenetv2.py diff --git a/python/paddle/vision/models/shufflenetv2.py b/python/paddle/vision/models/shufflenetv2.py new file mode 100644 index 00000000000000..4adb35fc994c95 --- /dev/null +++ b/python/paddle/vision/models/shufflenetv2.py @@ -0,0 +1,239 @@ +from typing import Callable, Any, List + +import paddle +import paddle.nn as nn +from paddle import Tensor + +from paddle.utils.download import get_weights_path_from_url + + +__all__ = ["ShuffleNetV2", "shufflenet_v2_x0_25","shufflenet_v2_x0_33", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"] + +model_urls = { + "shufflenetv2_x0.25": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_25_pretrained.pdparams", + "shufflenetv2_x0.33": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_33_pretrained.pdparams", + "shufflenetv2_x0.5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_5_pretrained.pdparams", + "shufflenetv2_x1.0": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_0_pretrained.pdparams", + "shufflenetv2_x1.5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_5_pretrained.pdparams", + "shufflenetv2_x2.0": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x2_0_pretrained.pdparams", +} + + +def channel_shuffle(x: Tensor, groups: int) -> Tensor: + batchsize, num_channels, height, width = x.size() + channels_per_group = num_channels // groups + + # reshape + x = paddle.reshape(x, (batchsize, groups, channels_per_group, height, width)) + + x = paddle.transpose(x,[0,2,1,3,4]) + + # flatten + x = paddle.reshape(x, (batchsize, -1, height, width)) + + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp: int, oup: int, stride: int) -> None: + super(InvertedResidual, self).__init__() + + if not (1 <= stride <= 3): + raise ValueError("illegal stride value") + self.stride = stride + + branch_features = oup // 2 + assert (self.stride != 1) or (inp == branch_features << 1) + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2D(inp), + nn.Conv2D(inp, branch_features, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(branch_features), + nn.ReLU(), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2D( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0 + ), + nn.BatchNorm2D(branch_features), + nn.ReLU(), + self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2D(branch_features), + nn.Conv2D(branch_features, branch_features, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(branch_features), + nn.ReLU(), + ) + + @staticmethod + def depthwise_conv( + i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False + ) -> nn.Conv2d: + return nn.Conv2D(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x: Tensor) -> Tensor: + if self.stride == 1: + x1, x2 = x.chunk(2, axis=1) + out = paddle.concat((x1, self.branch2(x2)), axis=1) + else: + out = paddle.concat((self.branch1(x), self.branch2(x)),axis=1) + + out = channel_shuffle(out, 2) + + return out + + +class ShuffleNetV2(nn.Module): + def __init__( + self, + stages_repeats: List[int], + stages_out_channels: List[int], + num_classes: int = 1000, + inverted_residual: Callable[..., nn.Module] = InvertedResidual, + ) -> None: + super(ShuffleNetV2, self).__init__() + + if len(stages_repeats) != 3: + raise ValueError("expected stages_repeats as list of 3 positive ints") + if len(stages_out_channels) != 5: + raise ValueError("expected stages_out_channels as list of 5 positive ints") + self._stage_out_channels = stages_out_channels + + input_channels = 3 + output_channels = self._stage_out_channels[0] + self.conv1 = nn.Sequential( + nn.Conv2D(input_channels, output_channels, 3, 2, 1), + nn.BatchNorm2d(output_channels), + nn.ReLU(), + ) + input_channels = output_channels + + self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + # Static annotations for mypy + self.stage2: nn.Sequential + self.stage3: nn.Sequential + self.stage4: nn.Sequential + stage_names = ["stage{}".format(i) for i in [2, 3, 4]] + for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]): + seq = [inverted_residual(input_channels, output_channels, 2)] + for i in range(repeats - 1): + seq.append(inverted_residual(output_channels, output_channels, 1)) + setattr(self, name, nn.Sequential(*seq)) + input_channels = output_channels + + output_channels = self._stage_out_channels[-1] + self.conv5 = nn.Sequential( + nn.Conv2D(input_channels, output_channels, 1, 1, 0), + nn.BatchNorm2D(output_channels), + nn.ReLU(inplace=True), + ) + + self.fc = nn.Linear(output_channels, num_classes) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.maxpool(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = self.conv5(x) + x = x.mean([2, 3]) # globalpool + x = self.fc(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2: + model = ShuffleNetV2(*args, **kwargs) + + if pretrained: + model_url = model_urls[arch] + if model_url is None: + raise NotImplementedError("pretrained {} is not supported as of now".format(arch)) + else: + weight_path = get_weights_path_from_url(model_urls[arch]) + param = paddle.load(weight_path) + model.set_dict(param) + + return model + +def shufflnet_v2_x0_25(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 0.5x output channels, as described in + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" + `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _shufflenetv2("shufflenetv2_x0.25", pretrained, progress, [4, 8, 4], [24, 24, 48, 96, 512], **kwargs) + +def shufflnet_v2_x0_33(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 0.5x output channels, as described in + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" + `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _shufflenetv2("shufflenetv2_x0.33", pretrained, progress, [4, 8, 4], [24, 32, 64, 128, 512], **kwargs) + +def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 0.5x output channels, as described in + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" + `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + + +def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 1.0x output channels, as described in + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" + `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + + +def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 1.5x output channels, as described in + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" + `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + + +def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 2.0x output channels, as described in + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" + `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) \ No newline at end of file From ecfc3f708851c69c58ed1f363ffddf06777a72e5 Mon Sep 17 00:00:00 2001 From: robinbg Date: Fri, 8 Oct 2021 11:32:17 +0800 Subject: [PATCH 2/2] Add ShuffleNetV2 --- python/paddle/tests/test_vision_models.py | 18 ++++++++++++++++++ python/paddle/vision/__init__.py | 7 +++++++ python/paddle/vision/models/__init__.py | 17 ++++++++++++++++- python/paddle/vision/models/shufflenetv2.py | 12 ++++++------ 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/python/paddle/tests/test_vision_models.py b/python/paddle/tests/test_vision_models.py index a25a8f373c29c4..6fb71357338f00 100644 --- a/python/paddle/tests/test_vision_models.py +++ b/python/paddle/tests/test_vision_models.py @@ -70,6 +70,24 @@ def test_resnet101(self): def test_resnet152(self): self.models_infer('resnet152') + + def test_shufflenet_v2_x0_25(self): + self.models_infer('shufflenet_v2_x0_25') + + def test_shufflenet_v2_x0_33(self): + self.models_infer('shufflenet_v2_x0_33') + + def test_shufflenet_v2_x0_5(self): + self.models_infer('shufflenet_v2_x0_5') + + def test_shufflenet_v2_x1_0(self): + self.models_infer('shufflenet_v2_x1_0') + + def test_shufflenet_v2_x1_5(self): + self.models_infer('shufflenet_v2_x1_5') + + def test_shufflenet_v2_x2_0(self): + self.models_infer('shufflenet_v2_x2_0') def test_vgg16_num_classes(self): vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10) diff --git a/python/paddle/vision/__init__.py b/python/paddle/vision/__init__.py index 76393865ded04a..82cd84c4850151 100644 --- a/python/paddle/vision/__init__.py +++ b/python/paddle/vision/__init__.py @@ -38,6 +38,13 @@ from .models import mobilenet_v1 # noqa: F401 from .models import MobileNetV2 # noqa: F401 from .models import mobilenet_v2 # noqa: F401 +from .models import ShuffleNetV2 # noqa: F401 +from .models import shufflenet_v2_x0_25 #noqa: F401 +from .models import shufflenet_v2_x0_33 #noqa: F401 +from .models import shufflenet_v2_x0_5 #noqa: F401 +from .models import shufflenet_v2_x1_0 #noqa: F401 +from .models import shufflenet_v2_x1_5 #noqa: F401 +from .models import shufflenet_v2_x0_5 #noqa: F401 from .models import VGG # noqa: F401 from .models import vgg11 # noqa: F401 from .models import vgg13 # noqa: F401 diff --git a/python/paddle/vision/models/__init__.py b/python/paddle/vision/models/__init__.py index d38f3b1722ee8c..10002edf838942 100644 --- a/python/paddle/vision/models/__init__.py +++ b/python/paddle/vision/models/__init__.py @@ -22,6 +22,13 @@ from .mobilenetv1 import mobilenet_v1 # noqa: F401 from .mobilenetv2 import MobileNetV2 # noqa: F401 from .mobilenetv2 import mobilenet_v2 # noqa: F401 +from .shufflenetv2 import ShuffleNetV2 # noqa: F401 +from .shufflenetv2 import shufflenet_v2_x0_25 #noqa: F401 +from .shufflenetv2 import shufflenet_v2_x0_33 #noqa: F401 +from .shufflenetv2 import shufflenet_v2_x0_5 #noqa: F401 +from .shufflenetv2 import shufflenet_v2_x1_0 #noqa: F401 +from .shufflenetv2 import shufflenet_v2_x1_5 #noqa: F401 +from .shufflenetv2 import shufflenet_v2_x2_0 #noqa: F401 from .vgg import VGG # noqa: F401 from .vgg import vgg11 # noqa: F401 from .vgg import vgg13 # noqa: F401 @@ -45,5 +52,13 @@ 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', - 'LeNet' + 'LeNet', + 'ShuffleNetV2', + 'shufflenet_v2_x0_25', + 'shufflenet_v2_x0_33', + 'shufflenet_v2_x0_5', + 'shufflenet_v2_x1_0', + 'shufflenet_v2_x1_5', + 'shufflenet_v2_x2_0', + ] diff --git a/python/paddle/vision/models/shufflenetv2.py b/python/paddle/vision/models/shufflenetv2.py index 4adb35fc994c95..fdfafabc5bc6b0 100644 --- a/python/paddle/vision/models/shufflenetv2.py +++ b/python/paddle/vision/models/shufflenetv2.py @@ -34,7 +34,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: return x -class InvertedResidual(nn.Module): +class InvertedResidual(nn.Layer): def __init__(self, inp: int, oup: int, stride: int) -> None: super(InvertedResidual, self).__init__() @@ -76,7 +76,7 @@ def __init__(self, inp: int, oup: int, stride: int) -> None: @staticmethod def depthwise_conv( i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False - ) -> nn.Conv2d: + ) -> nn.Conv2D: return nn.Conv2D(i, o, kernel_size, stride, padding, bias=bias, groups=i) def forward(self, x: Tensor) -> Tensor: @@ -91,13 +91,13 @@ def forward(self, x: Tensor) -> Tensor: return out -class ShuffleNetV2(nn.Module): +class ShuffleNetV2(nn.Layer): def __init__( self, stages_repeats: List[int], stages_out_channels: List[int], num_classes: int = 1000, - inverted_residual: Callable[..., nn.Module] = InvertedResidual, + inverted_residual: Callable[..., nn.Layer] = InvertedResidual, ) -> None: super(ShuffleNetV2, self).__init__() @@ -169,7 +169,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa return model -def shufflnet_v2_x0_25(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +def shufflenet_v2_x0_25(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" @@ -180,7 +180,7 @@ def shufflnet_v2_x0_25(pretrained: bool = False, progress: bool = True, **kwargs """ return _shufflenetv2("shufflenetv2_x0.25", pretrained, progress, [4, 8, 4], [24, 24, 48, 96, 512], **kwargs) -def shufflnet_v2_x0_33(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +def shufflenet_v2_x0_33(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"