diff --git a/dabnn/layers/BinConv.cpp b/dabnn/layers/BinConv.cpp index 3004ce9..262488e 100644 --- a/dabnn/layers/BinConv.cpp +++ b/dabnn/layers/BinConv.cpp @@ -12,6 +12,8 @@ namespace bnn { +int align_to(int a, int b) { return (a + (b - 1) / b) * b; } + BinConv::BinConv(NetCP net, const std::string &name, css input, css weight, css output, int pad_h, int pad_w, int stride_h, int stride_w) : Layer(net, name, "Bin Conv"), @@ -43,9 +45,11 @@ BinConv::BinConv(NetCP net, const std::string &name, css input, css weight, const auto col_mat_name = "col_for_" + output + "_cal"; if (mat_map.find(col_mat_name) == mat_map.end()) { - const auto len = output_mat->h * output_mat->w * weight_mat->h * - weight_mat->w * input_mat->elem_c; - mat_map[col_mat_name] = std::make_shared(1, 1, len, bnn::DataType::Bit); + const auto len = + output_mat->h * output_mat->w * + align_to(weight_mat->h * weight_mat->w * input_mat->elem_c, 128); + mat_map[col_mat_name] = + std::make_shared(1, 1, len, bnn::DataType::Bit); } col_mat = mat(col_mat_name); @@ -119,14 +123,14 @@ void BinConv::forward_impl() const { output_mat->fill(0.f); // pack_mat_64(*input_mat, *binarized_mat); // bnn::im2col(*binarized_mat, weight_mat->h, weight_mat->w, - // pad_h, pad_w, stride_h, stride_w, 1, 1, - // *col_mat); + // pad_h, pad_w, stride_h, stride_w, 1, + // 1, *col_mat); // const auto len = output_mat->h * output_mat->w * weight_mat->h * // weight_mat->w * input_mat->elem_c; // Mat temp(1, 1, len, bnn::DataType::Float); - // im2col(*input_mat, weight_mat->h, weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1, temp); - // pack_mat(temp, *col_mat); + // im2col(*input_mat, weight_mat->h, weight_mat->w, pad_h, pad_w, + // stride_h, stride_w, 1, 1, temp); pack_mat(temp, *col_mat); bnn::fused_binarize_im2col(*input_mat, weight_mat->h, weight_mat->w, pad_h, pad_w, stride_h, stride_w, 1, 1, @@ -134,7 +138,7 @@ void BinConv::forward_impl() const { const int m = weight_mat->n; const int n = output_mat->h * output_mat->w; - const int k = weight_mat->h * weight_mat->w * weight_mat->c; + const int k = weight_mat->total() / weight_mat->n; bgemm(m, n, k, static_cast(transposed_weight_mat->data), m, static_cast(col_mat->data), k, static_cast(output_mat->data), m);