Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fix RLE when items[0] is NaN #598

Merged
merged 4 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 58 additions & 83 deletions cub/agent/agent_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
Expand All @@ -33,8 +33,6 @@

#pragma once

#include <iterator>

#include <cub/agent/single_pass_scan_operators.cuh>
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_load.cuh>
Expand All @@ -44,6 +42,8 @@
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/iterator/constant_input_iterator.cuh>

#include <iterator>

CUB_NAMESPACE_BEGIN

/******************************************************************************
Expand Down Expand Up @@ -146,8 +146,7 @@ struct AgentReduceByKey
using KeyInputT = cub::detail::value_t<KeysInputIteratorT>;

// The output keys type
using KeyOutputT =
cub::detail::non_void_value_t<UniqueOutputIteratorT, KeyInputT>;
using KeyOutputT = cub::detail::non_void_value_t<UniqueOutputIteratorT, KeyInputT>;

// The input values type
using ValueInputT = cub::detail::value_t<ValuesInputIteratorT>;
Expand All @@ -173,17 +172,14 @@ struct AgentReduceByKey
int num_remaining;

/// Constructor
__host__ __device__ __forceinline__
GuardedInequalityWrapper(_EqualityOpT op, int num_remaining)
__host__ __device__ __forceinline__ GuardedInequalityWrapper(_EqualityOpT op, int num_remaining)
: op(op)
, num_remaining(num_remaining)
{}

/// Boolean inequality operator, returns <tt>(a != b)</tt>
template <typename T>
__host__ __device__ __forceinline__ bool operator()(const T &a,
const T &b,
int idx) const
__host__ __device__ __forceinline__ bool operator()(const T &a, const T &b, int idx) const
{
if (idx < num_remaining)
{
Expand All @@ -196,27 +192,23 @@ struct AgentReduceByKey
};

// Constants
static constexpr int BLOCK_THREADS = AgentReduceByKeyPolicyT::BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD =
AgentReduceByKeyPolicyT::ITEMS_PER_THREAD;
static constexpr int BLOCK_THREADS = AgentReduceByKeyPolicyT::BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD = AgentReduceByKeyPolicyT::ITEMS_PER_THREAD;
static constexpr int TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD;
static constexpr int TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1);

// Whether or not the scan operation has a zero-valued identity value (true
// if we're performing addition on a primitive type)
static constexpr int HAS_IDENTITY_ZERO =
(std::is_same<ReductionOpT, cub::Sum>::value) &&
(Traits<AccumT>::PRIMITIVE);
static constexpr int HAS_IDENTITY_ZERO = (std::is_same<ReductionOpT, cub::Sum>::value) &&
(Traits<AccumT>::PRIMITIVE);

// Cache-modified Input iterator wrapper type (for applying cache modifier)
// for keys Wrap the native input pointer with
// CacheModifiedValuesInputIterator or directly use the supplied input
// iterator type
using WrappedKeysInputIteratorT = cub::detail::conditional_t<
std::is_pointer<KeysInputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER,
KeyInputT,
OffsetT>,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
KeysInputIteratorT>;

// Cache-modified Input iterator wrapper type (for applying cache modifier)
Expand All @@ -225,9 +217,7 @@ struct AgentReduceByKey
// iterator type
using WrappedValuesInputIteratorT = cub::detail::conditional_t<
std::is_pointer<ValuesInputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER,
ValueInputT,
OffsetT>,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
ValuesInputIteratorT>;

// Cache-modified Input iterator wrapper type (for applying cache modifier)
Expand All @@ -236,33 +226,26 @@ struct AgentReduceByKey
// iterator type
using WrappedFixupInputIteratorT = cub::detail::conditional_t<
std::is_pointer<AggregatesOutputIteratorT>::value,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER,
ValueInputT,
OffsetT>,
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
AggregatesOutputIteratorT>;

// Reduce-value-by-segment scan operator
using ReduceBySegmentOpT = ReduceBySegmentOp<ReductionOpT>;

// Parameterized BlockLoad type for keys
using BlockLoadKeysT = BlockLoad<KeyOutputT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentReduceByKeyPolicyT::LOAD_ALGORITHM>;
using BlockLoadKeysT =
BlockLoad<KeyOutputT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentReduceByKeyPolicyT::LOAD_ALGORITHM>;

// Parameterized BlockLoad type for values
using BlockLoadValuesT = BlockLoad<AccumT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentReduceByKeyPolicyT::LOAD_ALGORITHM>;
using BlockLoadValuesT =
BlockLoad<AccumT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentReduceByKeyPolicyT::LOAD_ALGORITHM>;

// Parameterized BlockDiscontinuity type for keys
using BlockDiscontinuityKeys = BlockDiscontinuity<KeyOutputT, BLOCK_THREADS>;

// Parameterized BlockScan type
using BlockScanT = BlockScan<OffsetValuePairT,
BLOCK_THREADS,
AgentReduceByKeyPolicyT::SCAN_ALGORITHM>;
using BlockScanT =
BlockScan<OffsetValuePairT, BLOCK_THREADS, AgentReduceByKeyPolicyT::SCAN_ALGORITHM>;

// Callback type for obtaining tile prefix during block scan
using TilePrefixCallbackOpT =
Expand Down Expand Up @@ -362,15 +345,14 @@ struct AgentReduceByKey
* @param reduction_op
* ValueT reduction operator
*/
__device__ __forceinline__
AgentReduceByKey(TempStorage &temp_storage,
KeysInputIteratorT d_keys_in,
UniqueOutputIteratorT d_unique_out,
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
NumRunsOutputIteratorT d_num_runs_out,
EqualityOpT equality_op,
ReductionOpT reduction_op)
__device__ __forceinline__ AgentReduceByKey(TempStorage &temp_storage,
KeysInputIteratorT d_keys_in,
UniqueOutputIteratorT d_unique_out,
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
NumRunsOutputIteratorT d_num_runs_out,
EqualityOpT equality_op,
ReductionOpT reduction_op)
: temp_storage(temp_storage.Alias())
, d_keys_in(d_keys_in)
, d_unique_out(d_unique_out)
Expand All @@ -389,10 +371,9 @@ struct AgentReduceByKey
/**
* Directly scatter flagged items to output offsets
*/
__device__ __forceinline__ void
ScatterDirect(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD])
__device__ __forceinline__ void ScatterDirect(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD])
{
// Scatter flagged keys and values
#pragma unroll
Expand All @@ -413,12 +394,11 @@ struct AgentReduceByKey
* value aggregate: the scatter offsets must be decremented for value
* aggregates
*/
__device__ __forceinline__ void
ScatterTwoPhase(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD],
OffsetT num_tile_segments,
OffsetT num_tile_segments_prefix)
__device__ __forceinline__ void ScatterTwoPhase(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD],
OffsetT num_tile_segments,
OffsetT num_tile_segments_prefix)
{
CTA_SYNC();

Expand All @@ -428,18 +408,16 @@ struct AgentReduceByKey
{
if (segment_flags[ITEM])
{
temp_storage.raw_exchange
.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] =
temp_storage.raw_exchange.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] =
scatter_items[ITEM];
}
}

