Skip to content

Commit f506a2c

Browse files
refactor/tests: convert tests and benches to the new memory access API [SKIP_CHANGELOG]
REFERENCE: autumnai/collenchyma#37, autumnai/collenchyma#62
1 parent e20fc95 commit f506a2c

File tree

3 files changed

+25
-24
lines changed

3 files changed

+25
-24
lines changed

benches/network_benches.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ mod cuda {
6969
#[bench]
7070
#[ignore]
7171
#[cfg(feature = "cuda")]
72-
fn bench_mnsit_forward_1(b: &mut Bencher) {
72+
fn bench_mnsit_forward_1(_b: &mut Bencher) {
7373
let mut cfg = SequentialConfig::default();
7474
// set up input
7575
cfg.add_input("in", &vec![1, 30, 30]);
@@ -96,7 +96,7 @@ mod cuda {
9696
backend.clone(), &LayerConfig::new("network", LayerType::Sequential(cfg)));
9797

9898
let _ = timeit_loops!(10, {
99-
let inp = SharedTensor::<f32>::new(backend.device(), &vec![1, 30, 30]).unwrap();
99+
let inp = SharedTensor::<f32>::new(&[1, 30, 30]);
100100
let inp_lock = Arc::new(RwLock::new(inp));
101101

102102
network.forward(&[inp_lock]);
@@ -260,7 +260,7 @@ mod cuda {
260260

261261
let func = || {
262262
let forward_time = timeit_loops!(1, {
263-
let inp = SharedTensor::<f32>::new(backend.device(), &vec![128, 3, 112, 112]).unwrap();
263+
let inp = SharedTensor::new(&[128, 3, 112, 112]);
264264

265265
let inp_lock = Arc::new(RwLock::new(inp));
266266
network.forward(&[inp_lock]);
@@ -416,7 +416,7 @@ mod cuda {
416416
backend.clone(), &LayerConfig::new("network", LayerType::Sequential(cfg)));
417417

418418
let mut func = || {
419-
let inp = SharedTensor::<f32>::new(backend.device(), &vec![128, 3, 112, 112]).unwrap();
419+
let inp = SharedTensor::<f32>::new(&[128, 3, 112, 112]);
420420

421421
let inp_lock = Arc::new(RwLock::new(inp));
422422
network.forward(&[inp_lock]);

examples/benchmarks.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ fn bench_alexnet() {
160160
let func = || {
161161
let forward_time = timeit_loops!(1, {
162162
{
163-
let inp = SharedTensor::<f32>::new(backend.device(), &vec![128, 3, 224, 224]).unwrap();
164-
163+
let inp = SharedTensor::<f32>::new(&[128, 3, 224, 224]);
165164
let inp_lock = Arc::new(RwLock::new(inp));
166165
network.forward(&[inp_lock.clone()]);
167166
}
@@ -242,8 +241,7 @@ fn bench_overfeat() {
242241
let func = || {
243242
let forward_time = timeit_loops!(1, {
244243
{
245-
let inp = SharedTensor::<f32>::new(backend.device(), &vec![128, 3, 231, 231]).unwrap();
246-
244+
let inp = SharedTensor::new(&[128, 3, 231, 231]);
247245
let inp_lock = Arc::new(RwLock::new(inp));
248246
network.forward(&[inp_lock.clone()]);
249247
}
@@ -339,7 +337,7 @@ fn bench_vgg_a() {
339337
let func = || {
340338
let forward_time = timeit_loops!(1, {
341339
{
342-
let inp = SharedTensor::<f32>::new(backend.device(), &vec![64, 3, 224, 224]).unwrap();
340+
let inp = SharedTensor::new(&[64, 3, 224, 224]);
343341

344342
let inp_lock = Arc::new(RwLock::new(inp));
345343
network.forward(&[inp_lock.clone()]);

tests/layer_specs.rs

+18-15
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ mod layer_spec {
6666
let loaded_weights = loaded_layer.learnable_weights_data();
6767
let loaded_weight_lock = loaded_weights[0].read().unwrap();
6868

69-
let original_weight = original_weight_lock.get(native_backend().device()).unwrap().as_native().unwrap().as_slice::<f32>();
70-
let loaded_weight = loaded_weight_lock.get(native_backend().device()).unwrap().as_native().unwrap().as_slice::<f32>();
69+
let original_weight = original_weight_lock.read(native_backend().device())
70+
.unwrap().as_native().unwrap().as_slice::<f32>();
71+
let loaded_weight = loaded_weight_lock.read(native_backend().device())
72+
.unwrap().as_native().unwrap().as_slice::<f32>();
7173

7274
assert_eq!(original_weight, loaded_weight);
7375
}
@@ -131,27 +133,28 @@ mod layer_spec {
131133
let mut reshape_network = Layer::from_config(cuda_backend.clone(), &LayerConfig::new("reshape_model", LayerType::Sequential(reshape_model)));
132134

133135
let input = vec![1f32, 1f32, 2f32];
134-
let mut normal_tensor = SharedTensor::<f32>::new(native_backend.device(), &(3)).unwrap();
136+
let mut normal_tensor = SharedTensor::<f32>::new(&[3]);
135137
// let mut normal_tensor_output = SharedTensor::<f32>::new(native_backend.device(), &(3)).unwrap();
136-
let mut reshape_tensor = SharedTensor::<f32>::new(native_backend.device(), &(3)).unwrap();
138+
let mut reshape_tensor = SharedTensor::<f32>::new(&[3]);
137139
// let mut reshape_tensor_output = SharedTensor::<f32>::new(native_backend.device(), &(3)).unwrap();
138-
write_to_memory(normal_tensor.get_mut(native_backend.device()).unwrap(), &input);
139-
write_to_memory(reshape_tensor.get_mut(native_backend.device()).unwrap(), &input);
140+
write_to_memory(normal_tensor.write_only(native_backend.device()).unwrap(), &input);
141+
write_to_memory(reshape_tensor.write_only(native_backend.device()).unwrap(), &input);
140142

141143
let normal_tensor_output = normal_network.forward(&[Arc::new(RwLock::new(normal_tensor))])[0].clone();
142-
let _ = normal_tensor_output.write().unwrap().add_device(native_backend.device());
143-
normal_tensor_output.write().unwrap().sync(native_backend.device()).unwrap();
144144
let normal_tensor_output_native_ = normal_tensor_output.read().unwrap();
145-
let normal_tensor_output_native = normal_tensor_output_native_.get(native_backend.device()).unwrap().as_native().unwrap();
146-
assert_eq!(&[0.7310585786f32, 0.7310586f32, 0.880797f32], normal_tensor_output_native.as_slice::<f32>());
145+
let normal_tensor_output_native = normal_tensor_output_native_
146+
.read(native_backend.device()).unwrap().as_native().unwrap();
147+
assert_eq!(&[0.7310585786f32, 0.7310586f32, 0.880797f32],
148+
normal_tensor_output_native.as_slice::<f32>());
147149

148150
let reshape_tensor_output = reshape_network.forward(&[Arc::new(RwLock::new(reshape_tensor))])[0].clone();
149-
let _ = reshape_tensor_output.write().unwrap().add_device(native_backend.device());
150-
reshape_tensor_output.write().unwrap().sync(native_backend.device()).unwrap();
151151
let reshape_tensor_output_native_ = reshape_tensor_output.read().unwrap();
152-
let reshape_tensor_output_native = reshape_tensor_output_native_.get(native_backend.device()).unwrap().as_native().unwrap();
153-
assert_eq!(&[0.7310585786f32, 0.7310586f32, 0.880797f32], reshape_tensor_output_native.as_slice::<f32>());
154-
assert_eq!(normal_tensor_output_native.as_slice::<f32>(), reshape_tensor_output_native.as_slice::<f32>());
152+
let reshape_tensor_output_native = reshape_tensor_output_native_
153+
.read(native_backend.device()).unwrap().as_native().unwrap();
154+
assert_eq!(&[0.7310585786f32, 0.7310586f32, 0.880797f32],
155+
reshape_tensor_output_native.as_slice::<f32>());
156+
assert_eq!(normal_tensor_output_native.as_slice::<f32>(),
157+
reshape_tensor_output_native.as_slice::<f32>());
155158
}
156159
}
157160

0 commit comments

Comments
 (0)