Skip to content

Commit d1c1030

Browse files
committed
fix/sequential: synchronize after forward/backward
1 parent f5f25c3 commit d1c1030

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

src/layers/activation/relu.rs

-4
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,9 @@ impl<B: IBackend + Relu<f32> + ReluPointwise<f32>> ILayer<B> for ReLU {
3434
if let Some(inp) = input_data.get(0) {
3535
let read_inp = inp.read().unwrap();
3636
let input_desc = read_inp.desc();
37-
debug!("ONE");
3837
input_gradient[0].write().unwrap().resize(input_desc).unwrap();
39-
debug!("TWO");
4038
output_data[0].write().unwrap().resize(input_desc).unwrap();
41-
debug!("THREE");
4239
output_gradient[0].write().unwrap().resize(input_desc).unwrap();
43-
debug!("FOUR");
4440
}
4541
}
4642
}

src/layers/common/sequential.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ impl<B: IBackend + LayerOps<f32> + 'static> ILayer<B> for Sequential<B> {
249249
for layer in &self.layers {
250250
layer.borrow_mut().forward(&[]);
251251
}
252+
if let Some(last_layer) = self.layers.last() {
253+
last_layer.borrow_mut().synchronize();
254+
}
252255
}
253256

254257
fn backward_input(&self,
@@ -266,6 +269,9 @@ impl<B: IBackend + LayerOps<f32> + 'static> ILayer<B> for Sequential<B> {
266269
for layer in self.layers.iter().rev() {
267270
layer.borrow_mut().backward_input(&[]);
268271
}
272+
if let Some(first_layer) = self.layers.iter().rev().last() {
273+
first_layer.borrow_mut().synchronize();
274+
}
269275
}
270276

271277
fn backward_parameters(&self,
@@ -274,9 +280,12 @@ impl<B: IBackend + LayerOps<f32> + 'static> ILayer<B> for Sequential<B> {
274280
output_gradients: &[ArcLock<SharedTensor<f32>>],
275281
input_data: &[ArcLock<SharedTensor<f32>>],
276282
weights_gradients: &mut [ArcLock<SharedTensor<f32>>]) {
277-
for layer in &self.layers {
283+
for layer in self.layers.iter().rev() {
278284
layer.borrow_mut().backward_parameters();
279285
}
286+
if let Some(first_layer) = self.layers.iter().rev().last() {
287+
first_layer.borrow_mut().synchronize();
288+
}
280289
}
281290
}
282291

0 commit comments

Comments
 (0)