Skip to content

Commit c750830

Browse files
committed
v0.1 data_parallel support (bn track_stats) + preResNet/denseNet/resNeXt
1 parent 21e3c81 commit c750830

7 files changed

+379
-27
lines changed

models/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
from .convBlock import *
2-
from .wideResnet import get_wrn
2+
from .wideResnet import get_wrn
3+
from .preResNet import get_resnet
4+
from .denseNet import get_densenet
5+
from .resNeXt import get_resnext

models/convBlock.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@ def __init__(self, drop_type=0, drop_rate=0., inplace=False):
66
super(drop_op, self).__init__()
77
assert (drop_type in (0, 1, 2, 3) and 0.<=drop_rate<1.)
88
self.drop_type, self.keep_rate = drop_type, 1.-drop_rate
9+
if drop_rate == 0.:
10+
self.drop_op = nn.Sequential();return
911
if drop_type == 0:
1012
self.drop_op = nn.Dropout(p=drop_rate, inplace=inplace)
1113
elif drop_type == 1:
1214
self.drop_op = nn.Dropout2d(p=drop_rate, inplace=inplace)
1315

1416
def forward(self, x):
17+
if self.keep_rate == 1.: return x
1518
if self.drop_type in (0, 1): return self.drop_op(x)
1619
# drop-branch/layer, x in [b_0, b_1, ...], b_i: B*C_i*H*W
1720
if self.training:
18-
mask = torch.FloatTensor(1, x.size(1), 1, 1, device=x.device).\
21+
mask = torch.FloatTensor(len(x)).to(x[0].device).\
1922
bernoulli_(self.keep_rate)*(1./self.keep_rate)
20-
x = list(map(lambda b: b.mul_(mask), x))
23+
x = [x[idx]*mask[idx] for idx in range(len(x))]
2124
return torch.cat(x, dim=1)
2225

2326
class Norm2d(nn.Module):
@@ -46,15 +49,11 @@ def forward(self, input):
4649
var = input.var([0, 2, 3], unbiased=False)
4750
n = input.numel()/input.size(1)
4851
if self.training:
49-
self.train_running_mean = self.momentum*mean+\
50-
(1-self.momentum)*self.train_running_mean
51-
self.train_running_var = self.momentum*var*n/(n-1)+\
52-
(1-self.momentum)*self.train_running_var
52+
self.train_running_mean.mul_(1 - self.momentum).add_(self.momentum*mean)
53+
self.train_running_var.mul_(1-self.momentum).add_(self.momentum*var*n/(n-1))
5354
else:
54-
self.test_running_mean = self.momentum*mean+\
55-
(1-self.momentum)*self.test_running_mean
56-
self.test_running_var = self.momentum*var*n/(n-1)+\
57-
(1-self.momentum)*self.test_running_var
55+
self.test_running_mean.mul_(1 - self.momentum).add_(self.momentum*mean)
56+
self.test_running_var.mul_(1-self.momentum).add_(self.momentum*var*n/(n-1))
5857
return self.norm(input)
5958

6059
def norm2d_track_stats(model, is_track):
@@ -72,13 +71,14 @@ def norm2d_stats(model):
7271
class conv_block(nn.Module):
7372
def __init__(self, in_channels, out_channels, kernel_size, block_type=0,
7473
use_gn=False, gn_groups=8, drop_type=0, drop_rate=0.,
75-
stride=1, padding=0, groups=1, bias=False):
74+
stride=1, padding=0, groups=1, bias=False, track_stats=False):
7675
super(conv_block, self).__init__()
7776
self.relu = nn.ReLU(inplace=True)
78-
self.norm = Norm2d(in_channels, use_gn, gn_groups, drop_rate>0.)
7977
self.drop = drop_op(drop_type, drop_rate)
8078
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
8179
groups=groups, stride=stride, padding=padding, bias=bias)
80+
bn_channels = in_channels if block_type in [0, 1] else out_channels
81+
self.norm = Norm2d(bn_channels, use_gn, gn_groups, track_stats)
8282

