Skip to content

Commit 88c940d

Browse files
committed
Add init weights for hrnet_contrast (PaddlePaddle#1746)
1 parent cd8d52d commit 88c940d

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

paddleseg/models/hrnet_contrast.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class HRNetW48Contrast(nn.Layer):
4040
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
4141
pretrained (str, optional): The path or url of pretrained model. Default: None.
4242
"""
43+
4344
def __init__(self,
4445
in_channels,
4546
num_classes,
@@ -54,23 +55,23 @@ def __init__(self,
5455
self.num_classes = num_classes
5556
self.proj_dim = proj_dim
5657
self.align_corners = align_corners
57-
self.pretrained = pretrained
5858

5959
self.cls_head = nn.Sequential(
60-
layers.ConvBNReLU(in_channels,
61-
in_channels,
62-
kernel_size=3,
63-
stride=1,
64-
padding=1),
60+
layers.ConvBNReLU(
61+
in_channels, in_channels, kernel_size=3, stride=1, padding=1),
6562
nn.Dropout2D(drop_prob),
66-
nn.Conv2D(in_channels,
67-
num_classes,
68-
kernel_size=1,
69-
stride=1,
70-
bias_attr=False),
63+
nn.Conv2D(
64+
in_channels,
65+
num_classes,
66+
kernel_size=1,
67+
stride=1,
68+
bias_attr=False),
7169
)
72-
self.proj_head = ProjectionHead(dim_in=in_channels,
73-
proj_dim=self.proj_dim)
70+
self.proj_head = ProjectionHead(
71+
dim_in=in_channels, proj_dim=self.proj_dim)
72+
73+
self.pretrained = pretrained
74+
self.init_weight()
7475

7576
def init_weight(self):
7677
if self.pretrained is not None:
@@ -83,17 +84,19 @@ def forward(self, x):
8384
if self.training:
8485
emb = self.proj_head(feats)
8586
logit_list.append(
86-
F.interpolate(out,
87-
paddle.shape(x)[2:],
88-
mode='bilinear',
89-
align_corners=self.align_corners))
87+
F.interpolate(
88+
out,
89+
paddle.shape(x)[2:],
90+
mode='bilinear',
91+
align_corners=self.align_corners))
9092
logit_list.append({'seg': out, 'embed': emb})
9193
else:
9294
logit_list.append(
93-
F.interpolate(out,
94-
paddle.shape(x)[2:],
95-
mode='bilinear',
96-
align_corners=self.align_corners))
95+
F.interpolate(
96+
out,
97+
paddle.shape(x)[2:],
98+
mode='bilinear',
99+
align_corners=self.align_corners))
97100
return logit_list
98101

99102

@@ -105,6 +108,7 @@ class ProjectionHead(nn.Layer):
105108
proj_dim (int, optional): The output dimensions of projection head. Default: 256.
106109
proj (str, optional): The type of projection head, only support 'linear' and 'convmlp'. Default: 'convmlp'.
107110
"""
111+
108112
def __init__(self, dim_in, proj_dim=256, proj='convmlp'):
109113
super(ProjectionHead, self).__init__()
110114
if proj == 'linear':

0 commit comments

Comments
 (0)