diff --git a/cub/agent/agent_radix_sort_downsweep.cuh b/cub/agent/agent_radix_sort_downsweep.cuh index 6fbb092603..6055bf652a 100644 --- a/cub/agent/agent_radix_sort_downsweep.cuh +++ b/cub/agent/agent_radix_sort_downsweep.cuh @@ -41,6 +41,7 @@ #include "../block/block_store.cuh" #include "../block/block_radix_rank.cuh" #include "../block/block_exchange.cuh" +#include "../block/radix_rank_sort_operations.cuh" #include "../config.cuh" #include "../util_type.cuh" #include "../iterator/cache_modified_input_iterator.cuh" @@ -153,6 +154,10 @@ struct AgentRadixSortDownsweep >::Type >::Type BlockRadixRankT; + // Digit extractor type + typedef BFEDigitExtractor DigitExtractorT; + + enum { /// Number of bin-starting offsets tracked per thread @@ -217,11 +222,8 @@ struct AgentRadixSortDownsweep // The global scatter base offset for each digit (valid in the first RADIX_DIGITS threads) OffsetT bin_offset[BINS_TRACKED_PER_THREAD]; - // The least-significant bit position of the current digit to extract - int current_bit; - - // Number of bits in current digit - int num_bits; + // Digit extractor + DigitExtractorT digit_extractor; // Whether to short-cirucit int short_circuit; @@ -253,7 +255,7 @@ struct AgentRadixSortDownsweep for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { UnsignedBits key = temp_storage.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)]; - UnsignedBits digit = BFE(key, current_bit, num_bits); + UnsignedBits digit = digit_extractor.Digit(key); relative_bin_offsets[ITEM] = temp_storage.relative_bin_offsets[digit]; // Un-twiddle @@ -522,8 +524,7 @@ struct AgentRadixSortDownsweep BlockRadixRankT(temp_storage.radix_rank).RankKeys( keys, ranks, - current_bit, - num_bits, + digit_extractor, exclusive_digit_prefix); CTA_SYNC(); @@ -670,8 +671,7 @@ struct AgentRadixSortDownsweep d_values_in(d_values_in), d_keys_out(reinterpret_cast(d_keys_out)), d_values_out(d_values_out), - current_bit(current_bit), - num_bits(num_bits), + digit_extractor(current_bit, num_bits), short_circuit(1) { #pragma unroll @@ -710,8 +710,7 @@ struct AgentRadixSortDownsweep d_values_in(d_values_in), d_keys_out(reinterpret_cast(d_keys_out)), d_values_out(d_values_out), - current_bit(current_bit), - num_bits(num_bits), + digit_extractor(current_bit, num_bits), short_circuit(1) { #pragma unroll diff --git a/cub/agent/agent_radix_sort_histogram.cuh b/cub/agent/agent_radix_sort_histogram.cuh index 13278a80ef..c8178dd797 100644 --- a/cub/agent/agent_radix_sort_histogram.cuh +++ b/cub/agent/agent_radix_sort_histogram.cuh @@ -187,7 +187,7 @@ struct AgentRadixSortHistogram current_bit < end_bit; current_bit += RADIX_BITS, ++pass) { int num_bits = CUB_MIN(RADIX_BITS, end_bit - current_bit); - DigitExtractor digit_extractor(current_bit, num_bits); + ShiftDigitExtractor digit_extractor(current_bit, num_bits); #pragma unroll for (int u = 0; u < ITEMS_PER_THREAD; ++u) { diff --git a/cub/agent/agent_radix_sort_onesweep.cuh b/cub/agent/agent_radix_sort_onesweep.cuh index 6d43b30626..641f35a708 100644 --- a/cub/agent/agent_radix_sort_onesweep.cuh +++ b/cub/agent/agent_radix_sort_onesweep.cuh @@ -180,8 +180,7 @@ struct AgentRadixSortOnesweep ValueT* d_values_out; const ValueT* d_values_in; OffsetT num_items; - int current_bit, num_bits; - DigitExtractor digit_extractor; + ShiftDigitExtractor digit_extractor; // other thread variables int warp; @@ -605,7 +604,7 @@ struct AgentRadixSortOnesweep int exclusive_digit_prefix[BINS_PER_THREAD]; int bins[BINS_PER_THREAD]; BlockRadixRankT(s.rank_temp_storage).RankKeys( - keys, ranks, current_bit, num_bits, exclusive_digit_prefix, + keys, ranks, digit_extractor, exclusive_digit_prefix, CountsCallback(*this, bins, keys)); // scatter keys in shared memory @@ -648,8 +647,6 @@ struct AgentRadixSortOnesweep , d_values_out(d_values_out) , d_values_in(d_values_in) , num_items(num_items) - , current_bit(current_bit) - , num_bits(num_bits) , digit_extractor(current_bit, num_bits) , warp(threadIdx.x / WARP_THREADS) , lane(LaneId()) diff --git a/cub/agent/agent_radix_sort_upsweep.cuh b/cub/agent/agent_radix_sort_upsweep.cuh index 89060027f5..5865a60a2f 100644 --- a/cub/agent/agent_radix_sort_upsweep.cuh +++ b/cub/agent/agent_radix_sort_upsweep.cuh @@ -37,6 +37,7 @@ #include "../thread/thread_load.cuh" #include "../warp/warp_reduce.cuh" #include "../block/block_load.cuh" +#include "../block/radix_rank_sort_operations.cuh" #include "../config.cuh" #include "../util_type.cuh" #include "../iterator/cache_modified_input_iterator.cuh" @@ -139,6 +140,9 @@ struct AgentRadixSortUpsweep // Input iterator wrapper type (for applying cache modifier)s typedef CacheModifiedInputIterator KeysItr; + // Digit extractor type + typedef BFEDigitExtractor DigitExtractorT; + /** * Shared memory storage layout */ @@ -167,12 +171,8 @@ struct AgentRadixSortUpsweep // Input and output device pointers KeysItr d_keys_in; - // The least-significant bit position of the current digit to extract - int current_bit; - - // Number of bits in current digit - int num_bits; - + // Digit extractor + DigitExtractorT digit_extractor; //--------------------------------------------------------------------- @@ -217,7 +217,7 @@ struct AgentRadixSortUpsweep UnsignedBits converted_key = Traits::TwiddleIn(key); // Extract current digit bits - UnsignedBits digit = BFE(converted_key, current_bit, num_bits); + UnsignedBits digit = digit_extractor.Digit(converted_key); // Get sub-counter offset UnsignedBits sub_counter = digit & (PACKING_RATIO - 1); @@ -342,8 +342,7 @@ struct AgentRadixSortUpsweep : temp_storage(temp_storage.Alias()), d_keys_in(reinterpret_cast(d_keys_in)), - current_bit(current_bit), - num_bits(num_bits) + digit_extractor(current_bit, num_bits) {} diff --git a/cub/block/block_radix_rank.cuh b/cub/block/block_radix_rank.cuh index 9429cdb42d..ffcc7280b0 100644 --- a/cub/block/block_radix_rank.cuh +++ b/cub/block/block_radix_rank.cuh @@ -50,6 +50,7 @@ CUB_NS_PREFIX /// CUB namespace namespace cub { + /** * \brief Radix ranking algorithm, the algorithm used to implement stable ranking of the * keys from a single tile. Note that different ranking algorithms require different @@ -392,12 +393,12 @@ public: */ template < typename UnsignedBits, - int KEYS_PER_THREAD> + int KEYS_PER_THREAD, + typename DigitExtractorT> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile - int current_bit, ///< [in] The least-significant bit position of the current digit to extract - int num_bits) ///< [in] The number of bits in the current digit + DigitExtractorT digit_extractor) ///< [in] The digit extractor { DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem @@ -409,7 +410,7 @@ public: for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) { // Get digit - unsigned int digit = BFE(keys[ITEM], current_bit, num_bits); + unsigned int digit = digit_extractor.Digit(keys[ITEM]); // Get sub-counter unsigned int sub_counter = digit >> LOG_COUNTER_LANES; @@ -455,16 +456,16 @@ public: */ template < typename UnsignedBits, - int KEYS_PER_THREAD> + int KEYS_PER_THREAD, + typename DigitExtractorT> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) - int current_bit, ///< [in] The least-significant bit position of the current digit to extract - int num_bits, ///< [in] The number of bits in the current digit + DigitExtractorT digit_extractor, ///< [in] The digit extractor int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] { // Rank keys - RankKeys(keys, ranks, current_bit, num_bits); + RankKeys(keys, ranks, digit_extractor); // Get the inclusive and exclusive digit totals corresponding to the calling thread. #pragma unroll @@ -662,12 +663,12 @@ public: template < typename UnsignedBits, int KEYS_PER_THREAD, + typename DigitExtractorT, typename CountsCallback> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile - int current_bit, ///< [in] The least-significant bit position of the current digit to extract - int num_bits, ///< [in] The number of bits in the current digit + DigitExtractorT digit_extractor, ///< [in] The digit extractor CountsCallback callback) { // Initialize shared digit counters @@ -688,7 +689,7 @@ public: for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) { // My digit - uint32_t digit = BFE(keys[ITEM], current_bit, num_bits); + uint32_t digit = digit_extractor.Digit(keys[ITEM]); if (IS_DESCENDING) digit = RADIX_DIGITS - digit - 1; @@ -752,16 +753,17 @@ public: ranks[ITEM] += *digit_counters[ITEM]; } - template < + template < typename UnsignedBits, - int KEYS_PER_THREAD> + int KEYS_PER_THREAD, + typename DigitExtractorT> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], - int current_bit, int num_bits) - { - RankKeys(keys, ranks, current_bit, num_bits, - BlockRadixRankEmptyCallback()); - } + DigitExtractorT digit_extractor) + { + RankKeys(keys, ranks, digit_extractor, + BlockRadixRankEmptyCallback()); + } /** * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. @@ -769,16 +771,16 @@ public: template < typename UnsignedBits, int KEYS_PER_THREAD, - typename CountsCallback> + typename DigitExtractorT, + typename CountsCallback> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) - int current_bit, ///< [in] The least-significant bit position of the current digit to extract - int num_bits, ///< [in] The number of bits in the current digit + DigitExtractorT digit_extractor, ///< [in] The digit extractor int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD], ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] CountsCallback callback) { - RankKeys(keys, ranks, current_bit, num_bits, callback); + RankKeys(keys, ranks, digit_extractor, callback); // Get exclusive count for each digit #pragma unroll @@ -798,15 +800,15 @@ public: template < typename UnsignedBits, - int KEYS_PER_THREAD> + int KEYS_PER_THREAD, + typename DigitExtractorT> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) - int current_bit, ///< [in] The least-significant bit position of the current digit to extract - int num_bits, ///< [in] The number of bits in the current digit + DigitExtractorT digit_extractor, int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] { - RankKeys(keys, ranks, current_bit, num_bits, exclusive_digit_prefix, + RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix, BlockRadixRankEmptyCallback()); } }; @@ -866,11 +868,12 @@ struct BlockRadixRankMatchEarlyCounts TempStorage& temp_storage; // internal ranking implementation - template + template struct BlockRadixRankMatchInternal { TempStorage& s; - DigitExtractor digit_extractor; + DigitExtractorT digit_extractor; CountsCallback callback; int warp; int lane; @@ -1066,8 +1069,7 @@ struct BlockRadixRankMatchEarlyCounts } __device__ __forceinline__ BlockRadixRankMatchInternal - (TempStorage& temp_storage, DigitExtractor digit_extractor, - CountsCallback callback) + (TempStorage& temp_storage, DigitExtractorT digit_extractor, CountsCallback callback) : s(temp_storage), digit_extractor(digit_extractor), callback(callback), warp(threadIdx.x / WARP_THREADS), lane(LaneId()) {} @@ -1079,44 +1081,42 @@ struct BlockRadixRankMatchEarlyCounts /** * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. */ - template + template __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], - int current_bit, int num_bits, + DigitExtractorT digit_extractor, int (&exclusive_digit_prefix)[BINS_PER_THREAD], - CountsCallback callback) + CountsCallback callback) { - DigitExtractor digit_extractor(current_bit, num_bits); - BlockRadixRankMatchInternal + BlockRadixRankMatchInternal internal(temp_storage, digit_extractor, callback); internal.RankKeys(keys, ranks, exclusive_digit_prefix); } - template + template __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], - int current_bit, int num_bits, + DigitExtractorT digit_extractor, int (&exclusive_digit_prefix)[BINS_PER_THREAD]) { - DigitExtractor digit_extractor(current_bit, num_bits); typedef BlockRadixRankEmptyCallback CountsCallback; - BlockRadixRankMatchInternal + BlockRadixRankMatchInternal internal(temp_storage, digit_extractor, CountsCallback()); internal.RankKeys(keys, ranks, exclusive_digit_prefix); } - template + template __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], - int current_bit, int num_bits) + DigitExtractorT digit_extractor) { int exclusive_digit_prefix[BINS_PER_THREAD]; - RankKeys(keys, ranks, current_bit, num_bits, exclusive_digit_prefix); + RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix); } - }; diff --git a/cub/block/block_radix_sort.cuh b/cub/block/block_radix_sort.cuh index e666902156..b27148b85c 100644 --- a/cub/block/block_radix_sort.cuh +++ b/cub/block/block_radix_sort.cuh @@ -36,6 +36,7 @@ #include "block_exchange.cuh" #include "block_radix_rank.cuh" +#include "radix_rank_sort_operations.cuh" #include "../config.cuh" #include "../util_ptx.cuh" #include "../util_type.cuh" @@ -76,7 +77,9 @@ namespace cub { * bit-sequences of \p RADIX_BITS as radix digit places. Although the direct radix sorting * method can only be applied to unsigned integral types, BlockRadixSort * is able to sort signed and floating-point types via simple bit-wise transformations - * that ensure lexicographic key ordering. + * that ensure lexicographic key ordering. For floating-point types -0.0 and +0.0 are + * considered equal and appear in the result in the same order as they appear in + * the input. * - \rowmajor * * \par Performance Considerations @@ -175,6 +178,9 @@ private: PTX_ARCH> DescendingBlockRadixRank; + /// Digit extractor type + typedef BFEDigitExtractor DigitExtractorT; + /// BlockExchange utility type for keys typedef BlockExchange BlockExchangeKeys; @@ -216,30 +222,26 @@ private: __device__ __forceinline__ void RankKeys( UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], int (&ranks)[ITEMS_PER_THREAD], - int begin_bit, - int pass_bits, + DigitExtractorT digit_extractor, Int2Type /*is_descending*/) { AscendingBlockRadixRank(temp_storage.asending_ranking_storage).RankKeys( - unsigned_keys, - ranks, - begin_bit, - pass_bits); + unsigned_keys, + ranks, + digit_extractor); } /// Rank keys (specialized for descending sort) __device__ __forceinline__ void RankKeys( UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], int (&ranks)[ITEMS_PER_THREAD], - int begin_bit, - int pass_bits, + DigitExtractorT digit_extractor, Int2Type /*is_descending*/) { DescendingBlockRadixRank(temp_storage.descending_ranking_storage).RankKeys( - unsigned_keys, - ranks, - begin_bit, - pass_bits); + unsigned_keys, + ranks, + digit_extractor); } /// ExchangeValues (specialized for key-value sort, to-blocked arrangement) @@ -301,10 +303,11 @@ private: while (true) { int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit); + DigitExtractorT digit_extractor(begin_bit, pass_bits); // Rank the blocked keys int ranks[ITEMS_PER_THREAD]; - RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending); + RankKeys(unsigned_keys, ranks, digit_extractor, is_descending); begin_bit += RADIX_BITS; CTA_SYNC(); @@ -357,10 +360,11 @@ public: while (true) { int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit); + DigitExtractorT digit_extractor(begin_bit, pass_bits); // Rank the blocked keys int ranks[ITEMS_PER_THREAD]; - RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending); + RankKeys(unsigned_keys, ranks, digit_extractor, is_descending); begin_bit += RADIX_BITS; CTA_SYNC(); diff --git a/cub/block/radix_rank_sort_operations.cuh b/cub/block/radix_rank_sort_operations.cuh index 7b3a05002a..17a399efb6 100644 --- a/cub/block/radix_rank_sort_operations.cuh +++ b/cub/block/radix_rank_sort_operations.cuh @@ -34,6 +34,7 @@ #pragma once #include "../config.cuh" +#include "../util_ptx.cuh" #include "../util_type.cuh" @@ -67,20 +68,79 @@ struct RadixSortTwiddle } }; +/** \brief Base struct for digit extractor. Contains common code to provide + special handling for floating-point -0.0. -/** \brief Stateful abstraction to extract digits. */ -template -struct DigitExtractor + \note This handles correctly both the case when the keys are + bitwise-complemented after twiddling for descending sort (in onesweep) as + well as when the keys are not bit-negated, but the implementation handles + descending sort separately (in other implementations in CUB). Twiddling + alone maps -0.0f to 0x7fffffff and +0.0f to 0x80000000 for float, which are + subsequent bit patterns and bitwise complements of each other. For onesweep, + both -0.0f and +0.0f are mapped to the bit pattern of +0.0f (0x80000000) for + ascending sort, and to the pattern of -0.0f (0x7fffffff) for descending + sort. For all other sorting implementations in CUB, both are always mapped + to +0.0f. Since bit patterns for both -0.0f and +0.0f are next to each other + and only one of them is used, the sorting works correctly. For double, the + same applies, but with 64-bit patterns. +*/ +template +struct BaseDigitExtractor { - int current_bit, mask; - __device__ __forceinline__ DigitExtractor() : current_bit(0), mask(0) {} - __device__ __forceinline__ DigitExtractor(int current_bit, int num_bits) - : current_bit(current_bit), mask((1 << num_bits) - 1) + typedef Traits TraitsT; + typedef typename TraitsT::UnsignedBits UnsignedBits; + + enum + { + FLOAT_KEY = TraitsT::CATEGORY == FLOATING_POINT, + }; + + static __device__ __forceinline__ UnsignedBits ProcessFloatMinusZero(UnsignedBits key) + { + if (!FLOAT_KEY) return key; + + UnsignedBits TWIDDLED_MINUS_ZERO_BITS = + TraitsT::TwiddleIn(UnsignedBits(1) << UnsignedBits(8 * sizeof(UnsignedBits) - 1)); + UnsignedBits TWIDDLED_ZERO_BITS = TraitsT::TwiddleIn(0); + return key == TWIDDLED_MINUS_ZERO_BITS ? TWIDDLED_ZERO_BITS : key; + } +}; + +/** \brief A wrapper type to extract digits. Uses the BFE intrinsic to extract a + * key from a digit. */ +template +struct BFEDigitExtractor : BaseDigitExtractor +{ + using typename BaseDigitExtractor::UnsignedBits; + + uint32_t bit_start, num_bits; + explicit __device__ __forceinline__ BFEDigitExtractor( + uint32_t bit_start = 0, uint32_t num_bits = 0) + : bit_start(bit_start), num_bits(num_bits) + { } + + __device__ __forceinline__ uint32_t Digit(UnsignedBits key) + { + return BFE(ProcessFloatMinusZero(key), bit_start, num_bits); + } +}; + +/** \brief A wrapper type to extract digits. Uses a combination of shift and + * bitwise and to extract digits. */ +template +struct ShiftDigitExtractor : BaseDigitExtractor +{ + using typename BaseDigitExtractor::UnsignedBits; + + uint32_t bit_start, mask; + explicit __device__ __forceinline__ ShiftDigitExtractor( + uint32_t bit_start = 0, uint32_t num_bits = 0) + : bit_start(bit_start), mask((1 << num_bits) - 1) { } - __device__ __forceinline__ int Digit(UnsignedBits key) + __device__ __forceinline__ uint32_t Digit(UnsignedBits key) { - return int(key >> UnsignedBits(current_bit)) & mask; + return uint32_t(ProcessFloatMinusZero(key) >> UnsignedBits(bit_start)) & mask; } }; diff --git a/cub/device/device_radix_sort.cuh b/cub/device/device_radix_sort.cuh index df218a7c35..7534c508cc 100644 --- a/cub/device/device_radix_sort.cuh +++ b/cub/device/device_radix_sort.cuh @@ -66,7 +66,9 @@ namespace cub { * half-precision floating-point type. Although the direct radix sorting * method can only be applied to unsigned integral types, DeviceRadixSort * is able to sort signed and floating-point types via simple bit-wise transformations - * that ensure lexicographic key ordering. + * that ensure lexicographic key ordering. For floating-point types -0.0 and +0.0 are + * considered equal and appear in the result in the same order as they appear in + * the input. * * \par Usage Considerations * \cdp_class{DeviceRadixSort} diff --git a/cub/device/device_segmented_radix_sort.cuh b/cub/device/device_segmented_radix_sort.cuh index 2ab2a7dde2..b03a2bafb6 100644 --- a/cub/device/device_segmented_radix_sort.cuh +++ b/cub/device/device_segmented_radix_sort.cuh @@ -66,7 +66,9 @@ namespace cub { * half-precision floating-point type. Although the direct radix sorting * method can only be applied to unsigned integral types, DeviceSegmentedRadixSort * is able to sort signed and floating-point types via simple bit-wise transformations - * that ensure lexicographic key ordering. + * that ensure lexicographic key ordering. For floating-point types -0.0 and +0.0 are + * considered equal and appear in the result in the same order as they appear in + * the input. * * \par Usage Considerations * \cdp_class{DeviceSegmentedRadixSort} diff --git a/test/test_block_radix_sort.cu b/test/test_block_radix_sort.cu index 6929dcdf5b..ae23956eec 100644 --- a/test/test_block_radix_sort.cu +++ b/test/test_block_radix_sort.cu @@ -209,8 +209,7 @@ __global__ void Kernel( */ template < typename Key, - typename Value, - bool IS_FLOAT = (Traits::CATEGORY == FLOATING_POINT)> + typename Value> struct Pair { Key key; @@ -222,35 +221,6 @@ struct Pair } }; -/** - * Simple key-value pairing (specialized for floating point types) - */ -template -struct Pair -{ - Key key; - Value value; - - bool operator<(const Pair &b) const - { - if (key < b.key) - return true; - - if (key > b.key) - return false; - - // Key in unsigned bits - typedef typename Traits::UnsignedBits UnsignedBits; - - // Return true if key is negative zero and b.key is positive zero - UnsignedBits key_bits = SafeBitCast(key); - UnsignedBits b_key_bits = SafeBitCast(b.key); - UnsignedBits HIGH_BIT = Traits::HIGH_BIT; - - return ((key_bits & HIGH_BIT) != 0) && ((b_key_bits & HIGH_BIT) == 0); - } -}; - /** * Initialize key-value sorting problem. @@ -468,6 +438,14 @@ void TestValid(Int2Type /*fits_smem_capacity*/) TestDriver( RANDOM, entropy_reduction, begin_bit, end_bit); } + + // For floating-point keys, test random keys mixed with -0.0 and +0.0 + if (cub::Traits::CATEGORY == cub::FLOATING_POINT) + { + TestDriver( + RANDOM_MINUS_PLUS_ZERO, 0, begin_bit, end_bit); + } + } } } diff --git a/test/test_device_radix_sort.cu b/test/test_device_radix_sort.cu index 45324aa077..29c4515bf4 100644 --- a/test/test_device_radix_sort.cu +++ b/test/test_device_radix_sort.cu @@ -571,8 +571,7 @@ cudaError_t Dispatch( */ template < typename KeyT, - typename ValueT, - bool IS_FLOAT = (Traits::CATEGORY == FLOATING_POINT)> + typename ValueT> struct Pair { KeyT key; @@ -589,7 +588,7 @@ struct Pair * Simple key-value pairing (specialized for bool types) */ template -struct Pair +struct Pair { bool key; ValueT value; @@ -601,36 +600,6 @@ struct Pair }; -/** - * Simple key-value pairing (specialized for floating point types) - */ -template -struct Pair -{ - KeyT key; - ValueT value; - - bool operator<(const Pair &b) const - { - if (key < b.key) - return true; - - if (key > b.key) - return false; - - // KeyT in unsigned bits - typedef typename Traits::UnsignedBits UnsignedBits; - - // Return true if key is negative zero and b.key is positive zero - UnsignedBits key_bits = SafeBitCast(key); - UnsignedBits b_key_bits = SafeBitCast(b.key); - UnsignedBits HIGH_BIT = Traits::HIGH_BIT; - - return ((key_bits & HIGH_BIT) != 0) && ((b_key_bits & HIGH_BIT) == 0); - } -}; - - /** * Initialize key data */ @@ -1098,6 +1067,14 @@ void TestGen( TestSizes(h_keys, max_items, max_segments); } + if (cub::Traits::CATEGORY == cub::FLOATING_POINT) + { + printf("\nTesting random %s keys with some replaced with -0.0 or +0.0 \n", typeid(KeyT).name()); + fflush(stdout); + InitializeKeyBits(RANDOM_MINUS_PLUS_ZERO, h_keys, max_items, 0); + TestSizes(h_keys, max_items, max_segments); + } + printf("\nTesting uniform %s keys\n", typeid(KeyT).name()); fflush(stdout); InitializeKeyBits(UNIFORM, h_keys, max_items, 0); TestSizes(h_keys, max_items, max_segments); diff --git a/test/test_util.h b/test/test_util.h index eb0cbf705e..d51bfe4fc2 100644 --- a/test/test_util.h +++ b/test/test_util.h @@ -521,10 +521,11 @@ int CoutCast(signed char val) { return val; } */ enum GenMode { - UNIFORM, // Assign to '2', regardless of integer seed - INTEGER_SEED, // Assign to integer seed - RANDOM, // Assign to random, regardless of integer seed - RANDOM_BIT, // Assign to randomly chosen 0 or 1, regardless of integer seed + UNIFORM, // Assign to '2', regardless of integer seed + INTEGER_SEED, // Assign to integer seed + RANDOM, // Assign to random, regardless of integer seed + RANDOM_BIT, // Assign to randomly chosen 0 or 1, regardless of integer seed + RANDOM_MINUS_PLUS_ZERO, // Assign to random, with some values being -0.0 or +0.0 patterns }; /** @@ -540,10 +541,36 @@ __host__ __device__ __forceinline__ void InitValue(GenMode gen_mode, T &value, i RandomBits(value); break; case RANDOM_BIT: + { char c; RandomBits(c, 0, 0, 1); value = (c > 0) ? (T) 1 : (T) -1; break; + } + case RANDOM_MINUS_PLUS_ZERO: + { + // Replace roughly 1/128 of values with -0.0 or +0.0, and generate the rest randomly + typedef typename cub::Traits::UnsignedBits UnsignedBits; + char c; + RandomBits(c); + if (c == 0) + { + // Replace 1/256 of values with +0.0 bit pattern + value = SafeBitCast(UnsignedBits(0)); + } + else if (c == 1) + { + // Replace 1/256 of values with -0.0 bit pattern + value = SafeBitCast(UnsignedBits(UnsignedBits(1) << + (sizeof(UnsignedBits) * 8) - 1)); + } + else + { + // 127/128 of values are random + RandomBits(value); + } + break; + } #endif case UNIFORM: value = 2;