8383
if block_type==0: # bn/gn-relu-drop-conv, recommended
8484
self.ops = nn.Sequential(self.norm, self.relu,

models/denseNet.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from models import conv_block
6+
7+
class BasicBlock(nn.Module):
8+
def __init__(self, in_planes, out_planes, args):
9+
super(BasicBlock, self).__init__()
10+
self.conv = conv_block(in_planes, out_planes, 3, args.block_type,
11+
args.use_gn, args.gn_groups, args.drop_type,
12+
args.drop_rate, padding=1, track_stats=args.report_ratio)
13+
14+
def forward(self, x):
15+
out = self.conv(x)
16+
return torch.cat([x, out], 1)
17+
18+
class Bottleneck(nn.Module):
19+
def __init__(self, in_planes, out_planes, args):
20+
super(Bottleneck, self).__init__()
21+
inter_planes = out_planes * 4
22+
self.conv1 = conv_block(in_planes, inter_planes, 1, args.block_type,
23+
args.use_gn, args.gn_groups, args.drop_type, args.drop_rate,
24+
track_stats=args.report_ratio)
25+
self.conv2 = conv_block(inter_planes, out_planes, 3, args.block_type,
26+
args.use_gn, args.gn_groups, args.drop_type, args.drop_rate,
27+
padding=1, track_stats=args.report_ratio)
28+
29+
def forward(self, x):
30+
out = self.conv2(self.conv1(x))
31+
return torch.cat([x, out], 1)
32+
33+
class TransitionBlock(nn.Module):
34+
def __init__(self, in_planes, out_planes, args):
35+
super(TransitionBlock, self).__init__()
36+
self.conv = conv_block(in_planes, out_planes, 1, args.block_type,
37+
args.use_gn, args.gn_groups, args.drop_type, args.drop_rate,
38+
track_stats=args.report_ratio)
39+
40+
def forward(self, x):
41+
out = self.conv(x)
42+
return F.avg_pool2d(out, 2)
43+
44+
class DenseBlock(nn.Module):
45+
def __init__(self, num_layers, in_planes, growth_rate, block, args):
46+
super(DenseBlock, self).__init__()
47+
self.layer = nn.Sequential(*[block(in_planes+i*growth_rate, growth_rate, args)
48+
for i in range(num_layers)])
49+
50+
def forward(self, x):
51+
return self.layer(x)
52+
53+
# For CIFAR-10/100 dataset
54+
class DenseNet(nn.Module):
55+
def __init__(self, args, growth_rate=12,
56+
reduction=0.5, bottleneck=True):
57+
super(DenseNet, self).__init__()
58+
in_planes = 2 * growth_rate
59+
n = int((args.depth - 4) / 3)
60+
if bottleneck == True:
61+
n = n//2
62+
block = Bottleneck
63+
else:
64+
block = BasicBlock
65+
# 1st conv before any dense block
66+
self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,
67+
padding=1, bias=False)
68+
# 1st block
69+
self.block1 = DenseBlock(n, in_planes, growth_rate, block, args)
70+
in_planes = int(in_planes+n*growth_rate)
71+
self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), args)
72+
in_planes = int(math.floor(in_planes*reduction))
73+
# 2nd block
74+
self.block2 = DenseBlock(n, in_planes, growth_rate, block, args)
75+
in_planes = int(in_planes+n*growth_rate)
76+
self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), args)
77+
in_planes = int(math.floor(in_planes*reduction))
78+
# 3rd block
79+
self.block3 = DenseBlock(n, in_planes, growth_rate, block, args)
80+
in_planes = int(in_planes+n*growth_rate)
81+
# global average pooling and classifier
82+
self.bn = nn.BatchNorm2d(in_planes)
83+
self.relu = nn.ReLU(inplace=True)
84+
self.fc = nn.Linear(in_planes, args.class_num)
85+
self.in_planes = in_planes
86+
87+
def forward(self, x):
88+
out = self.conv1(x)
89+
out = self.trans1(self.block1(out))
90+
out = self.trans2(self.block2(out))
91+
out = self.block3(out)
92+
out = self.relu(self.bn(out))
93+
out = F.avg_pool2d(out, 8)
94+
out = out.view(-1, self.in_planes)
95+
return self.fc(out)
96+
97+
98+
# https://github.com/liuzhuang13/DenseNet#results-on-cifar
99+
# CIFAR DenseNet3(depth=100, num_classes=10., growth_rate=12, reduction=0.5, bottleneck=True, drop_rate=0.2)
100+
# SVHN DenseNet3(depth=100, num_classes=10., growth_rate=24, reduction=0.5, bottleneck=True, drop_rate=0.2)
101+
# DenseNet3(depth=250, num_classes=10., growth_rate=24, reduction=0.5, bottleneck=True, drop_rate=0.2)
102+
# DenseNet3(depth=190, num_classes=10., growth_rate=40, reduction=0.5, bottleneck=True, drop_rate=0.2)
103+
def get_densenet(args):
104+
return DenseNet(args, args.arg1)
105+
106+
if __name__ == '__main__':
107+
import argparse
108+
109+
parser = argparse.ArgumentParser(description='WideResNet')
110+
args = parser.parse_args()
111+
args.depth = 100
112+
args.class_num = 10
113+
args.block_type = 0
114+
args.use_gn = False
115+
args.gn_groups = 6
116+
args.drop_type = 1
117+
args.drop_rate = 0.1
118+
args.report_ratio = True
119+
args.arg1 = 12
120+
121+
net = DenseNet(args, args.arg1)
122+
y = net(torch.randn(1, 3, 32, 32))
123+
print(y.size())
124+
print(net)
125+
print(sum([p.data.nelement() for p in net.parameters()]))
126+
127+
from convBlock import Norm2d, norm2d_stats, norm2d_track_stats
128+
129+
# norm2d_track_stats(net, False)
130+
mean, var = norm2d_stats(net)
131+
print(len(mean), mean)
132+
print(var)

