Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Stable sorting order for -0.0 and +0.0 for float and double. #218

Merged
merged 1 commit into from
Jan 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "../block/block_store.cuh"
#include "../block/block_radix_rank.cuh"
#include "../block/block_exchange.cuh"
#include "../block/radix_rank_sort_operations.cuh"
#include "../config.cuh"
#include "../util_type.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
Expand Down Expand Up @@ -153,6 +154,10 @@ struct AgentRadixSortDownsweep
>::Type
>::Type BlockRadixRankT;

// Digit extractor type
typedef BFEDigitExtractor<KeyT> DigitExtractorT;


enum
{
/// Number of bin-starting offsets tracked per thread
Expand Down Expand Up @@ -217,11 +222,8 @@ struct AgentRadixSortDownsweep
// The global scatter base offset for each digit (valid in the first RADIX_DIGITS threads)
OffsetT bin_offset[BINS_TRACKED_PER_THREAD];

// The least-significant bit position of the current digit to extract
int current_bit;

// Number of bits in current digit
int num_bits;
// Digit extractor
DigitExtractorT digit_extractor;

// Whether to short-cirucit
int short_circuit;
Expand Down Expand Up @@ -253,7 +255,7 @@ struct AgentRadixSortDownsweep
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
UnsignedBits key = temp_storage.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)];
UnsignedBits digit = BFE(key, current_bit, num_bits);
UnsignedBits digit = digit_extractor.Digit(key);
relative_bin_offsets[ITEM] = temp_storage.relative_bin_offsets[digit];

// Un-twiddle
Expand Down Expand Up @@ -522,8 +524,7 @@ struct AgentRadixSortDownsweep
BlockRadixRankT(temp_storage.radix_rank).RankKeys(
keys,
ranks,
current_bit,
num_bits,
digit_extractor,
exclusive_digit_prefix);

