diff --git a/cub/agent/agent_scan.cuh b/cub/agent/agent_scan.cuh index 0781b3e9e9..3dec8ef9ad 100644 --- a/cub/agent/agent_scan.cuh +++ b/cub/agent/agent_scan.cuh @@ -100,12 +100,13 @@ struct AgentScan //--------------------------------------------------------------------- // The input value type - typedef typename std::iterator_traits::value_type InputT; + using InputT = typename std::iterator_traits::value_type; - // The output value type - typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? - typename std::iterator_traits::value_type, // ... then the input iterator's value type, - typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + // The output value type -- used as the intermediate accumulator + // Per https://wg21.link/P0571, use InitValueT if provided, otherwise the + // input iterator's value type. + using OutputT = + typename If::VALUE, InputT, InitValueT>::Type; // Tile status descriptor interface type typedef ScanTileState ScanTileStateT; diff --git a/cub/device/device_scan.cuh b/cub/device/device_scan.cuh index ae8a5902ce..e0a8e3a4ee 100644 --- a/cub/device/device_scan.cuh +++ b/cub/device/device_scan.cuh @@ -158,10 +158,9 @@ struct DeviceScan // Signed integer type for global offsets typedef int OffsetT; - // The output value type - typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? - typename std::iterator_traits::value_type, // ... then the input iterator's value type, - typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + // The output value type -- used as the intermediate accumulator + // Use the input value type per https://wg21.link/P0571 + typedef typename std::iterator_traits::value_type OutputT; // Initial value OutputT init_value = 0; diff --git a/cub/device/dispatch/dispatch_scan.cuh b/cub/device/dispatch/dispatch_scan.cuh index 24b30f102c..e3b3e3341f 100644 --- a/cub/device/dispatch/dispatch_scan.cuh +++ b/cub/device/dispatch/dispatch_scan.cuh @@ -254,9 +254,10 @@ template < typename InitValueT, ///< The init_value element type for ScanOpT (cub::NullType for inclusive scans) typename OffsetT, ///< Signed integer type for global offsets typename SelectedPolicy = DeviceScanPolicy< - typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? - typename std::iterator_traits::value_type, // ... then the input iterator's value type, - typename std::iterator_traits::value_type>::Type> > + // Accumulator type. + typename If::VALUE, + typename std::iterator_traits::value_type, + InitValueT>::Type>> struct DispatchScan: SelectedPolicy { @@ -269,11 +270,14 @@ struct DispatchScan: INIT_KERNEL_THREADS = 128 }; - // The output value type - typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? - typename std::iterator_traits::value_type, // ... then the input iterator's value type, - typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type + // The input value type + using InputT = typename std::iterator_traits::value_type; + // The output value type -- used as the intermediate accumulator + // Per https://wg21.link/P0571, use InitValueT if provided, otherwise the + // input iterator's value type. + using OutputT = + typename If::VALUE, InputT, InitValueT>::Type; void* d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. size_t& temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation diff --git a/test/test_device_scan.cu b/test/test_device_scan.cu index 998f6b12f5..30a7b26f8a 100644 --- a/test/test_device_scan.cu +++ b/test/test_device_scan.cu @@ -544,24 +544,28 @@ void Initialize( template < typename InputIteratorT, typename OutputT, - typename ScanOpT> + typename ScanOpT, + typename InitialValueT> void Solve( InputIteratorT h_in, OutputT *h_reference, int num_items, ScanOpT scan_op, - OutputT initial_value) + InitialValueT initial_value) { + // Use the initial value type for accumulation per P0571 + using AccumT = InitialValueT; + if (num_items > 0) { - OutputT val = h_in[0]; - h_reference[0] = initial_value; - OutputT inclusive = scan_op(initial_value, val); + AccumT val = static_cast(h_in[0]); + h_reference[0] = initial_value; + AccumT inclusive = scan_op(initial_value, val); for (int i = 1; i < num_items; ++i) { - val = h_in[i]; - h_reference[i] = inclusive; + val = static_cast(h_in[i]); + h_reference[i] = static_cast(inclusive); inclusive = scan_op(inclusive, val); } } @@ -582,16 +586,20 @@ void Solve( ScanOpT scan_op, NullType) { + // When no initial value type is supplied, use InputT for accumulation + // per P0571 + using AccumT = typename std::iterator_traits::value_type; + if (num_items > 0) { - OutputT inclusive = h_in[0]; - h_reference[0] = inclusive; + AccumT inclusive = h_in[0]; + h_reference[0] = static_cast(inclusive); for (int i = 1; i < num_items; ++i) { - OutputT val = h_in[i]; + AccumT val = h_in[i]; inclusive = scan_op(inclusive, val); - h_reference[i] = inclusive; + h_reference[i] = static_cast(inclusive); } } } @@ -746,7 +754,22 @@ void TestPointer( // Initialize problem and solution Initialize(gen_mode, h_in, num_items); - Solve(h_in, h_reference, num_items, scan_op, initial_value); + + // If the output type is primitive and the operator is cub::Sum, the test + // dispatcher throws away scan_op and initial_value for exclusive scan. + // Without an initial_value arg, the accumulator switches to the input value + // type. + // Do the same thing here: + if (Traits::PRIMITIVE && + Equals::VALUE && + !Equals::VALUE) + { + Solve(h_in, h_reference, num_items, cub::Sum{}, InputT{}); + } + else + { + Solve(h_in, h_reference, num_items, scan_op, initial_value); + } // Allocate problem device arrays InputT *d_in = NULL;