Skip to content

Commit 83db20d

Browse files
committed
feature/solver: implement solver and sgd
Set up structure for solvers; Documentation is still lacking
1 parent 5dc879e commit 83db20d

File tree

10 files changed

+408
-100
lines changed

10 files changed

+408
-100
lines changed

src/layer.rs

+9-8
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ pub type WriteBlob<'_> = RwLockWriteGuard<'_, HeapBlob>;
6868

6969
#[derive(Debug)]
7070
/// The generic Layer
71-
pub struct Layer<'a> {
71+
pub struct Layer {
7272
/// The configuration of the Layer
73-
pub config: Box<&'a LayerConfig>,
73+
pub config: Box<LayerConfig>,
7474
/// The [implementation][1] of the Layer.
7575
/// [1]: ../layers/index.html
7676
///
@@ -97,16 +97,16 @@ pub struct Layer<'a> {
9797
weight_propagate_down: Vec<bool>,
9898
}
9999

100-
impl<'a> Layer<'a> {
100+
impl Layer {
101101
/// Creates a new Layer from a [LayerConfig][1].
102102
/// [1]: ./struct.LayerConfig.html
103103
///
104104
/// Used during [Network][2] initalization.
105105
///
106106
/// [2]: ../network/struct.Network.html
107-
pub fn from_config(config: &'a LayerConfig) -> Layer {
107+
pub fn from_config(config: &LayerConfig) -> Layer {
108108
let cl = config.clone();
109-
let cfg = Box::<&'a LayerConfig>::new(cl);
109+
let cfg = Box::<LayerConfig>::new(cl);
110110
Layer {
111111
loss: Vec::new(),
112112
blobs: Vec::new(),
@@ -171,14 +171,15 @@ pub trait ILayer {
171171
/// [2]: ./type.ReadBlob.html
172172
/// [3]: ./type.WriteBlob.html
173173
/// [3]: #method.forward_cpu
174+
#[allow(map_clone)]
174175
fn forward(&self, bottom: &[ArcLock<HeapBlob>], top: &mut Vec<ArcLock<HeapBlob>>) -> f32 {
175176
// Lock();
176177
// Reshape(bottom, top); // Reshape the layer to fit top & bottom blob
177178
let mut loss = 0f32;
178179

179180
let btm: Vec<_> = bottom.iter().map(|b| b.read().unwrap()).collect();
180181
// let tp: Vec<_> = top.iter().map(|b| b.write().unwrap()).collect();
181-
let tp_ref = top.iter().map(|t| t.clone()).collect::<Vec<_>>();
182+
let tp_ref = top.iter().cloned().collect::<Vec<_>>();
182183
let mut tp = &mut tp_ref.iter().map(|b| b.write().unwrap()).collect::<Vec<_>>();
183184
let mut tpo = &mut tp.iter_mut().map(|a| a).collect::<Vec<_>>();
184185
self.forward_cpu(&btm, tpo);
@@ -249,7 +250,7 @@ impl fmt::Debug for ILayer {
249250
}
250251
}
251252

252-
#[derive(Debug)]
253+
#[derive(Debug, Clone)]
253254
/// Layer Configuration Struct
254255
pub struct LayerConfig {
255256
/// The name of the Layer
@@ -331,7 +332,7 @@ impl LayerConfig {
331332
}
332333

333334

334-
#[derive(Debug)]
335+
#[derive(Debug, Clone)]
335336
/// Specifies training configuration for a weight blob.
336337
pub struct WeightConfig {
337338
/// The name of the weight blob -- useful for sharing weights among

src/layers/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
//!
5252
//! [2]: https://en.wikipedia.org/wiki/Activation_function
5353
//! [3]: ../layer/index.html
54+
55+
/// Implement [ILayer][1] for [activation layers][2].
56+
/// [1]: ./layer/trait.ILayer.html
57+
/// [2]: ./layers/activation/index.html
5458
macro_rules! impl_neuron_layer {
5559
() => (
5660
fn exact_num_top_blobs(&self) -> usize { 1 }

src/math.rs

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
use rblas::Axpy;
2-
use rblas::Dot;
1+
use rblas::*;
2+
3+
pub fn leaf_cpu_axpy(alpha: &f32, x: &[f32], y: &mut Vec<f32>) {
4+
Axpy::axpy(alpha, x, y);
5+
}
6+
7+
pub fn leaf_cpu_axpby(alpha: &f32, x: &[f32], beta: &f32, y: &mut Vec<f32>) {
8+
leaf_cpu_scal(beta, y);
9+
leaf_cpu_axpy(alpha, x, y);
10+
}
311

412
pub fn leaf_cpu_dot(x: &[f32], y: &[f32]) -> f32 {
513
Dot::dot(x, y)
614
}
715

8-
pub fn leaf_cpu_axpy(alpha: &f32, x: &[f32], y: &mut Vec<f32>) {
9-
Axpy::axpy(alpha, x, y);
16+
pub fn leaf_cpu_scal(alpha: &f32, x: &mut Vec<f32>) {
17+
Scal::scal(alpha, x)
1018
}

src/network.rs

+21-11
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ use phloem::Blob;
6262
/// A Network is usually used together with a [Solver][6] to optimize the networks' weights.
6363
///
6464
/// [6]: ../solver/struct.Solver.html
65-
pub struct Network<'a> {
65+
pub struct Network {
6666
/// Identifies the Network
6767
///
6868
/// The name is mainly used for logging purposes.
6969
pub name: String,
70-
layers: Vec<Layer<'a>>,
70+
layers: Vec<Layer>,
7171
layer_names: Vec<String>,
7272
layer_names_index: HashMap<String, usize>,
7373
layer_need_backwards: Vec<bool>,
@@ -114,8 +114,8 @@ pub struct Network<'a> {
114114
weights_weight_decay: Vec<Option<f32>>,
115115
}
116116

117-
impl<'a> Default for Network<'a> {
118-
fn default() -> Network<'a> {
117+
impl Default for Network {
118+
fn default() -> Network {
119119
Network {
120120
name: "".to_owned(),
121121
layers: vec![],
@@ -159,7 +159,7 @@ impl<'a> Default for Network<'a> {
159159
}
160160
}
161161

162-
impl<'a> Network<'a> {
162+
impl Network {
163163
/// Creates a Network from a [NetworkConfig][1].
164164
/// [1]: ./struct.NetworkConfig.html
165165
///
@@ -183,12 +183,12 @@ impl<'a> Network<'a> {
183183
/// to be executed for each blob and layer.
184184
///
185185
/// [1]: ./struct.NetworkConfig.html
186-
fn init(&mut self, in_config: &'a NetworkConfig) {
186+
fn init(&mut self, in_config: &NetworkConfig) {
187187
let config = in_config.clone();
188188
let available_blobs = &mut HashSet::new();
189189
let blob_name_to_idx = &mut HashMap::<String, usize>::new();
190190
for (input_id, _) in config.inputs.iter().enumerate() {
191-
self.append_top(config,
191+
self.append_top(&config,
192192
None,
193193
input_id,
194194
Some(available_blobs),
@@ -198,7 +198,7 @@ impl<'a> Network<'a> {
198198
self.resize_vecs(config.layers.len());
199199

200200
for (layer_id, _) in config.inputs.iter().enumerate() {
201-
self.init_layer(layer_id, config, available_blobs, blob_name_to_idx);
201+
self.init_layer(layer_id, &config, available_blobs, blob_name_to_idx);
202202
}
203203

204204
// Go through the net backwards to determine which blobs contribute to the
@@ -259,7 +259,7 @@ impl<'a> Network<'a> {
259259
/// [4]: ../layers/index.html
260260
fn init_layer(&mut self,
261261
layer_id: usize,
262-
config: &'a NetworkConfig,
262+
config: &NetworkConfig,
263263
available_blobs: &mut HashSet<String>,
264264
blob_name_to_idx: &mut HashMap<String, usize>) {
265265
// Caffe
@@ -868,9 +868,19 @@ impl<'a> Network<'a> {
868868
pub fn learnable_weights(&self) -> &Vec<ArcLock<HeapBlob>> {
869869
&self.learnable_weights
870870
}
871+
872+
#[allow(missing_docs)]
873+
pub fn weights_weight_decay(&self) -> &Vec<Option<f32>> {
874+
&self.weights_weight_decay
875+
}
876+
877+
#[allow(missing_docs)]
878+
pub fn weights_lr(&self) -> &Vec<Option<f32>> {
879+
&self.weights_lr
880+
}
871881
}
872882

873-
#[derive(Debug)]
883+
#[derive(Debug, Clone)]
874884
/// Defines the configuration of a network.
875885
///
876886
/// TODO: [DOC] When and why would you use this?
@@ -959,7 +969,7 @@ impl NetworkConfig {
959969
}
960970
}
961971

962-
#[derive(Debug)]
972+
#[derive(Debug, Clone)]
963973
/// Defines the state of a network.
964974
pub struct NetworkState {
965975
/// Defines the current mode of the network.

0 commit comments

Comments
 (0)