Skip to content

Commit a7f8a69

Browse files
committed
fix/test: fix tests after adding collenchyma
1 parent 7734e5d commit a7f8a69

8 files changed

+84
-17
lines changed

.travis.yml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ before_script:
2121
script:
2222
- |
2323
travis-cargo build &&
24+
travis-cargo test &&
2425
travis-cargo bench &&
2526
travis-cargo doc
2627
addons:

src/layer.rs

+22-4
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,21 @@ use std::sync::{RwLockReadGuard, RwLockWriteGuard};
3131
/// ```
3232
/// extern crate phloem;
3333
/// # extern crate leaf;
34+
/// # extern crate collenchyma as co;
3435
/// use phloem::Blob;
3536
/// use std::sync::{RwLock, RwLockReadGuard};
3637
/// # use leaf::layer::ReadBlob;
38+
/// # use co::backend::{Backend, BackendConfig};
39+
/// # use co::frameworks::Native;
40+
/// # use co::framework::IFramework;
41+
/// # use std::rc::Rc;
3742
///
3843
/// # fn main() {
39-
/// let lock = RwLock::new(Box::new(Blob::<f32>::of_shape(vec![3])));
44+
/// # let framework = Native::new();
45+
/// # let hardwares = framework.hardwares();
46+
/// # let backend_config = BackendConfig::new(framework, hardwares);
47+
/// # let backend = Rc::new(Backend::new(backend_config).unwrap());
48+
/// let lock = RwLock::new(Box::new(Blob::<f32>::of_shape(Some(backend.device()), &[3, 2, 3])));
4049
/// let read_blob: ReadBlob = lock.read().unwrap();
4150
/// # }
4251
/// ```
@@ -60,12 +69,21 @@ pub type ReadBlob<'_> = RwLockReadGuard<'_, HeapBlob>;
6069
/// ```
6170
/// extern crate phloem;
6271
/// # extern crate leaf;
72+
/// # extern crate collenchyma as co;
6373
/// use phloem::Blob;
6474
/// use std::sync::{RwLock, RwLockWriteGuard};
6575
/// # use leaf::layer::WriteBlob;
76+
/// # use co::backend::{Backend, BackendConfig};
77+
/// # use co::frameworks::Native;
78+
/// # use co::framework::IFramework;
79+
/// # use std::rc::Rc;
6680
///
6781
/// # fn main() {
68-
/// let lock = RwLock::new(Box::new(Blob::<f32>::of_shape(vec![3])));
82+
/// # let framework = Native::new();
83+
/// # let hardwares = framework.hardwares();
84+
/// # let backend_config = BackendConfig::new(framework, hardwares);
85+
/// # let backend = Rc::new(Backend::new(backend_config).unwrap());
86+
/// let lock = RwLock::new(Box::new(Blob::<f32>::of_shape(Some(backend.device()), &[4, 2, 1])));
6987
/// let read_blob: WriteBlob = lock.write().unwrap();
7088
/// # }
7189
/// ```
@@ -455,7 +473,7 @@ pub trait ILayer {
455473
/// [2]: ./type.ReadBlob.html
456474
/// [3]: ./type.WriteBlob.html
457475
/// [3]: #method.forward_cpu
458-
#[allow(map_clone)]
476+
#[cfg_attr(lint, allow(map_clone))]
459477
fn forward(&self, bottom: &[ArcLock<HeapBlob>], top: &mut Vec<ArcLock<HeapBlob>>) -> f32 {
460478
// Lock();
461479
// Reshape(bottom, top); // Reshape the layer to fit top & bottom blob
@@ -495,7 +513,7 @@ pub trait ILayer {
495513
/// [2]: ./type.ReadBlob.html
496514
/// [3]: ./type.WriteBlob.html
497515
/// [3]: #method.backward_cpu
498-
#[allow(map_clone)]
516+
#[cfg_attr(lint, allow(map_clone))]
499517
fn backward(&self, top: &[ArcLock<HeapBlob>], propagate_down: &[bool], bottom: &mut Vec<ArcLock<HeapBlob>>) {
500518
let tp: Vec<_> = top.iter().map(|b| b.read().unwrap()).collect();
501519
let bt_ref = bottom.iter().cloned().collect::<Vec<_>>();

src/network.rs

+18-2
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,25 @@ impl<B: IBackend + IBlas<f32>> Network<B> {
120120
/// ## Examples
121121
///
122122
/// ```
123+
/// # extern crate collenchyma;
124+
/// # extern crate leaf;
125+
///
123126
/// # use leaf::network::*;
127+
/// # use collenchyma::backend::{Backend, BackendConfig};
128+
/// # use collenchyma::frameworks::Native;
129+
/// # use collenchyma::framework::IFramework;
130+
/// # use std::rc::Rc;
131+
///
132+
/// # fn main() {
133+
/// // create backend
134+
/// let framework = Native::new();
135+
/// let hardwares = framework.hardwares();
136+
/// let backend_config = BackendConfig::new(framework, hardwares);
137+
/// let backend = Rc::new(Backend::new(backend_config).unwrap());
138+
/// // create network
124139
/// let cfg = NetworkConfig::default();
125-
/// Network::from_config(&cfg);
140+
/// Network::from_config(backend, &cfg);
141+
/// # }
126142
/// ```
127143
pub fn from_config(backend: Rc<B>, param: &NetworkConfig) -> Network<B> {
128144
let mut network = Network::default();
@@ -250,7 +266,7 @@ impl<B: IBackend + IBlas<f32>> Network<B> {
250266
/// Used during initialization of the Network.
251267
/// [1]: ../layer/struct.Layer.html
252268
/// [2]: ../layer/struct.Layer.html#method.connect
253-
#[allow(ptr_arg)]
269+
#[cfg_attr(lint, allow(ptr_arg))]
254270
fn init_input_blob(&mut self,
255271
blob_name: &str,
256272
input_shape: &Vec<usize>,

src/solver.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,18 @@ impl<S, B: IBackend + IBlas<f32>> Solver<S, B> {
3131
/// ## Example
3232
///
3333
/// ```
34+
/// # extern crate leaf;
35+
/// # extern crate collenchyma;
3436
/// # use leaf::solver::*;
37+
/// # use collenchyma::backend::Backend;
38+
/// # use collenchyma::frameworks::Native;
39+
///
40+
/// # fn main() {
3541
/// let cfg = SolverConfig{
3642
/// solver: SolverKind::SGD(SGDKind::Momentum),
3743
/// ..SolverConfig::default()};
38-
/// let solver = Solver::<Box<ISolver>>::from_config(&cfg);
44+
/// let solver = Solver::<Box<ISolver<Backend<Native>>>, Backend<Native>>::from_config(&cfg);
45+
/// # }
3946
/// ```
4047
pub fn from_config(config: &SolverConfig) -> Solver<Box<ISolver<Backend<Native>>>, Backend<Native>> {
4148
let framework = Native::new();

src/solvers/sgd/momentum.rs

-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
//! into the same direction you will reach the optimum faster.
1414
//! It also makes solving more stable.
1515
use co::backend::*;
16-
use co::framework::*;
17-
use co::frameworks::*;
1816
use co::libraries::blas::IBlas;
1917
use shared_memory::*;
2018
use network::Network;

tests/layer_specs.rs

+18-6
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
11
extern crate leaf;
22
extern crate phloem;
3+
extern crate collenchyma as co;
34

45
#[cfg(test)]
56
mod layer_spec {
67

78
use leaf::layer::*;
89
use phloem::Blob;
10+
use std::rc::Rc;
11+
use co::backend::{Backend, BackendConfig};
12+
use co::frameworks::Native;
13+
use co::framework::IFramework;
914

1015
fn new_layer_config() -> LayerConfig {
1116
LayerConfig::new("foo".to_owned(), LayerType::Sigmoid)
1217
}
1318

19+
fn backend() -> Rc<Backend<Native>> {
20+
let framework = Native::new();
21+
let hardwares = framework.hardwares();
22+
let backend_config = BackendConfig::new(framework, hardwares);
23+
Rc::new(Backend::new(backend_config).unwrap())
24+
}
25+
1426
#[test]
1527
fn new_layer() {
1628
let cfg = new_layer_config();
17-
Layer::from_config(&cfg);
29+
Layer::from_config(backend(), &cfg);
1830
}
1931

2032
#[test]
2133
fn dim_check_strict() {
2234
let cfg = WeightConfig { share_mode: DimCheckMode::Strict, ..WeightConfig::default() };
23-
let blob_one = Blob::<f32>::of_shape(vec![2, 3, 3]);
24-
let blob_two = Blob::<f32>::of_shape(vec![3, 2, 3]);
35+
let blob_one = Blob::<f32>::of_shape(Some(backend().device()), &[2, 3, 3]);
36+
let blob_two = Blob::<f32>::of_shape(Some(backend().device()), &[3, 2, 3]);
2537
let param_name = "foo".to_owned();
2638
let owner_name = "owner".to_owned();
2739
let layer_name = "layer".to_owned();
@@ -43,9 +55,9 @@ mod layer_spec {
4355
#[test]
4456
fn dim_check_permissive() {
4557
let cfg = WeightConfig { share_mode: DimCheckMode::Permissive, ..WeightConfig::default() };
46-
let blob_one = Blob::<f32>::of_shape(vec![2, 3, 3]);
47-
let blob_two = Blob::<f32>::of_shape(vec![3, 2, 3]);
48-
let blob_three = Blob::<f32>::of_shape(vec![3, 10, 3]);
58+
let blob_one = Blob::<f32>::of_shape(Some(backend().device()), &[2, 3, 3]);
59+
let blob_two = Blob::<f32>::of_shape(Some(backend().device()), &[3, 2, 3]);
60+
let blob_three = Blob::<f32>::of_shape(Some(backend().device()), &[3, 10, 3]);
4961
let param_name = "foo".to_owned();
5062
let owner_name = "owner".to_owned();
5163
let layer_name = "layer".to_owned();

tests/network_specs.rs

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
extern crate leaf;
22
extern crate phloem;
3+
extern crate collenchyma as co;
34

45
#[cfg(test)]
56
mod network_spec {
7+
use std::rc::Rc;
8+
use co::backend::{Backend, BackendConfig};
9+
use co::framework::IFramework;
10+
use co::frameworks::Native;
611
use leaf::network::*;
712

13+
fn backend() -> Rc<Backend<Native>> {
14+
let framework = Native::new();
15+
let hardwares = framework.hardwares();
16+
let backend_config = BackendConfig::new(framework, hardwares);
17+
Rc::new(Backend::new(backend_config).unwrap())
18+
}
19+
820
#[test]
921
fn new_layer() {
1022
let cfg = NetworkConfig::default();
11-
Network::from_config(&cfg);
23+
Network::from_config(backend(), &cfg);
1224
}
1325
}

tests/solver_specs.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
extern crate leaf;
2+
extern crate collenchyma as co;
23

34
#[cfg(test)]
45
mod network_spec {
56
use leaf::solver::*;
7+
use co::backend::Backend;
8+
use co::frameworks::Native;
69

710
#[test]
811
// fixed: always return base_lr.
@@ -40,6 +43,6 @@ mod network_spec {
4043
#[test]
4144
fn instantiate_solver_sgd_momentum() {
4245
let cfg = SolverConfig{ solver: SolverKind::SGD(SGDKind::Momentum), ..SolverConfig::default()};
43-
Solver::<Box<ISolver>>::from_config(&cfg);
46+
Solver::<Box<ISolver<Backend<Native>>>, Backend<Native>>::from_config(&cfg);
4447
}
4548
}

0 commit comments

Comments
 (0)