@@ -7,9 +7,18 @@ use util::{ArcLock, native_backend};
7
7
#[ derive( Debug , Clone ) ]
8
8
#[ allow( missing_copy_implementations) ]
9
9
/// NegativeLogLikelihood Loss Layer
10
- pub struct NegativeLogLikelihood ;
10
+ pub struct NegativeLogLikelihood {
11
+ num_classes : usize ,
12
+ }
11
13
12
14
impl NegativeLogLikelihood {
15
+ /// Create a NegativeLogLikelihood layer from a NegativeLogLikelihoodConfig.
16
+ pub fn from_config ( config : & NegativeLogLikelihoodConfig ) -> NegativeLogLikelihood {
17
+ NegativeLogLikelihood {
18
+ num_classes : config. num_classes ,
19
+ }
20
+ }
21
+
13
22
fn calculate_outer_num ( softmax_axis : usize , input_shape : & [ usize ] ) -> usize {
14
23
input_shape. iter ( ) . take ( softmax_axis + 1 ) . fold ( 1 , |prod, i| prod * i)
15
24
}
@@ -25,14 +34,6 @@ impl NegativeLogLikelihood {
25
34
_ => panic ! ( "NegativeLogLikelihood layer only supports 1D/2D inputs" )
26
35
}
27
36
}
28
-
29
- fn num_classes ( input_shape : & [ usize ] ) -> usize {
30
- match input_shape. len ( ) {
31
- 1 => input_shape[ 0 ] ,
32
- 2 => input_shape[ 1 ] ,
33
- _ => panic ! ( "NegativeLogLikelihood layer only supports 1D/2D inputs" ) ,
34
- }
35
- }
36
37
}
37
38
38
39
impl < B : IBackend > ILayer < B > for NegativeLogLikelihood {
@@ -97,7 +98,7 @@ impl<B: IBackend> ComputeInputGradient<f32, B> for NegativeLogLikelihood {
97
98
input_gradients : & mut [ & mut SharedTensor < f32 > ] ) {
98
99
let labels = input_data[ 1 ] ;
99
100
let batch_size = Self :: batch_size ( input_data[ 0 ] . desc ( ) ) ;
100
- let num_classes = Self :: num_classes ( input_data [ 0 ] . desc ( ) ) ;
101
+ let num_classes = self . num_classes ;
101
102
102
103
let native = native_backend ( ) ;
103
104
let native_labels = labels. get ( native. device ( ) ) . unwrap ( ) . as_native ( ) . unwrap ( ) . as_slice :: < f32 > ( ) ;
@@ -114,8 +115,10 @@ impl<B: IBackend> ComputeInputGradient<f32, B> for NegativeLogLikelihood {
114
115
115
116
impl < B : IBackend > ComputeParametersGradient < f32 , B > for NegativeLogLikelihood { }
116
117
117
- impl :: std:: default:: Default for NegativeLogLikelihood {
118
- fn default ( ) -> NegativeLogLikelihood {
119
- NegativeLogLikelihood
120
- }
118
+ #[ derive( Debug , Clone ) ]
119
+ #[ allow( missing_copy_implementations) ]
120
+ /// Specifies configuration parameters for a NegativeLogLikelihood Layer.
121
+ pub struct NegativeLogLikelihoodConfig {
122
+ /// How many different classes can be classified.
123
+ pub num_classes : usize ,
121
124
}
0 commit comments