CTA_SYNC();

for (int item = threadIdx.x; item < num_tile_segments;
item += BLOCK_THREADS)
for (int item = threadIdx.x; item < num_tile_segments; item += BLOCK_THREADS)
{
KeyValuePairT pair = temp_storage.raw_exchange.Alias()[item];
KeyValuePairT pair = temp_storage.raw_exchange.Alias()[item];
d_unique_out[num_tile_segments_prefix + item] = pair.key;
d_aggregates_out[num_tile_segments_prefix + item] = pair.value;
}
Expand All @@ -448,12 +426,11 @@ struct AgentReduceByKey
/**
* Scatter flagged items
*/
__device__ __forceinline__ void
Scatter(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD],
OffsetT num_tile_segments,
OffsetT num_tile_segments_prefix)
__device__ __forceinline__ void Scatter(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
OffsetT (&segment_flags)[ITEMS_PER_THREAD],
OffsetT (&segment_indices)[ITEMS_PER_THREAD],
OffsetT num_tile_segments,
OffsetT num_tile_segments_prefix)
{
// Do a one-phase scatter if (a) two-phase is disabled or (b) the average
// number of selected items per thread is less than one
Expand Down Expand Up @@ -494,10 +471,8 @@ struct AgentReduceByKey
* Global tile state descriptor
*/
template <bool IS_LAST_TILE>
__device__ __forceinline__ void ConsumeTile(OffsetT num_remaining,
int tile_idx,
OffsetT tile_offset,
ScanTileStateT &tile_state)
__device__ __forceinline__ void
ConsumeTile(OffsetT num_remaining, int tile_idx, OffsetT tile_offset, ScanTileStateT &tile_state)
{
// Tile keys
KeyOutputT keys[ITEMS_PER_THREAD];
Expand All @@ -523,8 +498,7 @@ struct AgentReduceByKey
// Load keys
if (IS_LAST_TILE)
{
BlockLoadKeysT(temp_storage.load_keys)
.Load(d_keys_in + tile_offset, keys, num_remaining);
BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys, num_remaining);
}
else
{
Expand Down Expand Up @@ -553,8 +527,7 @@ struct AgentReduceByKey
}
else
{
BlockLoadValuesT(temp_storage.load_values)
.Load(d_values_in + tile_offset, values);
BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values);
}

CTA_SYNC();
Expand All @@ -575,7 +548,14 @@ struct AgentReduceByKey
.FlagHeads(head_flags, keys, prev_keys, flag_op, tile_predecessor);
}

// Zip values and head flags
// Reset head-flag on the very first item to make sure we don't start a new run for data where
// (key[0] == key[0]) is false (e.g., when key[0] is NaN)
if (threadIdx.x == 0 && tile_idx == 0)
{
head_flags[0] = 0;
}

// Zip values and head flags
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
Expand Down Expand Up @@ -638,11 +618,7 @@ struct AgentReduceByKey

// Scatter flagged keys and values
OffsetT num_tile_segments = block_aggregate.key;
Scatter(scatter_items,
head_flags,
segment_indices,
num_tile_segments,
num_segments_prefix);
Scatter(scatter_items, head_flags, segment_indices, num_tile_segments, num_segments_prefix);

// Last thread in last tile will output final count (and last pair, if
// necessary)
Expand Down Expand Up @@ -705,4 +681,3 @@ struct AgentReduceByKey
};

CUB_NAMESPACE_END

Loading