CTA_SYNC();
Expand Down Expand Up @@ -670,8 +671,7 @@ struct AgentRadixSortDownsweep
d_values_in(d_values_in),
d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)),
d_values_out(d_values_out),
current_bit(current_bit),
num_bits(num_bits),
digit_extractor(current_bit, num_bits),
short_circuit(1)
{
#pragma unroll
Expand Down Expand Up @@ -710,8 +710,7 @@ struct AgentRadixSortDownsweep
d_values_in(d_values_in),
d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)),
d_values_out(d_values_out),
current_bit(current_bit),
num_bits(num_bits),
digit_extractor(current_bit, num_bits),
short_circuit(1)
{
#pragma unroll
Expand Down
2 changes: 1 addition & 1 deletion cub/agent/agent_radix_sort_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ struct AgentRadixSortHistogram
current_bit < end_bit; current_bit += RADIX_BITS, ++pass)
{
int num_bits = CUB_MIN(RADIX_BITS, end_bit - current_bit);
DigitExtractor<UnsignedBits> digit_extractor(current_bit, num_bits);
ShiftDigitExtractor<KeyT> digit_extractor(current_bit, num_bits);
#pragma unroll
for (int u = 0; u < ITEMS_PER_THREAD; ++u)
{
Expand Down
7 changes: 2 additions & 5 deletions cub/agent/agent_radix_sort_onesweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ struct AgentRadixSortOnesweep
ValueT* d_values_out;
const ValueT* d_values_in;
OffsetT num_items;
int current_bit, num_bits;
DigitExtractor<UnsignedBits> digit_extractor;
ShiftDigitExtractor<KeyT> digit_extractor;

// other thread variables
int warp;
Expand Down Expand Up @@ -605,7 +604,7 @@ struct AgentRadixSortOnesweep
int exclusive_digit_prefix[BINS_PER_THREAD];
int bins[BINS_PER_THREAD];
BlockRadixRankT(s.rank_temp_storage).RankKeys(
keys, ranks, current_bit, num_bits, exclusive_digit_prefix,
keys, ranks, digit_extractor, exclusive_digit_prefix,
CountsCallback(*this, bins, keys));

// scatter keys in shared memory
Expand Down Expand Up @@ -648,8 +647,6 @@ struct AgentRadixSortOnesweep
, d_values_out(d_values_out)
, d_values_in(d_values_in)
, num_items(num_items)
, current_bit(current_bit)
, num_bits(num_bits)
, digit_extractor(current_bit, num_bits)
, warp(threadIdx.x / WARP_THREADS)
, lane(LaneId())
Expand Down
17 changes: 8 additions & 9 deletions cub/agent/agent_radix_sort_upsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "../thread/thread_load.cuh"
#include "../warp/warp_reduce.cuh"
#include "../block/block_load.cuh"
#include "../block/radix_rank_sort_operations.cuh"
#include "../config.cuh"
#include "../util_type.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
Expand Down Expand Up @@ -139,6 +140,9 @@ struct AgentRadixSortUpsweep
// Input iterator wrapper type (for applying cache modifier)s
typedef CacheModifiedInputIterator<LOAD_MODIFIER, UnsignedBits, OffsetT> KeysItr;

// Digit extractor type
typedef BFEDigitExtractor<KeyT> DigitExtractorT;

/**
* Shared memory storage layout
*/
Expand Down Expand Up @@ -167,12 +171,8 @@ struct AgentRadixSortUpsweep
// Input and output device pointers
KeysItr d_keys_in;

// The least-significant bit position of the current digit to extract
int current_bit;

// Number of bits in current digit
int num_bits;

// Digit extractor
DigitExtractorT digit_extractor;


//---------------------------------------------------------------------
Expand Down Expand Up @@ -217,7 +217,7 @@ struct AgentRadixSortUpsweep
UnsignedBits converted_key = Traits<KeyT>::TwiddleIn(key);

// Extract current digit bits
UnsignedBits digit = BFE(converted_key, current_bit, num_bits);
UnsignedBits digit = digit_extractor.Digit(converted_key);

// Get sub-counter offset
UnsignedBits sub_counter = digit & (PACKING_RATIO - 1);
Expand Down Expand Up @@ -342,8 +342,7 @@ struct AgentRadixSortUpsweep
:
temp_storage(temp_storage.Alias()),
d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
current_bit(current_bit),
num_bits(num_bits)
digit_extractor(current_bit, num_bits)
{}


Expand Down
86 changes: 43 additions & 43 deletions cub/block/block_radix_rank.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ CUB_NS_PREFIX
/// CUB namespace
namespace cub {


/**
* \brief Radix ranking algorithm, the algorithm used to implement stable ranking of the
* keys from a single tile. Note that different ranking algorithms require different
Expand Down Expand Up @@ -392,12 +393,12 @@ public:
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD>
int KEYS_PER_THREAD,
typename DigitExtractorT>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits) ///< [in] The number of bits in the current digit
DigitExtractorT digit_extractor) ///< [in] The digit extractor
{
DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit
DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem
Expand All @@ -409,7 +410,7 @@ public:
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
{
// Get digit
unsigned int digit = BFE(keys[ITEM], current_bit, num_bits);
unsigned int digit = digit_extractor.Digit(keys[ITEM]);

// Get sub-counter
unsigned int sub_counter = digit >> LOG_COUNTER_LANES;
Expand Down Expand Up @@ -455,16 +456,16 @@ public:
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD>
int KEYS_PER_THREAD,
typename DigitExtractorT>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter)
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits, ///< [in] The number of bits in the current digit
DigitExtractorT digit_extractor, ///< [in] The digit extractor
int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
{
// Rank keys
RankKeys(keys, ranks, current_bit, num_bits);
RankKeys(keys, ranks, digit_extractor);

// Get the inclusive and exclusive digit totals corresponding to the calling thread.
#pragma unroll
Expand Down Expand Up @@ -662,12 +663,12 @@ public:
template <
typename UnsignedBits,
int KEYS_PER_THREAD,
typename DigitExtractorT,
typename CountsCallback>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits, ///< [in] The number of bits in the current digit
DigitExtractorT digit_extractor, ///< [in] The digit extractor
CountsCallback callback)
{
// Initialize shared digit counters
Expand All @@ -688,7 +689,7 @@ public:
for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
{
// My digit
uint32_t digit = BFE(keys[ITEM], current_bit, num_bits);
uint32_t digit = digit_extractor.Digit(keys[ITEM]);

if (IS_DESCENDING)
digit = RADIX_DIGITS - digit - 1;
Expand Down Expand Up @@ -752,33 +753,34 @@ public:
ranks[ITEM] += *digit_counters[ITEM];
}

template <
template <
typename UnsignedBits,
int KEYS_PER_THREAD>
int KEYS_PER_THREAD,
typename DigitExtractorT>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD],
int current_bit, int num_bits)
{
RankKeys(keys, ranks, current_bit, num_bits,
BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>());
}
DigitExtractorT digit_extractor)
{
RankKeys(keys, ranks, digit_extractor,
BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>());
}

/**
* \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread.
*/
template <
typename UnsignedBits,
int KEYS_PER_THREAD,
typename CountsCallback>
typename DigitExtractorT,
typename CountsCallback>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter)
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits, ///< [in] The number of bits in the current digit
DigitExtractorT digit_extractor, ///< [in] The digit extractor
int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD], ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
CountsCallback callback)
{
RankKeys(keys, ranks, current_bit, num_bits, callback);
RankKeys(keys, ranks, digit_extractor, callback);

// Get exclusive count for each digit
#pragma unroll
Expand All @@ -798,15 +800,15 @@ public:

template <
typename UnsignedBits,
int KEYS_PER_THREAD>
int KEYS_PER_THREAD,
typename DigitExtractorT>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile
int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter)
int current_bit, ///< [in] The least-significant bit position of the current digit to extract
int num_bits, ///< [in] The number of bits in the current digit
DigitExtractorT digit_extractor,
int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
{
RankKeys(keys, ranks, current_bit, num_bits, exclusive_digit_prefix,
RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix,
BlockRadixRankEmptyCallback<BINS_TRACKED_PER_THREAD>());
}
};
Expand Down Expand Up @@ -866,11 +868,12 @@ struct BlockRadixRankMatchEarlyCounts
TempStorage& temp_storage;

