Skip to content
This repository was archived by the owner on Dec 30, 2019. It is now read-only.

Commit 4ae37c0

Browse files
author
Bernhard Schuster
committed
feat/pooling: stride before padding argument
The original issue is autumnai/collenchyma-nn#9
1 parent a16437c commit 4ae37c0

File tree

4 files changed

+44
-14
lines changed

4 files changed

+44
-14
lines changed

src/frameworks/cuda/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,8 @@ impl<T> Pooling<T> for Backend<Cuda>
911911
{
912912
fn new_pooling_config(&self,
913913
window: &[i32],
914-
padding: &[i32],
915-
stride: &[i32])
914+
stride: &[i32],
915+
padding: &[i32])
916916
-> Result<Self::CPOOL, ::co::error::Error> {
917917
let pooling_avg = ::cudnn::PoolingDescriptor::new(::cudnn::cudnnPoolingMode_t::CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING, window, padding, stride).unwrap();
918918
let pooling_max =

src/frameworks/native/mod.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,13 @@ impl<T> ::plugin::Pooling<T> for Backend<Native>
412412
{
413413
fn new_pooling_config(&self,
414414
window: &[i32],
415-
padding: &[i32],
416-
stride: &[i32])
415+
stride: &[i32],
416+
padding: &[i32])
417417
-> Result<Self::CPOOL, ::co::error::Error> {
418418
Ok(helper::PoolingConfig {
419419
window: window.to_vec(),
420-
padding: padding.to_vec(),
421420
stride: stride.to_vec(),
421+
padding: padding.to_vec(),
422422
})
423423
}
424424

@@ -513,8 +513,8 @@ impl<T> ::plugin::Pooling<T> for Backend<Native>
513513
input_idx_base: &mut [usize],
514514
window: &[i32],
515515
depth: usize,
516-
padding: &[i32],
517516
stride: &[i32],
517+
padding: &[i32],
518518
output: &mut [T],
519519
output_stride: &[usize],
520520
output_dim: &[usize],
@@ -541,8 +541,8 @@ impl<T> ::plugin::Pooling<T> for Backend<Native>
541541
input_idx_base,
542542
window,
543543
depth + 1,
544-
padding,
545544
&stride[1..],
545+
padding,
546546
output,
547547
&output_stride[1..],
548548
&output_dim[1..],
@@ -570,8 +570,8 @@ impl<T> ::plugin::Pooling<T> for Backend<Native>
570570
output_idx.resize(output_dim.len(), 0);
571571

572572
let window = &config.window[..];
573-
let padding = &config.padding[..];
574573
let stride = &config.stride[..];
574+
let padding = &config.padding[..];
575575
// do everything for each batch
576576
for batch in 0..input_dim[0] {
577577
// iterate over the batches!
@@ -591,8 +591,8 @@ impl<T> ::plugin::Pooling<T> for Backend<Native>
591591
&mut input_idx,
592592
&window,
593593
0,
594-
&padding,
595594
&stride,
595+
&padding,
596596
output,
597597
&output_stride[2..],
598598
&output_dim[2..],

src/plugin.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ pub trait LRN<F> : NN<F> {
409409
/// Provides the functionality for a Backend to support Pooling operations.
410410
pub trait Pooling<F> : NN<F> {
411411
/// Creates a new PoolingConfig, which needs to be passed to further pooling Operations.
412-
fn new_pooling_config(&self, window: &[i32], padding: &[i32], stride: &[i32])
412+
fn new_pooling_config(&self, window: &[i32], stride: &[i32], padding: &[i32])
413413
-> Result<Self::CPOOL, ::co::error::Error>;
414414

415415
/// Computes non-linear down-sampling ([max Pooling][pooling]) over the input Tensor `x`.

src/tests/pooling.rs

+34-4
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pub fn test_pooling_max<T, F: IFramework>(backend: Backend<F>)
4949

5050
let x = filled_tensor(&backend,&[4, 4, 4, 4], &inp);
5151
let mut r = SharedTensor::<T>::new(&[4, 4, 2, 4]);
52-
let conf = Pooling::<T>::new_pooling_config(&backend, &[2, 2], &[0, 0], &[2, 1])
52+
let conf = Pooling::<T>::new_pooling_config(&backend, &[2, 2], &[2, 1], &[0, 0])
5353
.unwrap();
5454

5555
backend.pooling_max(&x, &mut r, &conf).unwrap();
@@ -69,7 +69,7 @@ pub fn test_pooling_max_grad<T, F: IFramework>(backend: Backend<F>)
6969
let dx = filled_tensor(&backend,&[4, 4, 4, 4], &inp);
7070
let r = filled_tensor(&backend,&[4, 4, 2, 2], &inp[0..64]);
7171
let mut dr = SharedTensor::<T>::new(&[4, 4, 2, 2]);
72-
let conf = Pooling::<T>::new_pooling_config(&backend, &[2, 2], &[0, 0], &[2, 2])
72+
let conf = Pooling::<T>::new_pooling_config(&backend, &[2, 2], &[2, 2], &[0, 0])
7373
.unwrap();
7474

7575
backend.pooling_max_grad(&x, &dx, &r, &mut dr, &conf).unwrap();
@@ -90,7 +90,7 @@ pub fn test_pooling_avg<T, F: IFramework>(backend: Backend<F>)
9090

9191
let x = filled_tensor(&backend, &[4, 4, 4, 4], &inp);
9292
let mut r = SharedTensor::<T>::new(&[4, 4, 2, 2]);
93-
let conf = Pooling::<T>::new_pooling_config(&backend, &[2, 2], &[0, 0], &[2, 2])
93+
let conf = Pooling::<T>::new_pooling_config(&backend, &[2, 2], &[2, 2], &[0, 0])
9494
.unwrap();
9595

9696
backend.pooling_avg(&x, &mut r, &conf).unwrap();
@@ -111,7 +111,7 @@ pub fn test_pooling_avg_grad<T, F: IFramework>(backend: Backend<F>)
111111
let dx = filled_tensor(&backend, &[8, 4, 4, 4], &inp);
112112
let r = filled_tensor(&backend, &[8, 4, 2, 2], &inp[0..128]);
113113
let mut dr = SharedTensor::<T>::new(&[8, 4, 2, 2]);
114-
let conf = Pooling::<T>::new_pooling_config(&backend, &[2, 2], &[0, 0], &[2, 2])
114+
let conf = Pooling::<T>::new_pooling_config(&backend, &[2, 2], &[2, 2], &[0, 0])
115115
.unwrap();
116116

117117
backend.pooling_avg_grad(&x, &dx, &r, &mut dr, &conf).unwrap();
@@ -124,6 +124,36 @@ pub fn test_pooling_avg_grad<T, F: IFramework>(backend: Backend<F>)
124124
tensor_assert_eq(&dr, &dr_test, 1.0);
125125
}
126126

127+
pub fn cross_test_pooling_max<F: IFramework, G: IFramework>(backend_a: Backend<F>, backend_b: Backend<G>)
128+
where
129+
Backend<F>: Pooling<f32> + IBackend,
130+
Backend<G>: Pooling<f32> + IBackend {
131+
132+
let mut inp = vec![1.0; 256];
133+
inp[0] = 2.0;
134+
135+
let lower : f32 = -128.;
136+
let upper : f32 = 127.;
137+
let x = uniformly_random_tensor(&backend_a, &[4, 4, 4, 4], lower, upper);
138+
139+
let mut r_a = SharedTensor::<f32>::new(&[4, 4, 2, 4]);
140+
let mut r_b = SharedTensor::<f32>::new(&[4, 4, 2, 4]);
141+
142+
let conf_a = Pooling::<f32>::new_pooling_config(&backend_a, &[2, 2], &[2, 1], &[0, 0])
143+
.unwrap();
144+
let conf_b = Pooling::<f32>::new_pooling_config(&backend_b, &[2, 2], &[2, 1], &[0, 0])
145+
.unwrap();
146+
147+
backend_a.pooling_max(&x, &mut r_a, &conf_a).unwrap();
148+
backend_b.pooling_max(&x, &mut r_b, &conf_b).unwrap();
149+
tensor_assert_eq_tensor(&r_a, &r_b, 3.0);
150+
}
151+
152+
mod cross {
153+
use super::*;
154+
test_cross!(cross_test_pooling_max, cross_test_pooling_max_f32);
155+
}
156+
127157
mod cuda {
128158
use super::*;
129159
test_cuda!(test_pooling_avg, pooling_avg_f32, pooling_avg_f64);

0 commit comments

Comments
 (0)