Skip to content

Commit 79f7109

Browse files
committed
fix/convolution: add missing weight initialization
1 parent 20d97e9 commit 79f7109

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/layers/common/convolution.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
//! Does this convolution with a set of learnable filters, each producing one
44
//! feature map in the output tensor.
55
use std::rc::Rc;
6-
use co::{IBackend, DeviceType, SharedTensor};
6+
use co::prelude::*;
77
use conn;
88
use layer::*;
99
use util::{ArcLock, native_backend, cast_vec_usize_to_i32};
10+
use weight::FillerType;
1011
use super::FilterLayer;
1112

1213
#[derive(Debug, Clone)]
@@ -126,7 +127,13 @@ impl<B: IBackend + conn::Convolution<f32>> ILayer<B> for Convolution<B> {
126127
let config = backend.new_convolution_config(&inp, &output_data, &mut filter,
127128
conn::ConvForwardAlgo::Auto, conn::ConvBackwardFilterAlgo::Auto, conn::ConvBackwardDataAlgo::Auto,
128129
&stride, &padding).unwrap();
130+
// resize and fill weights
129131
weights_data[0].write().unwrap().resize(filter.desc()).unwrap();
132+
let filler = FillerType::Glorot {
133+
input_size: inp.desc().size(),
134+
output_size: output_shape.size(),
135+
};
136+
filler.fill(&mut weights_data[0].write().unwrap());
130137
weights_gradient[0].write().unwrap().resize(filter.desc()).unwrap();
131138
self.convolution_configs = Some(Rc::new(config));
132139
}

0 commit comments

Comments
 (0)