// internal ranking implementation
template <typename UnsignedBits, int KEYS_PER_THREAD, typename CountsCallback>
template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT,
typename CountsCallback>
struct BlockRadixRankMatchInternal
{
TempStorage& s;
DigitExtractor<UnsignedBits> digit_extractor;
DigitExtractorT digit_extractor;
CountsCallback callback;
int warp;
int lane;
Expand Down Expand Up @@ -1066,8 +1069,7 @@ struct BlockRadixRankMatchEarlyCounts
}

__device__ __forceinline__ BlockRadixRankMatchInternal
(TempStorage& temp_storage, DigitExtractor<UnsignedBits> digit_extractor,
CountsCallback callback)
(TempStorage& temp_storage, DigitExtractorT digit_extractor, CountsCallback callback)
: s(temp_storage), digit_extractor(digit_extractor),
callback(callback), warp(threadIdx.x / WARP_THREADS), lane(LaneId())
{}
Expand All @@ -1079,44 +1081,42 @@ struct BlockRadixRankMatchEarlyCounts
/**
* \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread.
*/
template <typename UnsignedBits, int KEYS_PER_THREAD, typename CountsCallback>
template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT,
typename CountsCallback>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD],
int (&ranks)[KEYS_PER_THREAD],
int current_bit, int num_bits,
DigitExtractorT digit_extractor,
int (&exclusive_digit_prefix)[BINS_PER_THREAD],
CountsCallback callback)
CountsCallback callback)
{
DigitExtractor<UnsignedBits> digit_extractor(current_bit, num_bits);
BlockRadixRankMatchInternal<UnsignedBits, KEYS_PER_THREAD, CountsCallback>
BlockRadixRankMatchInternal<UnsignedBits, KEYS_PER_THREAD, DigitExtractorT, CountsCallback>
internal(temp_storage, digit_extractor, callback);
internal.RankKeys(keys, ranks, exclusive_digit_prefix);
}

template <typename UnsignedBits, int KEYS_PER_THREAD>
template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD],
int (&ranks)[KEYS_PER_THREAD],
int current_bit, int num_bits,
DigitExtractorT digit_extractor,
int (&exclusive_digit_prefix)[BINS_PER_THREAD])
{
DigitExtractor<UnsignedBits> digit_extractor(current_bit, num_bits);
typedef BlockRadixRankEmptyCallback<BINS_PER_THREAD> CountsCallback;
BlockRadixRankMatchInternal<UnsignedBits, KEYS_PER_THREAD, CountsCallback>
BlockRadixRankMatchInternal<UnsignedBits, KEYS_PER_THREAD, DigitExtractorT, CountsCallback>
internal(temp_storage, digit_extractor, CountsCallback());
internal.RankKeys(keys, ranks, exclusive_digit_prefix);
}

template <typename UnsignedBits, int KEYS_PER_THREAD>
template <typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT>
__device__ __forceinline__ void RankKeys(
UnsignedBits (&keys)[KEYS_PER_THREAD],
int (&ranks)[KEYS_PER_THREAD],
int current_bit, int num_bits)
DigitExtractorT digit_extractor)
{
int exclusive_digit_prefix[BINS_PER_THREAD];
RankKeys(keys, ranks, current_bit, num_bits, exclusive_digit_prefix);
RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix);
}

};


Expand Down
Loading