Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Function "all" and "any" can not work normally #931

Open
DerekLin919 opened this issue Mar 21, 2025 · 4 comments
Open

[BUG] Function "all" and "any" can not work normally #931

DerekLin919 opened this issue Mar 21, 2025 · 4 comments

Comments

@DerekLin919
Copy link

Describe the Bug

When the dimension of the input tensor is 3 or more, the functions all and any can't work normally.

To Reproduce

Expected Behavior

Code Snippets

System Details (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • CUDA version: [e.g CUDA 11.4]
  • g++ version: [e.g. 9.3.0]

Additional Context

@cliffburdick
Copy link
Collaborator

Hi @DerekLin919 , can you please give an example of this not working? Our unit tests have a 4D tensor and it passes, and I just tried on a 3D tensor and it seems to be correct:

    auto a = make_tensor<float>({10, 10, 10});
    auto b = make_tensor<bool>({});
    (a = random<float>({10, 10, 10}, NORMAL)).run(exec);
    (b = all(a)).run(exec);
    print(b);
    a(1,1,1) = 0.0f;
    (b = all(a)).run(exec);
    print(b);

Output:

tensor_0_b: Tensor{bool} Rank: 0, Sizes:[], Strides:[]
 1
tensor_0_b: Tensor{bool} Rank: 0, Sizes:[], Strides:[]
 0

@DerekLin919
Copy link
Author

Thank you for your prompt reply! I will try the the unit test you provide on several days later. Thank you again!

@DerekLin919
Copy link
Author

I have tried the example you provided and it does work correctly. However, I reproduced the example provided in the documentation for converting a 4-dimensional tensor to a 2-dimensional tensor, and it does not run correctly. The program is as follows.
int main() {
auto a = make_tensor({10, 10, 10, 10});
auto b = make_tensor({10, 10});
(a = random({10, 10, 10}, NORMAL)).run();
(b = all(a, {0, 1})).run();
print(b);
a(1,1,1,1) = 0.0f;
(b = all(a, {0, 1})).run();
print(b);
}

The following compilation error is displayed.
/tmp/tmp.w1IQLiweob/MatX/include/matx/transforms/reduce.h(975): error: class "matx::detail::reduceOpAll<__nv_bool>" has no member "Reduce"
detected during:
instantiation of "T matx::detail::warpReduceOp(T, Op, uint32_t) [with T=__nv_bool, Op=matx::detail::reduceOpAll<__nv_bool>]"
(1162): here
instantiation of "void matx::detail::matxReduceKernel(OutType, InType, ReduceOp, matx::index_t) [with OutType=matx::detail::tensor_impl_t<__nv_bool, 2, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>, matx::detail::DenseTensorData<__nv_bool>>, InType=matx::detail::tensor_impl_t<float, 4, matx::tensor_desc_t<cuda::std::__4::array<long long, 4UL>, cuda::std::__4::array<long long, 4UL>, 4>, matx::detail::DenseTensorData>, ReduceOp=matx::detail::reduceOpAll<__nv_bool>]"
(1391): here
instantiation of "void matx::reduce(OutType, TensorIndexType, const InType &, ReduceOp, cudaStream_t, __nv_bool) [with OutType=matx::detail::tensor_impl_t<__nv_bool, 2, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>, matx::detail::DenseTensorData<__nv_bool>>, TensorIndexType=std::nullopt_t, InType=matx::detail::tensor_impl_t<float, 4, matx::tensor_desc_t<cuda::std::__4::array<long long, 4UL>, cuda::std::__4::array<long long, 4UL>, 4>, matx::detail::DenseTensorData>, ReduceOp=matx::detail::reduceOpAll<__nv_bool>, =true]"
(1452): here
instantiation of "void matx::reduce(OutType, const InType &, ReduceOp, cudaStream_t, __nv_bool) [with OutType=matx::detail::tensor_impl_t<__nv_bool, 2, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>, matx::detail::DenseTensorData<__nv_bool>>, InType=matx::detail::tensor_impl_t<float, 4, matx::tensor_desc_t<cuda::std::__4::array<long long, 4UL>, cuda::std::__4::array<long long, 4UL>, 4>, matx::detail::DenseTensorData>, ReduceOp=matx::detail::reduceOpAll<__nv_bool>]"
(2460): here
instantiation of "void matx::all_impl(OutType, const InType &, matx::cudaExecutor) [with OutType=matx::detail::tensor_impl_t<__nv_bool, 2, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>, matx::detail::DenseTensorData<__nv_bool>>, InType=matx::detail::tensor_impl_t<float, 4, matx::tensor_desc_t<cuda::std::__4::array<long long, 4UL>, cuda::std::__4::array<long long, 4UL>, 4>, matx::detail::DenseTensorData>]"
/tmp/tmp.w1IQLiweob/MatX/include/matx/operators/all.h(77): here
instantiation of "void matx::detail::AllOp<OpA, ORank>::Exec(Out &&, Executor &&) const [with OpA=matx::tensor_t<float, 4, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator>>, matx::tensor_desc_t<cuda::std::__4::array<long long, 4UL>, cuda::std::__4::array<long long, 4UL>, 4>>, ORank=2, Out=cuda::std::__4::tuple<matx::detail::tensor_impl_t<__nv_bool, 2, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>, matx::detail::DenseTensorData<__nv_bool>>>, Executor=matx::cudaExecutor &]"
/tmp/tmp.w1IQLiweob/MatX/include/matx/operators/set.h(179): here
instantiation of "void matx::detail::set<T, Op>::TransformExec(ShapeType &&, Executor &&) const noexcept [with T=matx::tensor_t<__nv_bool, 2, matx::basic_storage<matx::raw_pointer_buffer<__nv_bool, matx::matx_allocator<__nv_bool>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, Op=matx::detail::AllOp<matx::tensor_t<float, 4, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator>>, matx::tensor_desc_t<cuda::std::__4::array<long long, 4UL>, cuda::std::__4::array<long long, 4UL>, 4>>, 2>, ShapeType=cuda::std::__4::array<matx::index_t, 2UL>, Executor=matx::cudaExecutor &]"
/tmp/tmp.w1IQLiweob/MatX/include/matx/operators/base_operator.h(77): here
instantiation of "void matx::BaseOp::run(Ex &&) [with T=matx::detail::set<matx::tensor_t<__nv_bool, 2, matx::basic_storage<matx::raw_pointer_buffer<__nv_bool, matx::matx_allocator<__nv_bool>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, matx::detail::AllOp<matx::tensor_t<float, 4, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator>>, matx::tensor_desc_t<cuda::std::__4::array<long long, 4UL>, cuda::std::__4::array<long long, 4UL>, 4>>, 2>>, Ex=matx::cudaExecutor]"
/tmp/tmp.w1IQLiweob/MatX/include/matx/operators/base_operator.h(112): here
instantiation of "void matx::BaseOp::run(cudaStream_t) [with T=matx::detail::set<matx::tensor_t<__nv_bool, 2, matx::basic_storage<matx::raw_pointer_buffer<__nv_bool, matx::matx_allocator<__nv_bool>>>, matx::tensor_desc_t<cuda::std::__4::array<matx::index_t, 2UL>, cuda::std::__4::array<matx::index_t, 2UL>, 2>>, matx::detail::AllOp<matx::tensor_t<float, 4, matx::basic_storage<matx::raw_pointer_buffer<float, matx::matx_allocator>>, matx::tensor_desc_t<cuda::std::__4::array<long long, 4UL>, cuda::std::__4::array<long long, 4UL>, 4>>, 2>>]"
/tmp/tmp.w1IQLiweob/test/test_matx.cu(10): here
And there are a lot of similar errors in the back.

@cliffburdick
Copy link
Collaborator

Thanks, I was able to reproduce it now. We will look into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants