Skip to content

Commit 2d2d88f

Browse files
Implement an N-dimensional implicit gemm convolution algo for the native
backend. This algo needs to allocate a few index vectors, however it requires no additional space.
1 parent 19df0f2 commit 2d2d88f

File tree

2 files changed

+316
-83
lines changed

2 files changed

+316
-83
lines changed

src/frameworks/native/helper.rs

+6-63
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
use co::plugin::numeric_helpers::Float;
44
use co::memory::MemoryType;
55

6-
#[derive(Debug, Copy, Clone)]
7-
#[allow(missing_docs)]
8-
pub struct ConvolutionConfig;
96
#[derive(Debug, Copy, Clone)]
107
#[allow(missing_docs)]
118
pub struct NormalizationConfig;
@@ -278,66 +275,12 @@ macro_rules! impl_ops_tanh_for {
278275
);
279276
}
280277

281-
#[macro_export]
282-
macro_rules! impl_ops_convolution_for {
283-
($t:ident, $b:ty) => (
284-
impl ::plugin::Convolution<$t> for $b {
285-
fn new_convolution_config(
286-
&self,
287-
src: &::co::tensor::SharedTensor<$t>,
288-
dest: &::co::tensor::SharedTensor<$t>,
289-
filter: &mut ::co::tensor::SharedTensor<$t>,
290-
stride: &[i32],
291-
zero_padding: &[i32]
292-
) -> Result<Self::CC, ::co::error::Error> {
293-
unimplemented!();
294-
Ok(helper::ConvolutionConfig)
295-
}
296-
fn convolution(
297-
&self,
298-
x: &mut ::co::tensor::SharedTensor<$t>,
299-
result: &mut ::co::tensor::SharedTensor<$t>,
300-
config: &Self::CC
301-
) -> Result<(), ::co::error::Error> {
302-
unimplemented!();
303-
Ok(())
304-
}
305-
306-
fn convolution_plain(
307-
&self,
308-
x: &::co::tensor::SharedTensor<$t>,
309-
result: &mut ::co::tensor::SharedTensor<$t>,
310-
config: &Self::CC
311-
) -> Result<(), ::co::error::Error> {
312-
unimplemented!();
313-
Ok(())
314-
}
315-
316-
fn convolution_grad(
317-
&self,
318-
x: &mut ::co::tensor::SharedTensor<$t>,
319-
x_diff: &mut ::co::tensor::SharedTensor<$t>,
320-
result: &mut ::co::tensor::SharedTensor<$t>,
321-
result_diff: &mut ::co::tensor::SharedTensor<$t>,
322-
config: &Self::CC
323-
) -> Result<(), ::co::error::Error> {
324-
unimplemented!();
325-
Ok(())
326-
}
327-
328-
fn convolution_grad_plain(
329-
&self,
330-
x: &::co::tensor::SharedTensor<$t>,
331-
x_diff: &::co::tensor::SharedTensor<$t>,
332-
result: &::co::tensor::SharedTensor<$t>,
333-
result_diff: &mut ::co::tensor::SharedTensor<$t>,
334-
config: &Self::CC
335-
) -> Result<(), ::co::error::Error> {
336-
unimplemented!();
337-
Ok(())
338-
}
339-
}
340-
);
278+
#[derive(Debug, Clone)]
279+
#[allow(missing_docs)]
280+
pub struct ConvolutionConfig {
281+
pub filter_shape: Vec<usize>,
282+
pub stride: Vec<i32>,
283+
pub padding: Vec<i32>,
341284
}
342285

343286
#[macro_export]

0 commit comments

Comments
 (0)