Skip to content

Commit 3456877

Browse files
committed
fix/nll: add NLLConfig to specify number of classes
1 parent 79f7109 commit 3456877

File tree

4 files changed

+21
-18
lines changed

4 files changed

+21
-18
lines changed

src/layer.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ impl<B: IBackend + LayerOps<f32> + 'static> Layer<B> {
133133
LayerType::Softmax => Box::new(Softmax::default()),
134134
LayerType::ReLU => Box::new(ReLU),
135135
LayerType::Sigmoid => Box::new(Sigmoid),
136-
LayerType::NegativeLogLikelihood => Box::new(NegativeLogLikelihood::default()),
136+
LayerType::NegativeLogLikelihood(layer_config) => Box::new(NegativeLogLikelihood::from_config(&layer_config)),
137137
LayerType::Reshape(layer_config) => Box::new(Reshape::from_config(&layer_config)),
138138
}
139139
}
@@ -953,7 +953,7 @@ pub enum LayerType {
953953
Sigmoid,
954954
// Loss layers
955955
/// NegativeLogLikelihood Layer
956-
NegativeLogLikelihood,
956+
NegativeLogLikelihood(NegativeLogLikelihoodConfig),
957957
// Utility layers
958958
/// Reshape Layer
959959
Reshape(ReshapeConfig),

src/layers/loss/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ macro_rules! impl_ilayer_loss {
1818
)
1919
}
2020

21-
pub use self::negative_log_likelihood::NegativeLogLikelihood;
21+
pub use self::negative_log_likelihood::{NegativeLogLikelihood, NegativeLogLikelihoodConfig};
2222

2323
pub mod negative_log_likelihood;

src/layers/loss/negative_log_likelihood.rs

+17-14
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,18 @@ use util::{ArcLock, native_backend};
77
#[derive(Debug, Clone)]
88
#[allow(missing_copy_implementations)]
99
/// NegativeLogLikelihood Loss Layer
10-
pub struct NegativeLogLikelihood;
10+
pub struct NegativeLogLikelihood {
11+
num_classes: usize,
12+
}
1113

1214
impl NegativeLogLikelihood {
15+
/// Create a NegativeLogLikelihood layer from a NegativeLogLikelihoodConfig.
16+
pub fn from_config(config: &NegativeLogLikelihoodConfig) -> NegativeLogLikelihood {
17+
NegativeLogLikelihood {
18+
num_classes: config.num_classes,
19+
}
20+
}
21+
1322
fn calculate_outer_num(softmax_axis: usize, input_shape: &[usize]) -> usize {
1423
input_shape.iter().take(softmax_axis + 1).fold(1, |prod, i| prod * i)
1524
}
@@ -25,14 +34,6 @@ impl NegativeLogLikelihood {
2534
_ => panic!("NegativeLogLikelihood layer only supports 1D/2D inputs")
2635
}
2736
}
28-
29-
fn num_classes(input_shape: &[usize]) -> usize {
30-
match input_shape.len() {
31-
1 => input_shape[0],
32-
2 => input_shape[1],
33-
_ => panic!("NegativeLogLikelihood layer only supports 1D/2D inputs"),
34-
}
35-
}
3637
}
3738

3839
impl<B: IBackend> ILayer<B> for NegativeLogLikelihood {
@@ -97,7 +98,7 @@ impl<B: IBackend> ComputeInputGradient<f32, B> for NegativeLogLikelihood {
9798
input_gradients: &mut [&mut SharedTensor<f32>]) {
9899
let labels = input_data[1];
99100
let batch_size = Self::batch_size(input_data[0].desc());
100-
let num_classes = Self::num_classes(input_data[0].desc());
101+
let num_classes = self.num_classes;
101102

102103
let native = native_backend();
103104
let native_labels = labels.get(native.device()).unwrap().as_native().unwrap().as_slice::<f32>();
@@ -114,8 +115,10 @@ impl<B: IBackend> ComputeInputGradient<f32, B> for NegativeLogLikelihood {
114115

115116
impl<B: IBackend> ComputeParametersGradient<f32, B> for NegativeLogLikelihood { }
116117

117-
impl ::std::default::Default for NegativeLogLikelihood {
118-
fn default() -> NegativeLogLikelihood {
119-
NegativeLogLikelihood
120-
}
118+
#[derive(Debug, Clone)]
119+
#[allow(missing_copy_implementations)]
120+
/// Specifies configuration parameters for a NegativeLogLikelihood Layer.
121+
pub struct NegativeLogLikelihoodConfig {
122+
/// How many different classes can be classified.
123+
pub num_classes: usize,
121124
}

src/layers/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pub use self::common::{
6464

6565
#[allow(unused_import_braces)]
6666
pub use self::loss::{
67-
NegativeLogLikelihood,
67+
NegativeLogLikelihood, NegativeLogLikelihoodConfig,
6868
};
6969

7070
#[allow(unused_import_braces)]

0 commit comments

Comments
 (0)