|
3 | 3 | use co::plugin::numeric_helpers::Float;
|
4 | 4 | use co::memory::MemoryType;
|
5 | 5 |
|
6 |
| -#[derive(Debug, Copy, Clone)] |
7 |
| -#[allow(missing_docs)] |
8 |
| -pub struct ConvolutionConfig; |
9 | 6 | #[derive(Debug, Copy, Clone)]
|
10 | 7 | #[allow(missing_docs)]
|
11 | 8 | pub struct NormalizationConfig;
|
@@ -278,66 +275,12 @@ macro_rules! impl_ops_tanh_for {
|
278 | 275 | );
|
279 | 276 | }
|
280 | 277 |
|
281 |
| -#[macro_export] |
282 |
| -macro_rules! impl_ops_convolution_for { |
283 |
| - ($t:ident, $b:ty) => ( |
284 |
| - impl ::plugin::Convolution<$t> for $b { |
285 |
| - fn new_convolution_config( |
286 |
| - &self, |
287 |
| - src: &::co::tensor::SharedTensor<$t>, |
288 |
| - dest: &::co::tensor::SharedTensor<$t>, |
289 |
| - filter: &mut ::co::tensor::SharedTensor<$t>, |
290 |
| - stride: &[i32], |
291 |
| - zero_padding: &[i32] |
292 |
| - ) -> Result<Self::CC, ::co::error::Error> { |
293 |
| - unimplemented!(); |
294 |
| - Ok(helper::ConvolutionConfig) |
295 |
| - } |
296 |
| - fn convolution( |
297 |
| - &self, |
298 |
| - x: &mut ::co::tensor::SharedTensor<$t>, |
299 |
| - result: &mut ::co::tensor::SharedTensor<$t>, |
300 |
| - config: &Self::CC |
301 |
| - ) -> Result<(), ::co::error::Error> { |
302 |
| - unimplemented!(); |
303 |
| - Ok(()) |
304 |
| - } |
305 |
| - |
306 |
| - fn convolution_plain( |
307 |
| - &self, |
308 |
| - x: &::co::tensor::SharedTensor<$t>, |
309 |
| - result: &mut ::co::tensor::SharedTensor<$t>, |
310 |
| - config: &Self::CC |
311 |
| - ) -> Result<(), ::co::error::Error> { |
312 |
| - unimplemented!(); |
313 |
| - Ok(()) |
314 |
| - } |
315 |
| - |
316 |
| - fn convolution_grad( |
317 |
| - &self, |
318 |
| - x: &mut ::co::tensor::SharedTensor<$t>, |
319 |
| - x_diff: &mut ::co::tensor::SharedTensor<$t>, |
320 |
| - result: &mut ::co::tensor::SharedTensor<$t>, |
321 |
| - result_diff: &mut ::co::tensor::SharedTensor<$t>, |
322 |
| - config: &Self::CC |
323 |
| - ) -> Result<(), ::co::error::Error> { |
324 |
| - unimplemented!(); |
325 |
| - Ok(()) |
326 |
| - } |
327 |
| - |
328 |
| - fn convolution_grad_plain( |
329 |
| - &self, |
330 |
| - x: &::co::tensor::SharedTensor<$t>, |
331 |
| - x_diff: &::co::tensor::SharedTensor<$t>, |
332 |
| - result: &::co::tensor::SharedTensor<$t>, |
333 |
| - result_diff: &mut ::co::tensor::SharedTensor<$t>, |
334 |
| - config: &Self::CC |
335 |
| - ) -> Result<(), ::co::error::Error> { |
336 |
| - unimplemented!(); |
337 |
| - Ok(()) |
338 |
| - } |
339 |
| - } |
340 |
| - ); |
| 278 | +#[derive(Debug, Clone)] |
| 279 | +#[allow(missing_docs)] |
| 280 | +pub struct ConvolutionConfig { |
| 281 | + pub filter_shape: Vec<usize>, |
| 282 | + pub stride: Vec<i32>, |
| 283 | + pub padding: Vec<i32>, |
341 | 284 | }
|
342 | 285 |
|
343 | 286 | #[macro_export]
|
|
0 commit comments