models/preResNet.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import torch.nn as nn
3+
from models import conv_block
4+
5+
class Bottleneck(nn.Module):
6+
expansion = 4
7+
8+
def __init__(self, inplanes, planes, args, stride=1, downsample=None):
9+
super(Bottleneck, self).__init__()
10+
11+
self.bn1 = nn.BatchNorm2d(inplanes)
12+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
13+
self.relu = nn.ReLU(inplace=True)
14+
15+
self.block2 = conv_block(planes, planes, 3, args.block_type, args.use_gn, args.gn_groups,
16+
args.drop_type, args.drop_rate, stride=stride, padding=1,
17+
track_stats=args.report_ratio)
18+
self.block3 = conv_block(planes, planes*Bottleneck.expansion, 1, block_type=0,
19+
use_gn=False, drop_rate=0., track_stats=False)
20+
21+
self.downsample = downsample
22+
23+
def forward(self, x):
24+
residual = x
25+
26+
out = self.relu(self.bn1(x))
27+
28+
if self.downsample is not None:
29+
residual = self.downsample(out)
30+
31+
out = self.conv1(out)
32+
out = self.block2(out)
33+
out = self.block3(out)
34+
35+
out += residual
36+
37+
return out
38+
39+
40+
class preResNet(nn.Module):
41+
def __init__(self, args, widen_factor=1.):
42+
super(preResNet, self).__init__()
43+
self.inplanes = int(16*widen_factor)
44+
n = int((args.depth - 2) / 9)
45+
46+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
47+
self.layer1 = self._make_layer(int(16*widen_factor), n, args)
48+
self.layer2 = self._make_layer(int(32*widen_factor), n, args, stride=2)
49+
self.layer3 = self._make_layer(int(64*widen_factor), n, args, stride=2)
50+
self.bn = nn.BatchNorm2d(int(64 * widen_factor) * Bottleneck.expansion)
51+
self.relu = nn.ReLU(inplace=True)
52+
self.avgpool = nn.AvgPool2d(8)
53+
self.fc = nn.Linear(int(64*widen_factor) * Bottleneck.expansion, args.class_num)
54+
55+
def _make_layer(self, planes, blocks, args, stride=1):
56+
downsample = None
57+
if stride != 1 or self.inplanes != planes * Bottleneck.expansion:
58+
downsample = nn.Sequential(
59+
nn.Conv2d(self.inplanes, planes * Bottleneck.expansion,
60+
kernel_size=1, stride=stride, bias=False),
61+
nn.BatchNorm2d(planes * Bottleneck.expansion),
62+
)
63+
64+
layers = []
65+
layers.append(Bottleneck(self.inplanes, planes, args, stride, downsample))
66+
self.inplanes = planes * Bottleneck.expansion
67+
for i in range(1, blocks):
68+
layers.append(Bottleneck(self.inplanes, planes, args))
69+
70+
return nn.Sequential(*layers)
71+
72+
def forward(self, x):
73+
x = self.conv1(x)
74+
75+
x = self.layer1(x)
76+
x = self.layer2(x)
77+
x = self.layer3(x)
78+
79+
x = self.relu(self.bn(x))
80+
x = self.avgpool(x)
81+
x = x.view(x.size(0), -1)
82+
x = self.fc(x)
83+
84+
return x
85+
86+
def get_resnet(args):
87+
return preResNet(args, args.arg1)
88+
89+
if __name__ == '__main__':
90+
import argparse
91+
parser = argparse.ArgumentParser(description='PreResNet')
92+
args = parser.parse_args()
93+
args.depth=110
94+
args.class_num = 10
95+
args.block_type = 0
96+
args.use_gn = False
97+
args.gn_groups = 16
98+
args.drop_type = 1
99+
args.drop_rate = 0.1
100+
args.report_ratio = True
101+
102+
net = preResNet(args)
103+
y = net((torch.randn(1, 3, 32, 32)))
104+
print(y.size())
105+
print(net)
106+
print(sum([p.data.nelement() for p in net.parameters()]))
107+
108+
from convBlock import Norm2d, norm2d_stats, norm2d_track_stats
109+
110+
# norm2d_track_stats(net, False)
111+
mean, var = norm2d_stats(net)
112+
print(mean)
113+
print(var)

0 commit comments

Comments
 (0)