Skip to content

Commit 38d5f48

Browse files
committed
Auto merge of #53 - autumnai:fix/input_reshape, r=hobofan
Fix/input reshape
2 parents d081891 + 3456877 commit 38d5f48

File tree

6 files changed

+36
-19
lines changed

6 files changed

+36
-19
lines changed

src/layer.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ impl<B: IBackend + LayerOps<f32> + 'static> Layer<B> {
133133
LayerType::Softmax => Box::new(Softmax::default()),
134134
LayerType::ReLU => Box::new(ReLU),
135135
LayerType::Sigmoid => Box::new(Sigmoid),
136-
LayerType::NegativeLogLikelihood => Box::new(NegativeLogLikelihood::default()),
136+
LayerType::NegativeLogLikelihood(layer_config) => Box::new(NegativeLogLikelihood::from_config(&layer_config)),
137137
LayerType::Reshape(layer_config) => Box::new(Reshape::from_config(&layer_config)),
138138
}
139139
}
@@ -953,7 +953,7 @@ pub enum LayerType {
953953
Sigmoid,
954954
// Loss layers
955955
/// NegativeLogLikelihood Layer
956-
NegativeLogLikelihood,
956+
NegativeLogLikelihood(NegativeLogLikelihoodConfig),
957957
// Utility layers
958958
/// Reshape Layer
959959
Reshape(ReshapeConfig),

src/layers/common/convolution.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
//! Does this convolution with a set of learnable filters, each producing one
44
//! feature map in the output tensor.
55
use std::rc::Rc;
6-
use co::{IBackend, DeviceType, SharedTensor};
6+
use co::prelude::*;
77
use conn;
88
use layer::*;
99
use util::{ArcLock, native_backend, cast_vec_usize_to_i32};
10+
use weight::FillerType;
1011
use super::FilterLayer;
1112

1213
#[derive(Debug, Clone)]
@@ -126,7 +127,13 @@ impl<B: IBackend + conn::Convolution<f32>> ILayer<B> for Convolution<B> {
126127
let config = backend.new_convolution_config(&inp, &output_data, &mut filter,
127128
conn::ConvForwardAlgo::Auto, conn::ConvBackwardFilterAlgo::Auto, conn::ConvBackwardDataAlgo::Auto,
128129
&stride, &padding).unwrap();
130+
// resize and fill weights
129131
weights_data[0].write().unwrap().resize(filter.desc()).unwrap();
132+
let filler = FillerType::Glorot {
133+
input_size: inp.desc().size(),
134+
output_size: output_shape.size(),
135+
};
136+
filler.fill(&mut weights_data[0].write().unwrap());
130137
weights_gradient[0].write().unwrap().resize(filter.desc()).unwrap();
131138
self.convolution_configs = Some(Rc::new(config));
132139
}

src/layers/loss/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ macro_rules! impl_ilayer_loss {
1818
)
1919
}
2020

21-
pub use self::negative_log_likelihood::NegativeLogLikelihood;
21+
pub use self::negative_log_likelihood::{NegativeLogLikelihood, NegativeLogLikelihoodConfig};
2222

2323
pub mod negative_log_likelihood;

src/layers/loss/negative_log_likelihood.rs

+17-14
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,18 @@ use util::{ArcLock, native_backend};
77
#[derive(Debug, Clone)]
88
#[allow(missing_copy_implementations)]
99
/// NegativeLogLikelihood Loss Layer
10-
pub struct NegativeLogLikelihood;
10+
pub struct NegativeLogLikelihood {
11+
num_classes: usize,
12+
}
1113

1214
impl NegativeLogLikelihood {
15+
/// Create a NegativeLogLikelihood layer from a NegativeLogLikelihoodConfig.
16+
pub fn from_config(config: &NegativeLogLikelihoodConfig) -> NegativeLogLikelihood {
17+
NegativeLogLikelihood {
18+
num_classes: config.num_classes,
19+
}
20+
}
21+
1322
fn calculate_outer_num(softmax_axis: usize, input_shape: &[usize]) -> usize {
1423
input_shape.iter().take(softmax_axis + 1).fold(1, |prod, i| prod * i)
1524
}
@@ -25,14 +34,6 @@ impl NegativeLogLikelihood {
2534
_ => panic!("NegativeLogLikelihood layer only supports 1D/2D inputs")
2635
}
2736
}
28-
29-
fn num_classes(input_shape: &[usize]) -> usize {
30-
match input_shape.len() {
31-
1 => input_shape[0],
32-
2 => input_shape[1],
33-
_ => panic!("NegativeLogLikelihood layer only supports 1D/2D inputs"),
34-
}
35-
}
3637
}
3738

3839
impl<B: IBackend> ILayer<B> for NegativeLogLikelihood {
@@ -97,7 +98,7 @@ impl<B: IBackend> ComputeInputGradient<f32, B> for NegativeLogLikelihood {
9798
input_gradients: &mut [&mut SharedTensor<f32>]) {
9899
let labels = input_data[1];
99100
let batch_size = Self::batch_size(input_data[0].desc());
100-
let num_classes = Self::num_classes(input_data[0].desc());
101+
let num_classes = self.num_classes;
101102

102103
let native = native_backend();
103104
let native_labels = labels.get(native.device()).unwrap().as_native().unwrap().as_slice::<f32>();
@@ -114,8 +115,10 @@ impl<B: IBackend> ComputeInputGradient<f32, B> for NegativeLogLikelihood {
114115

115116
impl<B: IBackend> ComputeParametersGradient<f32, B> for NegativeLogLikelihood { }
116117

117-
impl ::std::default::Default for NegativeLogLikelihood {
118-
fn default() -> NegativeLogLikelihood {
119-
NegativeLogLikelihood
120-
}
118+
#[derive(Debug, Clone)]
119+
#[allow(missing_copy_implementations)]
120+
/// Specifies configuration parameters for a NegativeLogLikelihood Layer.
121+
pub struct NegativeLogLikelihoodConfig {
122+
/// How many different classes can be classified.
123+
pub num_classes: usize,
121124
}

src/layers/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pub use self::common::{
6464

6565
#[allow(unused_import_braces)]
6666
pub use self::loss::{
67-
NegativeLogLikelihood,
67+
NegativeLogLikelihood, NegativeLogLikelihoodConfig,
6868
};
6969

7070
#[allow(unused_import_braces)]

src/network.rs

+7
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,14 @@ impl<B: IBackend + LayerOps<f32> + 'static> Network<B> {
337337
for layer in &mut self.layers {
338338
for (blob_index, blob_name) in layer.input_blob_names().to_owned().iter().enumerate() {
339339
if blob_name == &self.input_blob_names[i] {
340+
let reshaped_shape = layer.input_blobs_data[blob_index].read().unwrap().desc().clone();
340341
layer.input_blobs_data[blob_index] = inp.clone();
342+
// reshape input tensor to the reshaped shape
343+
let old_shape = layer.input_blobs_data[blob_index].read().unwrap().desc().clone();
344+
if old_shape.size() != reshaped_shape.size() {
345+
panic!("The provided input does not have the expected shape");
346+
}
347+
layer.input_blobs_data[blob_index].write().unwrap().reshape(&reshaped_shape).unwrap();
341348
}
342349
}
343350
}

0 commit comments

Comments
 (0)