Skip to content

Commit 6868995

Browse files
fix/sgd: initialize weight gradient history with zeroes
SGD solver used unintialized history tensors. If there were some NaNs then whole network got poisoned after the first generation even if momentum was set to zero. This patch prefills gradient history with zeros. FIX: autumnai/leaf-examples#13
1 parent 6f41247 commit 6868995

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/solvers/sgd/mod.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ macro_rules! impl_isolver_sgd {
3131

3232
for weight_gradient in net.learnable_weights_gradients() {
3333
let shape = weight_gradient.read().unwrap().desc().clone();
34-
let history_tensor = Arc::new(RwLock::new(SharedTensor::new(IBackend::device(&*self.backend), &shape).unwrap()));
34+
let mut tensor = SharedTensor::new(IBackend::device(&*self.backend),
35+
&shape).unwrap();
36+
37+
let filler = ::weight::FillerType::Constant { value: 0f32 };
38+
filler.fill(&mut tensor);
39+
40+
let history_tensor = Arc::new(RwLock::new(tensor));
3541
self.history.push(history_tensor);
3642
}
3743
}

0 commit comments

Comments
 (0)