Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Remove some problematic inline annotations. #259

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ __global__ void EmptyKernel(void) { }
/**
* \brief Returns the current device or -1 if an error occurred.
*/
CUB_RUNTIME_FUNCTION __forceinline__ int CurrentDevice()
CUB_RUNTIME_FUNCTION int CurrentDevice()
{
#if defined(CUB_RUNTIME_ENABLED) // Host code or device code with the CUDA runtime.

Expand All @@ -147,14 +147,14 @@ private:
int const old_device;
bool const needs_reset;
public:
__host__ __forceinline__ SwitchDevice(int new_device)
__host__ SwitchDevice(int new_device)
: old_device(CurrentDevice()), needs_reset(old_device != new_device)
{
if (needs_reset)
CubDebug(cudaSetDevice(new_device));
}

__host__ __forceinline__ ~SwitchDevice()
__host__ ~SwitchDevice()
{
if (needs_reset)
CubDebug(cudaSetDevice(old_device));
Expand All @@ -165,7 +165,7 @@ public:
* \brief Returns the number of CUDA devices available or -1 if an error
* occurred.
*/
CUB_RUNTIME_FUNCTION __forceinline__ int DeviceCountUncached()
CUB_RUNTIME_FUNCTION int DeviceCountUncached()
{
#if defined(CUB_RUNTIME_ENABLED) // Host code or device code with the CUDA runtime.

Expand Down Expand Up @@ -198,7 +198,7 @@ struct ValueCache
* \brief Call the nullary function to produce the value and construct the
* cache.
*/
__host__ __forceinline__ ValueCache() : value(Function()) {}
__host__ ValueCache() : value(Function()) {}
};

#endif
Expand All @@ -207,7 +207,7 @@ struct ValueCache
// Host code, only safely usable in C++11 or newer, where thread-safe
// initialization of static locals is guaranteed. This is a separate function
// to avoid defining a local static in a host/device function.
__host__ __forceinline__ int DeviceCountCachedValue()
__host__ int DeviceCountCachedValue()
{
static ValueCache<int, DeviceCountUncached> cache;
return cache.value;
Expand All @@ -221,7 +221,7 @@ __host__ __forceinline__ int DeviceCountCachedValue()
*
* \note This function is thread safe.
*/
CUB_RUNTIME_FUNCTION __forceinline__ int DeviceCount()
CUB_RUNTIME_FUNCTION int DeviceCount()
{
int result = -1;
if (CUB_IS_HOST_CODE) {
Expand Down Expand Up @@ -281,7 +281,7 @@ public:
/**
* \brief Construct the cache.
*/
__host__ __forceinline__ PerDeviceAttributeCache() : entries_()
__host__ PerDeviceAttributeCache() : entries_()
{
assert(DeviceCount() <= CUB_MAX_DEVICES);
}
Expand Down Expand Up @@ -359,7 +359,7 @@ public:
/**
* \brief Retrieves the PTX version that will be used on the current device (major * 100 + minor * 10).
*/
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersionUncached(int& ptx_version)
CUB_RUNTIME_FUNCTION cudaError_t PtxVersionUncached(int& ptx_version)
{
// Instantiate `EmptyKernel<void>` in both host and device code to ensure
// it can be called.
Expand Down Expand Up @@ -399,15 +399,15 @@ CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersionUncached(int& ptx_ver
/**
* \brief Retrieves the PTX version that will be used on \p device (major * 100 + minor * 10).
*/
__host__ __forceinline__ cudaError_t PtxVersionUncached(int& ptx_version, int device)
__host__ cudaError_t PtxVersionUncached(int& ptx_version, int device)
{
SwitchDevice sd(device);
return PtxVersionUncached(ptx_version);
}

#if CUB_CPP_DIALECT >= 2011 // C++11 and later.
template <typename Tag>
__host__ __forceinline__ PerDeviceAttributeCache& GetPerDeviceAttributeCache()
__host__ PerDeviceAttributeCache& GetPerDeviceAttributeCache()
{
// C++11 guarantees that initialization of static locals is thread safe.
static PerDeviceAttributeCache cache;
Expand All @@ -425,7 +425,7 @@ struct SmVersionCacheTag {};
*
* \note This function is thread safe.
*/
__host__ __forceinline__ cudaError_t PtxVersion(int& ptx_version, int device)
__host__ cudaError_t PtxVersion(int& ptx_version, int device)
{
#if CUB_CPP_DIALECT >= 2011 // C++11 and later.

Expand Down Expand Up @@ -454,7 +454,7 @@ __host__ __forceinline__ cudaError_t PtxVersion(int& ptx_version, int device)
*
* \note This function is thread safe.
*/
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersion(int& ptx_version)
CUB_RUNTIME_FUNCTION cudaError_t PtxVersion(int& ptx_version)
{
cudaError_t result = cudaErrorUnknown;
if (CUB_IS_HOST_CODE) {
Expand Down Expand Up @@ -490,7 +490,7 @@ CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t PtxVersion(int& ptx_version)
/**
* \brief Retrieves the SM version of \p device (major * 100 + minor * 10)
*/
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t SmVersionUncached(int& sm_version, int device = CurrentDevice())
CUB_RUNTIME_FUNCTION cudaError_t SmVersionUncached(int& sm_version, int device = CurrentDevice())
{
#if defined(CUB_RUNTIME_ENABLED) // Host code or device code with the CUDA runtime.

Expand Down Expand Up @@ -524,7 +524,7 @@ CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t SmVersionUncached(int& sm_versi
*
* \note This function is thread safe.
*/
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t SmVersion(int& sm_version, int device = CurrentDevice())
CUB_RUNTIME_FUNCTION cudaError_t SmVersion(int& sm_version, int device = CurrentDevice())
{
cudaError_t result = cudaErrorUnknown;
if (CUB_IS_HOST_CODE) {
Expand Down Expand Up @@ -557,7 +557,7 @@ CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t SmVersion(int& sm_version, int
/**
* Synchronize the specified \p stream.
*/
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t SyncStream(cudaStream_t stream)
CUB_RUNTIME_FUNCTION cudaError_t SyncStream(cudaStream_t stream)
{
cudaError_t result = cudaErrorUnknown;
if (CUB_IS_HOST_CODE) {
Expand Down Expand Up @@ -613,7 +613,7 @@ CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t SyncStream(cudaStream_t stream)
*
*/
template <typename KernelPtr>
CUB_RUNTIME_FUNCTION __forceinline__
CUB_RUNTIME_FUNCTION
cudaError_t MaxSmOccupancy(
int& max_sm_occupancy, ///< [out] maximum number of thread blocks that can reside on a single SM
KernelPtr kernel_ptr, ///< [in] Kernel pointer for which to compute SM occupancy
Expand Down