Skip to content

Commit

Permalink
fix satrn export for paddle2.5 (#11096)
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 authored Oct 19, 2023
1 parent a0218a8 commit bf59c42
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions ppocr/modeling/heads/rec_satrn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,15 @@ def forward(self, feat, valid_ratios=None):
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
if valid_ratios is None:
valid_ratios = [1.0 for _ in range(feat.shape[0])]
bs = paddle.shape(feat)[0]
valid_ratios = paddle.full((bs, 1), 1., dtype=paddle.float32)

feat = self.position_enc(feat)
n, c, h, w = feat.shape

mask = paddle.zeros((n, h, w))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
valid_width = int(min(w, paddle.ceil(w * valid_ratio)))
mask[i, :, :valid_width] = 1

mask = mask.reshape([n, h * w])
Expand Down Expand Up @@ -347,7 +349,6 @@ def _get_sinusoid_encoding_table(self, n_position, d_hid):
return sinusoid_table.unsqueeze(0)

def forward(self, x):

x = x + self.position_table[:, :x.shape[1]].clone().detach()
return self.dropout(x)

Expand Down Expand Up @@ -514,7 +515,6 @@ def forward_train(self, feat, out_enc, targets, valid_ratio):
return outputs

def forward_test(self, feat, out_enc, valid_ratio):

src_mask = self._get_mask(out_enc, valid_ratio)
N = out_enc.shape[0]
init_target_seq = paddle.full(
Expand Down Expand Up @@ -556,13 +556,11 @@ def __init__(self, enc_cfg, dec_cfg, **kwargs):
self.decoder = SATRNDecoder(**dec_cfg)

def forward(self, feat, targets=None):

if targets is not None:
targets, valid_ratio = targets
else:
targets, valid_ratio = None, None
holistic_feat = self.encoder(feat, valid_ratio) # bsz c

final_out = self.decoder(feat, holistic_feat, targets, valid_ratio)

return final_out

0 comments on commit bf59c42

Please sign in to comment.