Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle Hackathon] Add ShuffleNetV2 #36278

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/paddle/tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ def test_resnet101(self):

def test_resnet152(self):
self.models_infer('resnet152')

def test_shufflenet_v2_x0_25(self):
self.models_infer('shufflenet_v2_x0_25')

def test_shufflenet_v2_x0_33(self):
self.models_infer('shufflenet_v2_x0_33')

def test_shufflenet_v2_x0_5(self):
self.models_infer('shufflenet_v2_x0_5')

def test_shufflenet_v2_x1_0(self):
self.models_infer('shufflenet_v2_x1_0')

def test_shufflenet_v2_x1_5(self):
self.models_infer('shufflenet_v2_x1_5')

def test_shufflenet_v2_x2_0(self):
self.models_infer('shufflenet_v2_x2_0')

def test_vgg16_num_classes(self):
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
from .models import mobilenet_v1 # noqa: F401
from .models import MobileNetV2 # noqa: F401
from .models import mobilenet_v2 # noqa: F401
from .models import ShuffleNetV2 # noqa: F401
from .models import shufflenet_v2_x0_25 #noqa: F401
from .models import shufflenet_v2_x0_33 #noqa: F401
from .models import shufflenet_v2_x0_5 #noqa: F401
from .models import shufflenet_v2_x1_0 #noqa: F401
from .models import shufflenet_v2_x1_5 #noqa: F401
from .models import shufflenet_v2_x0_5 #noqa: F401
from .models import VGG # noqa: F401
from .models import vgg11 # noqa: F401
from .models import vgg13 # noqa: F401
Expand Down
17 changes: 16 additions & 1 deletion python/paddle/vision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
from .mobilenetv1 import mobilenet_v1 # noqa: F401
from .mobilenetv2 import MobileNetV2 # noqa: F401
from .mobilenetv2 import mobilenet_v2 # noqa: F401
from .shufflenetv2 import ShuffleNetV2 # noqa: F401
from .shufflenetv2 import shufflenet_v2_x0_25 #noqa: F401
from .shufflenetv2 import shufflenet_v2_x0_33 #noqa: F401
from .shufflenetv2 import shufflenet_v2_x0_5 #noqa: F401
from .shufflenetv2 import shufflenet_v2_x1_0 #noqa: F401
from .shufflenetv2 import shufflenet_v2_x1_5 #noqa: F401
from .shufflenetv2 import shufflenet_v2_x2_0 #noqa: F401
from .vgg import VGG # noqa: F401
from .vgg import vgg11 # noqa: F401
from .vgg import vgg13 # noqa: F401
Expand All @@ -45,5 +52,13 @@
'mobilenet_v1',
'MobileNetV2',
'mobilenet_v2',
'LeNet'
'LeNet',
'ShuffleNetV2',
'shufflenet_v2_x0_25',
'shufflenet_v2_x0_33',
'shufflenet_v2_x0_5',
'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5',
'shufflenet_v2_x2_0',

]
239 changes: 239 additions & 0 deletions python/paddle/vision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
from typing import Callable, Any, List

import paddle
import paddle.nn as nn
from paddle import Tensor

from paddle.utils.download import get_weights_path_from_url


__all__ = ["ShuffleNetV2", "shufflenet_v2_x0_25","shufflenet_v2_x0_33", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"]

model_urls = {
"shufflenetv2_x0.25": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_25_pretrained.pdparams",
"shufflenetv2_x0.33": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_33_pretrained.pdparams",
"shufflenetv2_x0.5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x0_5_pretrained.pdparams",
"shufflenetv2_x1.0": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_0_pretrained.pdparams",
"shufflenetv2_x1.5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x1_5_pretrained.pdparams",
"shufflenetv2_x2.0": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ShuffleNetV2_x2_0_pretrained.pdparams",
}


def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups

# reshape
x = paddle.reshape(x, (batchsize, groups, channels_per_group, height, width))

x = paddle.transpose(x,[0,2,1,3,4])

# flatten
x = paddle.reshape(x, (batchsize, -1, height, width))

return x


class InvertedResidual(nn.Layer):
def __init__(self, inp: int, oup: int, stride: int) -> None:
super(InvertedResidual, self).__init__()

if not (1 <= stride <= 3):
raise ValueError("illegal stride value")
self.stride = stride

branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)

if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2D(inp),
nn.Conv2D(inp, branch_features, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2D(branch_features),
nn.ReLU(),
)
else:
self.branch1 = nn.Sequential()

self.branch2 = nn.Sequential(
nn.Conv2D(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0
),
nn.BatchNorm2D(branch_features),
nn.ReLU(),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2D(branch_features),
nn.Conv2D(branch_features, branch_features, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2D(branch_features),
nn.ReLU(),
)

@staticmethod
def depthwise_conv(
i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False
) -> nn.Conv2D:
return nn.Conv2D(i, o, kernel_size, stride, padding, bias=bias, groups=i)

def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
x1, x2 = x.chunk(2, axis=1)
out = paddle.concat((x1, self.branch2(x2)), axis=1)
else:
out = paddle.concat((self.branch1(x), self.branch2(x)),axis=1)

out = channel_shuffle(out, 2)

return out


class ShuffleNetV2(nn.Layer):
def __init__(
self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., nn.Layer] = InvertedResidual,
) -> None:
super(ShuffleNetV2, self).__init__()

if len(stages_repeats) != 3:
raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels) != 5:
raise ValueError("expected stages_out_channels as list of 5 positive ints")
self._stage_out_channels = stages_out_channels

input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2D(input_channels, output_channels, 3, 2, 1),
nn.BatchNorm2d(output_channels),
nn.ReLU(),
)
input_channels = output_channels

self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)

# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):
seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(inverted_residual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels

output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2D(input_channels, output_channels, 1, 1, 0),
nn.BatchNorm2D(output_channels),
nn.ReLU(inplace=True),
)

self.fc = nn.Linear(output_channels, num_classes)

def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # globalpool
x = self.fc(x)
return x

def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2:
model = ShuffleNetV2(*args, **kwargs)

if pretrained:
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError("pretrained {} is not supported as of now".format(arch))
else:
weight_path = get_weights_path_from_url(model_urls[arch])
param = paddle.load(weight_path)
model.set_dict(param)

return model

def shufflenet_v2_x0_25(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2("shufflenetv2_x0.25", pretrained, progress, [4, 8, 4], [24, 24, 48, 96, 512], **kwargs)

def shufflenet_v2_x0_33(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2("shufflenetv2_x0.33", pretrained, progress, [4, 8, 4], [24, 32, 64, 128, 512], **kwargs)

def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)


def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)


def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)


def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
"""
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)