Skip to content

Commit 08fd965

Browse files
perf/sgd: use GPU for computation of weight updates
Actually SGD now uses backend it was instantiated with. Before this patch it used hardcoded `Native` backend. On `leaf-examples mnist` and GTX 960 this provides 2.4x performance increase. CLOSE: 88
1 parent 3b25a48 commit 08fd965

File tree

1 file changed

+36
-18
lines changed

1 file changed

+36
-18
lines changed

src/solvers/sgd/momentum.rs

+36-18
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::rc::Rc;
2121
use std::sync::{Arc, RwLock};
2222
use util::*;
2323

24-
#[derive(Debug, Clone)]
24+
#[derive(Debug)]
2525
/// Stochastic Gradient Descent with Momentum.
2626
///
2727
/// See [module description][1] for more information.
@@ -31,6 +31,11 @@ pub struct Momentum<SolverB: IBackend + SolverOps<f32>> {
3131
history: Vec<ArcLock<SharedTensor<f32>>>,
3232
/// The backend used for computing the gradient.
3333
backend: Rc<SolverB>,
34+
35+
/// Scalar that temporarily holds learing rate for weight update computations
36+
lr: SharedTensor<f32>,
37+
/// Scalar that temporarily holds momentum for weight update computations
38+
momentum: SharedTensor<f32>,
3439
}
3540

3641
impl<SolverB: IBackend + SolverOps<f32>> Momentum<SolverB> {
@@ -41,9 +46,19 @@ impl<SolverB: IBackend + SolverOps<f32>> Momentum<SolverB> {
4146
///
4247
/// [2]: ../../../solver/struct.Solver.html#method.from_config
4348
pub fn new(backend: Rc<SolverB>) -> Momentum<SolverB> {
49+
let (lr, momentum) = {
50+
let device = IBackend::device(backend.as_ref());
51+
52+
(SharedTensor::<f32>::new(device, &1).unwrap(),
53+
SharedTensor::<f32>::new(device, &1).unwrap())
54+
};
55+
4456
Momentum {
4557
history: Vec::new(),
46-
backend: backend
58+
backend: backend,
59+
60+
lr: lr,
61+
momentum: momentum,
4762
}
4863
}
4964

@@ -56,28 +71,31 @@ impl<B: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f32> + 'static> SGD
5671
history_blob_id: usize,
5772
global_lr: &f32,
5873
blob_lr: &f32) {
59-
let history_blob = &self.history[history_blob_id];
60-
let local_momentum = config.momentum;
61-
let local_lr = global_lr * blob_lr;
74+
::weight::FillerType::Constant {
75+
value: global_lr * blob_lr
76+
}.fill(&mut self.lr);
77+
78+
::weight::FillerType::Constant {
79+
value: config.momentum
80+
}.fill(&mut self.momentum);
6281

63-
let native_backend = native_backend();
6482
let backend = ISolver::<B, NetB>::backend(self);
6583
let device = IBackend::device(backend);
6684

67-
let lr_shared = native_scalar(local_lr);
68-
let momentum_shared = native_scalar(local_momentum);
85+
let history_blob = &self.history[history_blob_id];
86+
87+
let _ = weight_gradient.write().unwrap().add_device(device);
88+
weight_gradient.write().unwrap().sync(device).unwrap();
89+
let _ = history_blob.write().unwrap().add_device(device);
90+
history_blob.write().unwrap().sync(device).unwrap();
6991

70-
let _ = weight_gradient.write().unwrap().add_device(native_backend.device());
71-
weight_gradient.write().unwrap().sync(native_backend.device()).unwrap();
72-
let _ = history_blob.write().unwrap().add_device(native_backend.device());
73-
history_blob.write().unwrap().sync(native_backend.device()).unwrap();
74-
Axpby::<f32>::axpby_plain(&native_backend,
75-
&lr_shared,
76-
&weight_gradient.read().unwrap(),
77-
&momentum_shared,
78-
&mut history_blob.write().unwrap()).unwrap();
92+
Axpby::axpby_plain(backend,
93+
&self.lr,
94+
&weight_gradient.read().unwrap(),
95+
&self.momentum,
96+
&mut history_blob.write().unwrap()).unwrap();
7997

80-
native_backend.copy_plain(
98+
backend.copy_plain(
8199
&history_blob.read().unwrap(), &mut weight_gradient.write().unwrap()).unwrap();
82100
}
83101
}

0 commit comments

Comments
 (0)