diff --git a/cmake/musa.cmake b/cmake/musa.cmake index 82694001e04f55..b6192ca60a26e8 100644 --- a/cmake/musa.cmake +++ b/cmake/musa.cmake @@ -47,6 +47,9 @@ endif() list(APPEND MUSA_MCC_FLAGS --cuda-gpu-arch=mp_21) list(APPEND MUSA_MCC_FLAGS -U__CUDA__) +# MUSA has compile conflicts of float16.h as platform::float16 overload std::is_floating_point and std::is_integer +list(APPEND MUSA_MCC_FLAGS -D__MUSA_NO_HALF_CONVERSIONS__) + #set(MUSA_VERBOSE_BUILD ON) if(CMAKE_BUILD_TYPE MATCHES Debug) list(APPEND MUSA_MCC_FLAGS -g2) diff --git a/paddle/phi/backends/gpu/gpu_primitives.h b/paddle/phi/backends/gpu/gpu_primitives.h index b891644679264d..e5d252d4ff89bc 100644 --- a/paddle/phi/backends/gpu/gpu_primitives.h +++ b/paddle/phi/backends/gpu/gpu_primitives.h @@ -266,54 +266,54 @@ CUDA_ATOMIC_WRAPPER(Add, phi::dtype::bfloat16) { PDBF16ToCUDABF16(val))); } #else -//CUDA_ATOMIC_WRAPPER(Add, phi::dtype::bfloat16) { -// // concrete packed bfloat16 value may exsits in lower or higher 16bits -// // of the 32bits address. -// uint32_t *address_as_ui = reinterpret_cast( -// reinterpret_cast(address) - -// (reinterpret_cast(address) & 0x02)); -// float val_f = static_cast(val); -// uint32_t old = *address_as_ui; -// uint32_t sum; -// uint32_t newval; -// uint32_t assumed; -// if (((uintptr_t)address & 0x02) == 0) { -// // the bfloat16 value stay at lower 16 bits of the address. -// do { -// assumed = old; -// old = atomicCAS( -// address_as_ui, assumed, bf16_add_to_low_half(assumed, val_f)); -// } while (old != assumed); -// phi::dtype::bfloat16 ret; -// ret.x = old & 0xFFFFu; -// return ret; -// } else { -// // the bfloat16 value stay at higher 16 bits of the address. -// do { -// assumed = old; -// old = atomicCAS( -// address_as_ui, assumed, bf16_add_to_high_half(assumed, val_f)); -// } while (old != assumed); -// phi::dtype::bfloat16 ret; -// ret.x = old >> 16; -// return ret; -// } -//} +CUDA_ATOMIC_WRAPPER(Add, phi::dtype::bfloat16) { + // concrete packed bfloat16 value may exsits in lower or higher 16bits + // of the 32bits address. + uint32_t *address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 0x02)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t sum; + uint32_t newval; + uint32_t assumed; + if (((uintptr_t)address & 0x02) == 0) { + // the bfloat16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS( + address_as_ui, assumed, bf16_add_to_low_half(assumed, val_f)); + } while (old != assumed); + phi::dtype::bfloat16 ret; + ret.x = old & 0xFFFFu; + return ret; + } else { + // the bfloat16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS( + address_as_ui, assumed, bf16_add_to_high_half(assumed, val_f)); + } while (old != assumed); + phi::dtype::bfloat16 ret; + ret.x = old >> 16; + return ret; + } +} #endif -//CUDA_ATOMIC_WRAPPER(Add, complex) { -// float *real = reinterpret_cast(address); -// float *imag = real + 1; -// return complex(CudaAtomicAdd(real, val.real), -// CudaAtomicAdd(imag, val.imag)); -//} -// -//CUDA_ATOMIC_WRAPPER(Add, complex) { -// double *real = reinterpret_cast(address); -// double *imag = real + 1; -// return complex(CudaAtomicAdd(real, val.real), -// CudaAtomicAdd(imag, val.imag)); -//} +CUDA_ATOMIC_WRAPPER(Add, complex) { + float *real = reinterpret_cast(address); + float *imag = real + 1; + return complex(CudaAtomicAdd(real, val.real), + CudaAtomicAdd(imag, val.imag)); +} + +CUDA_ATOMIC_WRAPPER(Add, complex) { + double *real = reinterpret_cast(address); + double *imag = real + 1; + return complex(CudaAtomicAdd(real, val.real), + CudaAtomicAdd(imag, val.imag)); +} // For atomicMax USE_CUDA_ATOMIC(Max, int); @@ -470,38 +470,38 @@ inline static __device__ uint32_t bf16_max_to_high_half(uint32_t val, float x) { //return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); } -//CUDA_ATOMIC_WRAPPER(Max, phi::dtype::bfloat16) { -// if (*address >= val) { -// return *address; -// } -// uint32_t *address_as_ui = reinterpret_cast( -// reinterpret_cast(address) - -// (reinterpret_cast(address) & 0x02)); -// float val_f = static_cast(val); -// uint32_t old = *address_as_ui; -// uint32_t assumed; -// if (((uintptr_t)address & 0x02) == 0) { -// // The bfloat16 value stay at lower 16 bits of the address. -// do { -// assumed = old; -// old = atomicCAS( -// address_as_ui, assumed, bf16_max_to_low_half(assumed, val_f)); -// } while (old != assumed); -// phi::dtype::bfloat16 ret; -// ret.x = old & 0xFFFFu; -// return ret; -// } else { -// // The bfloat16 value stay at higher 16 bits of the address. -// do { -// assumed = old; -// old = atomicCAS( -// address_as_ui, assumed, bf16_max_to_high_half(assumed, val_f)); -// } while (old != assumed); -// phi::dtype::bfloat16 ret; -// ret.x = old >> 16; -// return ret; -// } -//} +CUDA_ATOMIC_WRAPPER(Max, phi::dtype::bfloat16) { + if (*address >= val) { + return *address; + } + uint32_t *address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 0x02)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t assumed; + if (((uintptr_t)address & 0x02) == 0) { + // The bfloat16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS( + address_as_ui, assumed, bf16_max_to_low_half(assumed, val_f)); + } while (old != assumed); + phi::dtype::bfloat16 ret; + ret.x = old & 0xFFFFu; + return ret; + } else { + // The bfloat16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS( + address_as_ui, assumed, bf16_max_to_high_half(assumed, val_f)); + } while (old != assumed); + phi::dtype::bfloat16 ret; + ret.x = old >> 16; + return ret; + } +} // For atomicMin USE_CUDA_ATOMIC(Min, int); @@ -658,38 +658,38 @@ inline static __device__ uint32_t bf16_min_to_high_half(uint32_t val, float x) { //return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); } -//CUDA_ATOMIC_WRAPPER(Min, phi::dtype::bfloat16) { -// if (*address <= val) { -// return *address; -// } -// uint32_t *address_as_ui = reinterpret_cast( -// reinterpret_cast(address) - -// (reinterpret_cast(address) & 0x02)); -// float val_f = static_cast(val); -// uint32_t old = *address_as_ui; -// uint32_t assumed; -// if (((uintptr_t)address & 0x02) == 0) { -// // The bfloat16 value stay at lower 16 bits of the address. -// do { -// assumed = old; -// old = atomicCAS( -// address_as_ui, assumed, bf16_min_to_low_half(assumed, val_f)); -// } while (old != assumed); -// phi::dtype::bfloat16 ret; -// ret.x = old & 0xFFFFu; -// return ret; -// } else { -// // The bfloat16 value stay at higher 16 bits of the address. -// do { -// assumed = old; -// old = atomicCAS( -// address_as_ui, assumed, bf16_min_to_high_half(assumed, val_f)); -// } while (old != assumed); -// phi::dtype::bfloat16 ret; -// ret.x = old >> 16; -// return ret; -// } -//} +CUDA_ATOMIC_WRAPPER(Min, phi::dtype::bfloat16) { + if (*address <= val) { + return *address; + } + uint32_t *address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 0x02)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t assumed; + if (((uintptr_t)address & 0x02) == 0) { + // The bfloat16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS( + address_as_ui, assumed, bf16_min_to_low_half(assumed, val_f)); + } while (old != assumed); + phi::dtype::bfloat16 ret; + ret.x = old & 0xFFFFu; + return ret; + } else { + // The bfloat16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS( + address_as_ui, assumed, bf16_min_to_high_half(assumed, val_f)); + } while (old != assumed); + phi::dtype::bfloat16 ret; + ret.x = old >> 16; + return ret; + } +} #ifdef PADDLE_WITH_CUDA /* diff --git a/paddle/phi/common/float16.h b/paddle/phi/common/float16.h index 00de1bf605157b..5b53828251e40c 100644 --- a/paddle/phi/common/float16.h +++ b/paddle/phi/common/float16.h @@ -1019,13 +1019,14 @@ struct is_pod { is_standard_layout::value; }; -//template <> -//struct is_floating_point -// : std::integral_constant< -// bool, -// std::is_same< -// phi::dtype::float16, -// typename std::remove_cv::type>::value> {}; +template <> +struct is_floating_point + : std::integral_constant< + bool, + std::is_same< + phi::dtype::float16, + typename std::remove_cv::type>::value> {}; + template <> struct is_signed { static const bool value = true; diff --git a/paddle/phi/common/scalar.h b/paddle/phi/common/scalar.h index c8ced345a637a1..4286dfcc1d0fac 100644 --- a/paddle/phi/common/scalar.h +++ b/paddle/phi/common/scalar.h @@ -140,10 +140,10 @@ class ScalarBase { return static_cast(data_.f32); case DataType::FLOAT64: return static_cast(data_.f64); - //case DataType::FLOAT16: - // return static_cast(data_.f16); - //case DataType::BFLOAT16: - // return static_cast(data_.bf16); + case DataType::FLOAT16: + return static_cast(data_.f16); + case DataType::BFLOAT16: + return static_cast(data_.bf16); case DataType::INT32: return static_cast(data_.i32); case DataType::INT64: @@ -162,10 +162,10 @@ class ScalarBase { return static_cast(data_.ui8); case DataType::BOOL: return static_cast(data_.b); - //case DataType::COMPLEX64: - // return static_cast(data_.c64); - //case DataType::COMPLEX128: - // return static_cast(data_.c128); + case DataType::COMPLEX64: + return static_cast(data_.c64); + case DataType::COMPLEX128: + return static_cast(data_.c128); default: PD_THROW("Invalid enum scalar data type `", dtype_, "`."); } diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index d72046a82e0cb5..f96fdb1f28b63c 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -281,9 +281,17 @@ namespace phi { PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT16, int16_t, __VA_ARGS__) \ PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT32, int32_t, __VA_ARGS__) \ PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT64, int64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::phi::DataType::BFLOAT16, phi::bfloat16, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::phi::DataType::FLOAT16, phi::float16, __VA_ARGS__) \ PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::FLOAT32, float, __VA_ARGS__) \ PD_PRIVATE_CASE_TYPE( \ NAME, ::phi::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::phi::DataType::COMPLEX64, phi::complex64, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::phi::DataType::COMPLEX128, phi::complex128, __VA_ARGS__) \ default: \ PADDLE_THROW(phi::errors::InvalidArgument( \ "Invalid enum data type `%d`.", static_cast(__dtype__))); \ diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index edc647a968f635..b53de3beef9aa4 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -34,8 +34,7 @@ class ReduceAdd { typename tensor_t, std::enable_if_t::value>* = nullptr> __device__ void operator()(tensor_t* self_data, tensor_t* src_data) const { - // TODO(@caizhi): enable cudaAtomicAdd - //phi::CudaAtomicAdd(self_data, *src_data); + phi::CudaAtomicAdd(self_data, *src_data); } template ::value>* = nullptr> diff --git a/paddle/phi/kernels/funcs/im2col.cu b/paddle/phi/kernels/funcs/im2col.cu index a14d9886bb821b..87c82adbb7fbe8 100644 --- a/paddle/phi/kernels/funcs/im2col.cu +++ b/paddle/phi/kernels/funcs/im2col.cu @@ -472,8 +472,7 @@ __global__ void col2imOCF(const T* col_data, if (height_offset >= 0 && height_offset < im_height && width_offset >= 0 && width_offset < im_width) { - // TODO(@caizhi): compile CudaAtomicAdd - //phi::CudaAtomicAdd(im_data + im_offset, col_data[col_offset]); + phi::CudaAtomicAdd(im_data + im_offset, col_data[col_offset]); } } } diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index 9194b6dcc24d64..19a391ea150b6f 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -76,8 +76,7 @@ __global__ void ScatterCUDAKernel(const T* params, if (overwrite) { *(output + out_i) = *(params + i); } else { - // TODO(@caizhi): enable compiling cudaAtomicAdd - //phi::CudaAtomicAdd(output + out_i, *(params + i)); + phi::CudaAtomicAdd(output + out_i, *(params + i)); } } } @@ -111,8 +110,7 @@ __global__ void ScatterNdCUDAKernel(const T* update, temp *= output_dims[j]; } int64_t output_i = gather_i + slice_i; - // TODO(@caizhi): enable compiling cudaAtomicAdd - //phi::CudaAtomicAdd(output + output_i, *(update + i)); + phi::CudaAtomicAdd(output + output_i, *(update + i)); } } diff --git a/paddle/phi/kernels/funcs/segment_pooling.cu b/paddle/phi/kernels/funcs/segment_pooling.cu index ef13af5b4eff55..0b6df55bdeff19 100644 --- a/paddle/phi/kernels/funcs/segment_pooling.cu +++ b/paddle/phi/kernels/funcs/segment_pooling.cu @@ -61,8 +61,7 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids, } if (j > 0) { if (last_segment_id == first_segment_id) { - // TODO(@caizhi): enable compiling CudaAtomicAdd - //phi::CudaAtomicAdd(summed_ids + last_segment_id, sum); + phi::CudaAtomicAdd(summed_ids + last_segment_id, sum); } else { *(summed_ids + last_segment_id) = sum; } @@ -72,7 +71,7 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids, sum += T(1); last_segment_id = current_segment_id; } - //phi::CudaAtomicAdd(summed_ids + last_segment_id, sum); + phi::CudaAtomicAdd(summed_ids + last_segment_id, sum); } } @@ -113,9 +112,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, last_segment_id * inner_dim_size + segment_offset; if (last_segment_id == first_segment_id) { - // TODO(@caizhi): enable compiling CudaAtomicAdd - //phi::CudaAtomicAdd(output + output_index, - // sum / *(summed_ids + last_segment_id)); + phi::CudaAtomicAdd(output + output_index, + sum / *(summed_ids + last_segment_id)); } else { *(output + output_index) = sum / *(summed_ids + last_segment_id); } @@ -126,9 +124,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, last_segment_id = current_segment_id; } Index output_index = last_segment_id * inner_dim_size + segment_offset; - // TODO(@caizhi): enable compiling CudaAtomicAdd - //phi::CudaAtomicAdd(output + output_index, - // sum / *(summed_ids + last_segment_id)); + phi::CudaAtomicAdd(output + output_index, + sum / *(summed_ids + last_segment_id)); } } @@ -219,9 +216,7 @@ class MaxPool { DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; } DEVICE inline T atomic(T* address, const T val) { - // TODO(@caizhi): enable compiling CudaAtomicAdd - //return phi::CudaAtomicMax(address, val); - return val; + return phi::CudaAtomicMax(address, val); } }; @@ -231,9 +226,7 @@ class MinPool { DEVICE inline T initial() { return static_cast(FLT_MAX); } DEVICE inline void compute(const T& x, T* y) { *y = *y < x ? *y : x; } DEVICE inline T atomic(T* address, const T val) { - // TODO(@caizhi): enable compiling CudaAtomicAdd - //return phi::CudaAtomicMin(address, val); - return val; + return phi::CudaAtomicMin(address, val); } }; @@ -243,9 +236,7 @@ class SumPool { DEVICE inline T initial() { return static_cast(0); } DEVICE inline void compute(const T& x, T* y) { *y = *y + x; } DEVICE inline T atomic(T* address, const T val) { - // TODO(@caizhi): enable compiling CudaAtomicAdd - //return phi::CudaAtomicAdd(address, val); - return val; + return phi::CudaAtomicAdd(address, val); } }; diff --git a/paddle/phi/kernels/funcs/selected_rows_functor.cu b/paddle/phi/kernels/funcs/selected_rows_functor.cu index 416c6b18b4c483..2947701befcc70 100644 --- a/paddle/phi/kernels/funcs/selected_rows_functor.cu +++ b/paddle/phi/kernels/funcs/selected_rows_functor.cu @@ -129,8 +129,7 @@ __global__ void SelectedRowsAddTensorKernel(const T* selected_rows, // Since index in rows of SelectedRows can be duplicate, we can not use // tensor_out[index] += selected_rows[index]; Instead, we have to use // AtomicAdd to avoid concurrent write error. - // TODO(@caizhi): enable it - // phi::CudaAtomicAdd(tensor_out + index, selected_rows[index]); + phi::CudaAtomicAdd(tensor_out + index, selected_rows[index]); } } } // namespace @@ -282,8 +281,7 @@ __global__ void SelectedRowsAddToTensorKernel(const T* selected_rows, for (int index = tid; index < row_numel; index += block_size) { // Since index in rows of SelectedRows can be duplicate, we have to use // Atomic Operation to avoid concurrent write error. - // TODO(@caizhi): enable it - // phi::CudaAtomicAdd(tensor_out + index, selected_rows[index]); + phi::CudaAtomicAdd(tensor_out + index, selected_rows[index]); } } } // namespace @@ -364,8 +362,7 @@ __global__ void MergeAddKernel(const T* input, input += ty * row_numel; out += out_idx * row_numel; for (int index = tid; index < row_numel; index += block_size) { - // TODO(@caizhi): enable it - // phi::CudaAtomicAdd(out + index, input[index]); + phi::CudaAtomicAdd(out + index, input[index]); } }