Skip to content

Commit

Permalink
fix pool bug, test=develop (#28359)
Browse files Browse the repository at this point in the history
  • Loading branch information
LDOUBLEV authored Nov 3, 2020
1 parent 6115c14 commit 17db031
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/paddle/nn/functional/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def avg_pool1d(x,
x, 'pooling_type', 'avg', 'ksize', kernel_size, 'global_pooling',
False, 'strides', stride, 'paddings', padding, 'padding_algorithm',
padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode,
'use_mkldnn', False, 'exclusive', not exclusive, 'data_format',
'use_mkldnn', False, 'exclusive', exclusive, 'data_format',
data_format)
return squeeze(output, [2])

Expand Down Expand Up @@ -338,7 +338,7 @@ def avg_pool2d(x,
x, 'pooling_type', 'avg', 'ksize', kernel_size, 'global_pooling',
False, 'padding_algorithm', padding_algorithm, 'strides', stride,
'paddings', padding, 'use_cudnn', True, 'ceil_mode', ceil_mode,
'use_mkldnn', False, 'exclusive', not exclusive, 'data_format',
'use_mkldnn', False, 'exclusive', exclusive, 'data_format',
data_format)
if divisor_override is None:
return output
Expand Down Expand Up @@ -452,7 +452,7 @@ def avg_pool3d(x,
x, 'pooling_type', 'avg', 'ksize', kernel_size, 'strides', stride,
'paddings', padding, 'global_pooling', False, 'padding_algorithm',
padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode,
'use_mkldnn', False, 'exclusive', not exclusive, 'data_format',
'use_mkldnn', False, 'exclusive', exclusive, 'data_format',
data_format)
if divisor_override is None:
return output
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/nn/layer/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def forward(self, x):
stride=self.stride,
padding=self.padding,
return_mask=self.return_mask,
ceil_mode=self.ceil_mode,
data_format=self.data_format,
name=self.name)

Expand Down Expand Up @@ -594,6 +595,7 @@ def forward(self, x):
stride=self.stride,
padding=self.padding,
return_mask=self.return_mask,
ceil_mode=self.ceil_mode,
data_format=self.data_format,
name=self.name)

Expand Down

0 comments on commit 17db031

Please sign in to comment.