3
3
//! Does this convolution with a set of learnable filters, each producing one
4
4
//! feature map in the output tensor.
5
5
use std:: rc:: Rc ;
6
+ use std:: sync:: { Arc , RwLock } ;
6
7
use co:: prelude:: * ;
7
8
use conn;
9
+ use conn:: ConvolutionConfig as connConvolutionConfig;
8
10
use layer:: * ;
9
- use util:: { ArcLock , native_backend , cast_vec_usize_to_i32} ;
11
+ use util:: { ArcLock , cast_vec_usize_to_i32} ;
10
12
use weight:: FillerType ;
11
13
use super :: FilterLayer ;
12
14
@@ -19,7 +21,8 @@ pub struct Convolution<B: conn::Convolution<f32>> {
19
21
stride : Vec < usize > ,
20
22
padding : Vec < usize > ,
21
23
22
- convolution_configs : Option < Rc < B :: CC > > ,
24
+ workspace : Option < ArcLock < SharedTensor < u8 > > > ,
25
+ convolution_config : Option < Rc < B :: CC > > ,
23
26
}
24
27
25
28
impl < B : conn:: Convolution < f32 > > Convolution < B > {
@@ -34,7 +37,8 @@ impl<B: conn::Convolution<f32>> Convolution<B> {
34
37
35
38
axis : config. axis ( ) ,
36
39
37
- convolution_configs : None ,
40
+ workspace : None ,
41
+ convolution_config : None ,
38
42
}
39
43
}
40
44
@@ -103,7 +107,7 @@ impl<B: IBackend + conn::Convolution<f32>> ILayer<B> for Convolution<B> {
103
107
}
104
108
105
109
fn reshape ( & mut self ,
106
- backend : :: std :: rc :: Rc < B > ,
110
+ backend : Rc < B > ,
107
111
input_data : & mut Vec < ArcLock < SharedTensor < f32 > > > ,
108
112
input_gradient : & mut Vec < ArcLock < SharedTensor < f32 > > > ,
109
113
weights_data : & mut Vec < ArcLock < SharedTensor < f32 > > > ,
@@ -125,12 +129,10 @@ impl<B: IBackend + conn::Convolution<f32>> ILayer<B> for Convolution<B> {
125
129
let stride = cast_vec_usize_to_i32 ( self . stride_dims ( num_spatial_dims) ) ;
126
130
let padding = cast_vec_usize_to_i32 ( self . padding_dims ( num_spatial_dims) ) ;
127
131
128
- // add copy on native as workaround for bug in new_convolution_config
129
- let native = native_backend ( ) ;
130
- let _ = filter. add_device ( native. device ( ) ) ;
131
132
let config = backend. new_convolution_config ( & inp, & output_data, & mut filter,
132
133
conn:: ConvForwardAlgo :: Auto , conn:: ConvBackwardFilterAlgo :: Auto , conn:: ConvBackwardDataAlgo :: Auto ,
133
134
& stride, & padding) . unwrap ( ) ;
135
+
134
136
// resize and fill weights
135
137
weights_data[ 0 ] . write ( ) . unwrap ( ) . resize ( filter. desc ( ) ) . unwrap ( ) ;
136
138
let filler = FillerType :: Glorot {
@@ -139,9 +141,27 @@ impl<B: IBackend + conn::Convolution<f32>> ILayer<B> for Convolution<B> {
139
141
} ;
140
142
filler. fill ( & mut weights_data[ 0 ] . write ( ) . unwrap ( ) ) ;
141
143
weights_gradient[ 0 ] . write ( ) . unwrap ( ) . resize ( filter. desc ( ) ) . unwrap ( ) ;
142
- self . convolution_configs = Some ( Rc :: new ( config) ) ;
144
+ self . convolution_config = Some ( Rc :: new ( config) ) ;
143
145
}
144
146
}
147
+
148
+ fn resize_shared_workspace ( & mut self , backend : Rc < B > , workspace : Option < ArcLock < SharedTensor < u8 > > > ) -> Option < ArcLock < SharedTensor < u8 > > > {
149
+ let required_size = self . convolution_config . as_ref ( ) . unwrap ( ) . workspace_size ( ) ;
150
+ let new_workspace = if workspace. is_none ( ) {
151
+ Arc :: new ( RwLock :: new ( SharedTensor :: < u8 > :: new ( IBackend :: device ( & * backend) , & ( required_size) ) . unwrap ( ) ) )
152
+ } else {
153
+ let old_workspace = workspace. as_ref ( ) . unwrap ( ) . clone ( ) ;
154
+ let old_workspace_size = old_workspace. read ( ) . unwrap ( ) . capacity ( ) ;
155
+ if old_workspace_size < required_size {
156
+ Arc :: new ( RwLock :: new ( SharedTensor :: < u8 > :: new ( IBackend :: device ( & * backend) , & ( required_size) ) . unwrap ( ) ) )
157
+ } else {
158
+ workspace. unwrap ( )
159
+ }
160
+ } ;
161
+
162
+ self . workspace = Some ( new_workspace. clone ( ) ) ;
163
+ Some ( new_workspace)
164
+ }
145
165
}
146
166
147
167
impl < B : IBackend + conn:: Convolution < f32 > > ComputeOutput < f32 , B > for Convolution < B > {
@@ -151,8 +171,9 @@ impl<B: IBackend + conn::Convolution<f32>> ComputeOutput<f32, B> for Convolution
151
171
input_data : & [ & SharedTensor < f32 > ] ,
152
172
output_data : & mut [ & mut SharedTensor < f32 > ] ) {
153
173
let filter_data = weights[ 0 ] ;
154
- let conv_config = self . convolution_configs . as_ref ( ) . unwrap ( ) ;
155
- backend. convolution_plain ( filter_data, input_data[ 0 ] , output_data[ 0 ] , conv_config) . unwrap ( ) ;
174
+ let conv_config = self . convolution_config . as_ref ( ) . unwrap ( ) ;
175
+ let mut workspace = self . workspace . as_ref ( ) . unwrap ( ) . write ( ) . unwrap ( ) ;
176
+ backend. convolution_plain ( filter_data, input_data[ 0 ] , output_data[ 0 ] , & mut workspace, conv_config) . unwrap ( ) ;
156
177
}
157
178
}
158
179
@@ -165,9 +186,10 @@ impl<B: IBackend + conn::Convolution<f32>> ComputeInputGradient<f32, B> for Conv
165
186
input_data : & [ & SharedTensor < f32 > ] ,
166
187
input_gradients : & mut [ & mut SharedTensor < f32 > ] ) {
167
188
let filter_data = weights_data[ 0 ] ;
168
- let conv_config = self . convolution_configs . as_ref ( ) . unwrap ( ) ;
189
+ let conv_config = self . convolution_config . as_ref ( ) . unwrap ( ) ;
190
+ let mut workspace = self . workspace . as_ref ( ) . unwrap ( ) . write ( ) . unwrap ( ) ;
169
191
// compute gradient w.r.t. input
170
- backend. convolution_grad_data_plain ( filter_data, output_gradients[ 0 ] , input_gradients[ 0 ] , conv_config) . unwrap ( ) ;
192
+ backend. convolution_grad_data_plain ( filter_data, output_gradients[ 0 ] , input_gradients[ 0 ] , & mut workspace , conv_config) . unwrap ( ) ;
171
193
}
172
194
}
173
195
@@ -180,9 +202,10 @@ impl<B: IBackend + conn::Convolution<f32>> ComputeParametersGradient<f32, B> for
180
202
parameters_gradients : & mut [ & mut SharedTensor < f32 > ] ) {
181
203
// TODO: compute gradient w.r.t to bias
182
204
let filter_gradient = & mut parameters_gradients[ 0 ] ;
183
- let conv_config = self . convolution_configs . as_ref ( ) . unwrap ( ) ;
205
+ let conv_config = self . convolution_config . as_ref ( ) . unwrap ( ) ;
206
+ let mut workspace = self . workspace . as_ref ( ) . unwrap ( ) . write ( ) . unwrap ( ) ;
184
207
// compute gradient w.r.t. filter
185
- backend. convolution_grad_filter_plain ( input_data[ 0 ] , output_gradients[ 0 ] , filter_gradient, conv_config) . unwrap ( ) ;
208
+ backend. convolution_grad_filter_plain ( input_data[ 0 ] , output_gradients[ 0 ] , filter_gradient, & mut workspace , conv_config) . unwrap ( ) ;
186
209
}
187
210
}
188
211
0 commit comments