-
Notifications
You must be signed in to change notification settings - Fork 379
/
Copy pathrtc.h
195 lines (171 loc) · 6.57 KB
/
rtc.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>
#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
namespace transformer_engine {
namespace rtc {
/*! \brief Whether NVRTC support is enabled
*
* NVRTC support can be disabled by setting NVTE_DISABLE_NVRTC=1 in
* the environment.
*/
bool is_enabled();
/*! \brief Wrapper class for a runtime-compiled CUDA kernel */
class Kernel {
public:
Kernel(std::string mangled_name, std::string compiled_code);
~Kernel();
Kernel(const Kernel&) = delete; // move-only
Kernel(Kernel&&) noexcept;
Kernel& operator=(Kernel) noexcept;
friend void swap(Kernel& first, Kernel& second) noexcept;
/*! \brief Launch CUDA kernel
*
* Loads the kernel into the device the first time the device is
* accessed.
*
* \param[in] device_id CUDA device
* \param[in] grid_dim Grid dimensions in blocks
* \param[in] block_dim Thread block dimensions
* \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
* bytes
* \param[in] stream CUDA stream
* \param[in] args Kernel arguments
*/
template <typename... ArgTs>
void launch(int device_id,
const dim3 grid_dim,
const dim3 block_dim,
unsigned int shared_mem_bytes,
cudaStream_t stream,
ArgTs &&... args) {
void* arg_ptrs[] = { const_cast<void*>(static_cast<const void*>(&args))... };
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel,
get_function(device_id),
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
shared_mem_bytes,
static_cast<CUstream>(stream),
arg_ptrs,
nullptr);
}
/*! \brief CUDA function for given CUDA device
*
* Loads the kernel into the device the first time the device is
* accessed.
*/
CUfunction get_function(int device_id);
private:
/*! \brief Mangled function name */
std::string mangled_name_;
/*! \brief Compiled assembly, either in PTX or cubin format */
std::string compiled_code_;
/*! CUDA module for each CUDA device */
std::vector<CUmodule> modules_;
/*! CUDA function for each CUDA device */
std::vector<CUfunction> functions_;
/*! Flags for thread-safe kernel initialization */
std::unique_ptr<std::vector<std::once_flag>> init_flags_;
/*! \brief Uninitialized CUDA module */
static constexpr CUmodule null_module = static_cast<CUmodule>(nullptr);
/*! Uninitialized CUDA function */
static constexpr CUfunction null_function = static_cast<CUfunction>(nullptr);
};
/*! \brief Singleton class to manage runtime-compiled CUDA kernels */
class KernelManager {
public:
/*! \brief Get singleton instance */
static KernelManager& instance();
/*! \brief Compile CUDA kernel for current CUDA device
*
* The compiled kernel is cached and made available for launching.
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] kernel_name Kernel name within source code
* \param[in] code Kernel source code
* \param[in] filename Path to associate with source code,
* primarily for debugging
*/
void compile(const std::string &kernel_label,
const std::string &kernel_name,
const std::string &code,
const std::string &filename);
/*! \brief Whether CUDA kernel has been compiled for CUDA device
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] device_id CUDA device (default is current device)
* \return Whether kernel has been compiled
*/
bool is_compiled(const std::string &kernel_label,
int device_id = -1) const;
/*! \brief Launch CUDA kernel on current CUDA device
*
* Assumes the kernel has already been compiled.
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] grid_dim Grid dimensions in blocks
* \param[in] block_dim Thread block dimensions
* \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
* bytes
* \param[in] stream CUDA stream
* \param[in] args Kernel arguments
*/
template <typename... ArgTs>
void launch(const std::string &kernel_label,
const dim3 grid_dim,
const dim3 block_dim,
unsigned int shared_mem_bytes,
cudaStream_t stream,
ArgTs &&... args) {
const int device_id = cuda::current_device();
const auto key = get_kernel_cache_key(kernel_label, device_id);
NVTE_CHECK(kernel_cache_.count(key) > 0,
"Attempted to launch RTC kernel before compilation");
kernel_cache_.at(key).launch(device_id,
grid_dim,
block_dim,
shared_mem_bytes,
stream,
std::forward<ArgTs>(args)...);
}
private:
/*! \brief Compiled kernels */
std::unordered_map<std::string, Kernel> kernel_cache_;
/*! \brief Mutex for thread-safe compilation */
std::mutex lock_;
KernelManager() = default;
~KernelManager() = default;
KernelManager(const KernelManager&) = delete;
KernelManager& operator=(const KernelManager&) = delete;
/*! \brief Construct key for kernel cache
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] device_id CUDA device (default is current device)
*
* \return Key for kernel cache
*/
std::string get_kernel_cache_key(const std::string &kernel_label,
int device_id) const;
};
} // namespace rtc
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_