Skip to content

Commit 4a21001

Browse files
fix/benches: fix cargo bench compilation [SKIP_CHANGELOG]
Looks like benchmarks are superseded by examples/benchmarks.rs and should be removed altogether, but while they are here they should at least compile cleanly. Well, now benches compile but panic on tensor dimension mismatch.
1 parent 432e33c commit 4a21001

File tree

1 file changed

+12
-22
lines changed

1 file changed

+12
-22
lines changed

benches/network_benches.rs

+12-22
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,8 @@ mod cuda {
1414
use std::sync::{Arc, RwLock};
1515
use leaf::layers::*;
1616
use leaf::layer::*;
17-
use leaf::network::*;
1817
use std::rc::Rc;
1918

20-
#[cfg(feature = "native")]
21-
fn native_backend() -> Rc<Backend<Native>> {
22-
Rc::new(Backend::<Native>::default().unwrap())
23-
}
24-
2519
#[cfg(feature = "cuda")]
2620
fn cuda_backend() -> Rc<Backend<Cuda>> {
2721
Rc::new(Backend::<Cuda>::default().unwrap())
@@ -76,7 +70,7 @@ mod cuda {
7670
#[ignore]
7771
#[cfg(feature = "cuda")]
7872
fn bench_mnsit_forward_1(b: &mut Bencher) {
79-
let mut cfg = NetworkConfig::default();
73+
let mut cfg = SequentialConfig::default();
8074
// set up input
8175
cfg.add_input("in", &vec![1, 30, 30]);
8276
cfg.add_input("label", &vec![1, 1, 10]);
@@ -98,18 +92,14 @@ mod cuda {
9892
// cfg.add_layer(loss_cfg);
9993

10094
let backend = cuda_backend();
101-
let native_backend = native_backend();
102-
let mut network = Network::from_config(backend.clone(), &cfg);
103-
let loss = &mut 0f32;
95+
let mut network = Layer::from_config(
96+
backend.clone(), &LayerConfig::new("network", LayerType::Sequential(cfg)));
10497

10598
let _ = timeit_loops!(10, {
10699
let inp = SharedTensor::<f32>::new(backend.device(), &vec![1, 30, 30]).unwrap();
107-
let label = SharedTensor::<f32>::new(native_backend.device(), &vec![1, 1, 10]).unwrap();
108-
109100
let inp_lock = Arc::new(RwLock::new(inp));
110-
let label_lock = Arc::new(RwLock::new(label));
111101

112-
network.forward(&[inp_lock, label_lock], loss);
102+
network.forward(&[inp_lock]);
113103
});
114104
// b.iter(|| {
115105
// for _ in 0..1 {
@@ -128,7 +118,7 @@ mod cuda {
128118
// #[ignore]
129119
#[cfg(feature = "cuda")]
130120
fn alexnet_forward(b: &mut Bencher) {
131-
let mut cfg = NetworkConfig::default();
121+
let mut cfg = SequentialConfig::default();
132122
// Layer: data
133123
cfg.add_input("data", &vec![128, 3, 224, 224]);
134124
// Layer: conv1
@@ -265,15 +255,15 @@ mod cuda {
265255

266256
let backend = cuda_backend();
267257
// let native_backend = native_backend();
268-
let mut network = Network::from_config(backend.clone(), &cfg);
258+
let mut network = Layer::from_config(
259+
backend.clone(), &LayerConfig::new("network", LayerType::Sequential(cfg)));
269260

270261
let func = || {
271262
let forward_time = timeit_loops!(1, {
272-
let loss = &mut 0f32;
273263
let inp = SharedTensor::<f32>::new(backend.device(), &vec![128, 3, 112, 112]).unwrap();
274264

275265
let inp_lock = Arc::new(RwLock::new(inp));
276-
network.forward(&[inp_lock], loss);
266+
network.forward(&[inp_lock]);
277267
});
278268
println!("Forward step: {}", forward_time);
279269
};
@@ -285,7 +275,7 @@ mod cuda {
285275
#[cfg(feature = "cuda")]
286276
fn small_alexnet_forward(b: &mut Bencher) {
287277
// let _ = env_logger::init();
288-
let mut cfg = NetworkConfig::default();
278+
let mut cfg = SequentialConfig::default();
289279
// Layer: data
290280
cfg.add_input("data", &vec![128, 3, 112, 112]);
291281
// Layer: conv1
@@ -422,14 +412,14 @@ mod cuda {
422412

423413
let backend = cuda_backend();
424414
// let native_backend = native_backend();
425-
let mut network = Network::from_config(backend.clone(), &cfg);
415+
let mut network = Layer::from_config(
416+
backend.clone(), &LayerConfig::new("network", LayerType::Sequential(cfg)));
426417

427418
let mut func = || {
428-
let loss = &mut 0f32;
429419
let inp = SharedTensor::<f32>::new(backend.device(), &vec![128, 3, 112, 112]).unwrap();
430420

431421
let inp_lock = Arc::new(RwLock::new(inp));
432-
network.forward(&[inp_lock], loss);
422+
network.forward(&[inp_lock]);
433423
};
434424
{ func(); bench_profile(b, func, 10); }
435425
}

0 commit comments

Comments
 (0)