diff --git a/src/frameworks/cuda/mod.rs b/src/frameworks/cuda/mod.rs index 16c95f9..baa5459 100644 --- a/src/frameworks/cuda/mod.rs +++ b/src/frameworks/cuda/mod.rs @@ -154,19 +154,6 @@ impl ConvForwardAlgo { }; Ok(ConvForwardAlgo::from_cudnn(&algo)) } - - /// Check if the algo needs a cudnn workspace. - fn needs_cudnn_workspace(&self) -> Result { - Ok(match *self { - ConvForwardAlgo::Auto => return Err(::co::error::Error::Plugin(::co::plugin::Error::Plugin("Can't check necessary workspace size for ConvForwardAlgo::Auto. Use `find_cudnn_algo` to find an algorithm."))), - ConvForwardAlgo::GEMM => true, - ConvForwardAlgo::ImplicitGEMM => false, - ConvForwardAlgo::ImplicitPrecompiledGEMM => true, - ConvForwardAlgo::FFT => true, - ConvForwardAlgo::FFTTiling => true, - ConvForwardAlgo::Direct => true, - }) - } } impl ConvBackwardFilterAlgo { @@ -209,17 +196,6 @@ impl ConvBackwardFilterAlgo { }; Ok(ConvBackwardFilterAlgo::from_cudnn(&algo)) } - - /// Check if the algo needs a cudnn workspace. - fn needs_cudnn_workspace(&self) -> Result { - Ok(match *self { - ConvBackwardFilterAlgo::Auto => return Err(::co::error::Error::Plugin(::co::plugin::Error::Plugin("Can't check necessary workspace size for ConvBackwardFilterAlgo::Auto. Use `find_cudnn_algo` to find an algorithm."))), - ConvBackwardFilterAlgo::ImplicitGEMM => false, - ConvBackwardFilterAlgo::ImplicitGEMMSum => true, - ConvBackwardFilterAlgo::ImplicitPrecompiledGEMMSum => true, - ConvBackwardFilterAlgo::FFT => true, - }) - } } impl ConvBackwardDataAlgo { @@ -262,17 +238,6 @@ impl ConvBackwardDataAlgo { }; Ok(ConvBackwardDataAlgo::from_cudnn(&algo)) } - - /// Check if the algo needs a cudnn workspace. - fn needs_cudnn_workspace(&self) -> Result { - Ok(match *self { - ConvBackwardDataAlgo::Auto => return Err(::co::error::Error::Plugin(::co::plugin::Error::Plugin("Can't check necessary workspace size for ConvBackwardDataAlgo::Auto. Use `find_cudnn_algo` to find an algorithm."))), - ConvBackwardDataAlgo::ImplicitGEMM => false, - ConvBackwardDataAlgo::ImplicitGEMMSum => false, - ConvBackwardDataAlgo::FFT => true, - ConvBackwardDataAlgo::FFTTiling => true, - }) - } } macro_rules! impl_convolution_for_cuda_backend { @@ -304,13 +269,19 @@ macro_rules! impl_convolution_for_cuda_backend { let useable_algo_bwd_filter = try!(algo_bwd_filter.find_cudnn_algo(&filter_desc, &conv_desc, &src_desc, &dest_desc)); let useable_algo_bwd_data = try!(algo_bwd_data.find_cudnn_algo(&filter_desc, &conv_desc, &src_desc, &dest_desc)); - let workspace_size_fwd = API::get_convolution_forward_workspace_size(*CUDNN.id_c(), useable_algo_fwd.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(); - let workspace_size_bwd_filter = API::get_convolution_backward_filter_workspace_size(*CUDNN.id_c(), useable_algo_bwd_filter.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(); - // let workspace_size_bwd_data = API::get_convolution_backward_data_workspace_size(*CUDNN.id_c(), useable_algo_bwd_data.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(); - let workspace_size_bwd_data = match try!(useable_algo_bwd_data.needs_cudnn_workspace()) { - false => 1, - true => API::get_convolution_backward_data_workspace_size(*CUDNN.id_c(), useable_algo_bwd_data.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(), - }; + let mut workspace_size_fwd = API::get_convolution_forward_workspace_size(*CUDNN.id_c(), useable_algo_fwd.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(); + let mut workspace_size_bwd_filter = API::get_convolution_backward_filter_workspace_size(*CUDNN.id_c(), useable_algo_bwd_filter.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(); + let mut workspace_size_bwd_data = API::get_convolution_backward_data_workspace_size(*CUDNN.id_c(), useable_algo_bwd_data.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(); + + if workspace_size_fwd == 0 { + workspace_size_fwd = 8; + } + if workspace_size_bwd_filter == 0 { + workspace_size_bwd_filter = 8; + } + if workspace_size_bwd_data == 0 { + workspace_size_bwd_data = 8; + } Ok( ::cudnn::utils::ConvolutionConfig::new(