@@ -21,7 +21,7 @@ use std::rc::Rc;
21
21
use std:: sync:: { Arc , RwLock } ;
22
22
use util:: * ;
23
23
24
- #[ derive( Debug , Clone ) ]
24
+ #[ derive( Debug ) ]
25
25
/// Stochastic Gradient Descent with Momentum.
26
26
///
27
27
/// See [module description][1] for more information.
@@ -31,6 +31,11 @@ pub struct Momentum<SolverB: IBackend + SolverOps<f32>> {
31
31
history : Vec < ArcLock < SharedTensor < f32 > > > ,
32
32
/// The backend used for computing the gradient.
33
33
backend : Rc < SolverB > ,
34
+
35
+ /// Scalar that temporarily holds learing rate for weight update computations
36
+ lr : SharedTensor < f32 > ,
37
+ /// Scalar that temporarily holds momentum for weight update computations
38
+ momentum : SharedTensor < f32 > ,
34
39
}
35
40
36
41
impl < SolverB : IBackend + SolverOps < f32 > > Momentum < SolverB > {
@@ -41,9 +46,19 @@ impl<SolverB: IBackend + SolverOps<f32>> Momentum<SolverB> {
41
46
///
42
47
/// [2]: ../../../solver/struct.Solver.html#method.from_config
43
48
pub fn new ( backend : Rc < SolverB > ) -> Momentum < SolverB > {
49
+ let ( lr, momentum) = {
50
+ let device = IBackend :: device ( backend. as_ref ( ) ) ;
51
+
52
+ ( SharedTensor :: < f32 > :: new ( device, & 1 ) . unwrap ( ) ,
53
+ SharedTensor :: < f32 > :: new ( device, & 1 ) . unwrap ( ) )
54
+ } ;
55
+
44
56
Momentum {
45
57
history : Vec :: new ( ) ,
46
- backend : backend
58
+ backend : backend,
59
+
60
+ lr : lr,
61
+ momentum : momentum,
47
62
}
48
63
}
49
64
@@ -56,28 +71,31 @@ impl<B: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f32> + 'static> SGD
56
71
history_blob_id : usize ,
57
72
global_lr : & f32 ,
58
73
blob_lr : & f32 ) {
59
- let history_blob = & self . history [ history_blob_id] ;
60
- let local_momentum = config. momentum ;
61
- let local_lr = global_lr * blob_lr;
74
+ :: weight:: FillerType :: Constant {
75
+ value : global_lr * blob_lr
76
+ } . fill ( & mut self . lr ) ;
77
+
78
+ :: weight:: FillerType :: Constant {
79
+ value : config. momentum
80
+ } . fill ( & mut self . momentum ) ;
62
81
63
- let native_backend = native_backend ( ) ;
64
82
let backend = ISolver :: < B , NetB > :: backend ( self ) ;
65
83
let device = IBackend :: device ( backend) ;
66
84
67
- let lr_shared = native_scalar ( local_lr) ;
68
- let momentum_shared = native_scalar ( local_momentum) ;
85
+ let history_blob = & self . history [ history_blob_id] ;
86
+
87
+ let _ = weight_gradient. write ( ) . unwrap ( ) . add_device ( device) ;
88
+ weight_gradient. write ( ) . unwrap ( ) . sync ( device) . unwrap ( ) ;
89
+ let _ = history_blob. write ( ) . unwrap ( ) . add_device ( device) ;
90
+ history_blob. write ( ) . unwrap ( ) . sync ( device) . unwrap ( ) ;
69
91
70
- let _ = weight_gradient. write ( ) . unwrap ( ) . add_device ( native_backend. device ( ) ) ;
71
- weight_gradient. write ( ) . unwrap ( ) . sync ( native_backend. device ( ) ) . unwrap ( ) ;
72
- let _ = history_blob. write ( ) . unwrap ( ) . add_device ( native_backend. device ( ) ) ;
73
- history_blob. write ( ) . unwrap ( ) . sync ( native_backend. device ( ) ) . unwrap ( ) ;
74
- Axpby :: < f32 > :: axpby_plain ( & native_backend,
75
- & lr_shared,
76
- & weight_gradient. read ( ) . unwrap ( ) ,
77
- & momentum_shared,
78
- & mut history_blob. write ( ) . unwrap ( ) ) . unwrap ( ) ;
92
+ Axpby :: axpby_plain ( backend,
93
+ & self . lr ,
94
+ & weight_gradient. read ( ) . unwrap ( ) ,
95
+ & self . momentum ,
96
+ & mut history_blob. write ( ) . unwrap ( ) ) . unwrap ( ) ;
79
97
80
- native_backend . copy_plain (
98
+ backend . copy_plain (
81
99
& history_blob. read ( ) . unwrap ( ) , & mut weight_gradient. write ( ) . unwrap ( ) ) . unwrap ( ) ;
82
100
}
83
101
}
0 commit comments