Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Update the scan implementation to follow P0571's guidance. #201

Merged
merged 1 commit into from
Sep 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ struct AgentScan
//---------------------------------------------------------------------

// The input value type
typedef typename std::iterator_traits<InputIteratorT>::value_type InputT;
using InputT = typename std::iterator_traits<InputIteratorT>::value_type;

// The output value type
typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
// The output value type -- used as the intermediate accumulator
// Per https://wg21.link/P0571, use InitValueT if provided, otherwise the
// input iterator's value type.
using OutputT =
typename If<Equals<InitValueT, NullType>::VALUE, InputT, InitValueT>::Type;

// Tile status descriptor interface type
typedef ScanTileState<OutputT> ScanTileStateT;
Expand Down
7 changes: 3 additions & 4 deletions cub/device/device_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,9 @@ struct DeviceScan
// Signed integer type for global offsets
typedef int OffsetT;

// The output value type
typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
// The output value type -- used as the intermediate accumulator
// Use the input value type per https://wg21.link/P0571
typedef typename std::iterator_traits<InputIteratorT>::value_type OutputT;

// Initial value
OutputT init_value = 0;
Expand Down
18 changes: 11 additions & 7 deletions cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,10 @@ template <
typename InitValueT, ///< The init_value element type for ScanOpT (cub::NullType for inclusive scans)
typename OffsetT, ///< Signed integer type for global offsets
typename SelectedPolicy = DeviceScanPolicy<
typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<OutputIteratorT>::value_type>::Type> >
// Accumulator type.
typename If<Equals<InitValueT, NullType>::VALUE,
typename std::iterator_traits<InputIteratorT>::value_type,
InitValueT>::Type>>
struct DispatchScan:
SelectedPolicy
{
Expand All @@ -269,11 +270,14 @@ struct DispatchScan:
INIT_KERNEL_THREADS = 128
};

// The output value type
typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
// The input value type
using InputT = typename std::iterator_traits<InputIteratorT>::value_type;

// The output value type -- used as the intermediate accumulator
// Per https://wg21.link/P0571, use InitValueT if provided, otherwise the
// input iterator's value type.
using OutputT =
typename If<Equals<InitValueT, NullType>::VALUE, InputT, InitValueT>::Type;

void* d_temp_storage; ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t& temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
Expand Down
47 changes: 35 additions & 12 deletions test/test_device_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -544,24 +544,28 @@ void Initialize(
template <
typename InputIteratorT,
typename OutputT,
typename ScanOpT>
typename ScanOpT,
typename InitialValueT>
void Solve(
InputIteratorT h_in,
OutputT *h_reference,
int num_items,
ScanOpT scan_op,
OutputT initial_value)
InitialValueT initial_value)
{
// Use the initial value type for accumulation per P0571
using AccumT = InitialValueT;

if (num_items > 0)
{
OutputT val = h_in[0];
h_reference[0] = initial_value;
OutputT inclusive = scan_op(initial_value, val);
AccumT val = static_cast<AccumT>(h_in[0]);
h_reference[0] = initial_value;
AccumT inclusive = scan_op(initial_value, val);

for (int i = 1; i < num_items; ++i)
{
val = h_in[i];
h_reference[i] = inclusive;
val = static_cast<AccumT>(h_in[i]);
h_reference[i] = static_cast<OutputT>(inclusive);
inclusive = scan_op(inclusive, val);
}
}
Expand All @@ -582,16 +586,20 @@ void Solve(
ScanOpT scan_op,
NullType)
{
// When no initial value type is supplied, use InputT for accumulation
// per P0571
using AccumT = typename std::iterator_traits<InputIteratorT>::value_type;

if (num_items > 0)
{
OutputT inclusive = h_in[0];
h_reference[0] = inclusive;
AccumT inclusive = h_in[0];
h_reference[0] = static_cast<OutputT>(inclusive);

for (int i = 1; i < num_items; ++i)
{
OutputT val = h_in[i];
AccumT val = h_in[i];
inclusive = scan_op(inclusive, val);
h_reference[i] = inclusive;
h_reference[i] = static_cast<OutputT>(inclusive);
}
}
}
Expand Down Expand Up @@ -746,7 +754,22 @@ void TestPointer(

// Initialize problem and solution
Initialize(gen_mode, h_in, num_items);
Solve(h_in, h_reference, num_items, scan_op, initial_value);

// If the output type is primitive and the operator is cub::Sum, the test
// dispatcher throws away scan_op and initial_value for exclusive scan.
// Without an initial_value arg, the accumulator switches to the input value
// type.
// Do the same thing here:
if (Traits<OutputT>::PRIMITIVE &&
Equals<ScanOpT, cub::Sum>::VALUE &&
!Equals<InitialValueT, NullType>::VALUE)
{
Solve(h_in, h_reference, num_items, cub::Sum{}, InputT{});
}
else
{
Solve(h_in, h_reference, num_items, scan_op, initial_value);
}

// Allocate problem device arrays
InputT *d_in = NULL;
Expand Down