Skip to content

Commit 7d33957

Browse files
committed
Implement runtime compiler. (nvrtc)
1 parent 3adbe3e commit 7d33957

File tree

13 files changed

+680
-102
lines changed

13 files changed

+680
-102
lines changed

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ members = [
1414
"hipblaslt-sys",
1515
"hipfft-sys",
1616
"hiprt-sys",
17+
"hiprtc-sys",
1718
"miopen-sys",
1819
"offline_compiler",
1920
"optix_base",
@@ -39,6 +40,7 @@ members = [
3940
"zluda_ml",
4041
"zluda_redirect",
4142
"zluda_rt",
43+
"zluda_rtc",
4244
"zluda_sparse",
4345
]
4446

Makefile.toml

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ args = [
2222
"-p", "zluda_lib",
2323
"-p", "zluda_ml",
2424
"-p", "zluda_sparse",
25+
"-p", "zluda_rtc",
2526
"-p", "zluda_redirect",
2627
]
2728

@@ -38,6 +39,7 @@ args = [
3839
"-p", "zluda_fft",
3940
"-p", "zluda_lib",
4041
"-p", "zluda_ml",
42+
"-p", "zluda_rtc",
4143
"-p", "zluda_sparse",
4244
]
4345

@@ -55,6 +57,7 @@ args = [
5557
"-p", "zluda_fft",
5658
"-p", "zluda_lib",
5759
"-p", "zluda_ml",
60+
"-p", "zluda_rtc",
5861
"-p", "zluda_sparse",
5962
]
6063

hiprtc-sys/Cargo.toml

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "hiprtc-sys"
3+
version = "0.0.0"
4+
authors = ["Seunghoon Lee <op@lsh.sh>"]
5+
edition = "2018"
6+
links = "hiprtc"
7+
8+
[lib]

hiprtc-sys/README

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bindgen $Env:HIP_PATH/include/hip/hiprtc.h -o src/hiprtc.rs --no-layout-tests --default-enum-style=newtype --no-derive-debug --allowlist-function "hiprtc.*" --must-use-type hiprtcResult_t -- -I$Env:HIP_PATH/include -D__HIP_PLATFORM_AMD__

hiprtc-sys/build.rs

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use std::env::VarError;
2+
use std::{env, path::PathBuf};
3+
4+
fn main() -> Result<(), VarError> {
5+
println!("cargo:rustc-link-lib=dylib=hiprtc");
6+
if cfg!(windows) {
7+
let mut path = PathBuf::from(env::var("HIP_PATH")?);
8+
path.push("lib");
9+
println!("cargo:rustc-link-search=native={}", path.display());
10+
} else {
11+
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
12+
}
13+
Ok(())
14+
}

hiprtc-sys/src/hiprtc.rs

+360
Large diffs are not rendered by default.

hiprtc-sys/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#[allow(warnings)]
2+
mod hiprtc;
3+
pub use hiprtc::*;

zluda_inject/src/bin.rs

+17-68
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,8 @@ use winapi::um::{
2323
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
2424

2525
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
26-
static CUBLAS_DLL: &'static str = "cublas.dll";
27-
static CUDNN_DLL: &'static str = "cudnn.dll";
28-
static CUFFT_DLL: &'static str = "cufft.dll";
29-
static CUSPARSE_DLL: &'static str = "cusparse.dll";
3026
static NCCL_DLL: &'static str = "nccl.dll";
27+
static NVRTC_DLL: &'static str = "nvrtc.dll";
3128
static NVCUDA_DLL: &'static str = "nvcuda.dll";
3229
static NVML_DLL: &'static str = "nvml.dll";
3330
static NVAPI_DLL: &'static str = "nvapi64.dll";
@@ -38,26 +35,14 @@ include!("../../zluda_redirect/src/payload_guid.rs");
3835
#[derive(FromArgs)]
3936
/// Launch application with custom CUDA libraries
4037
struct ProgramArguments {
41-
/// DLL to be injected instead of system cublas.dll. If not provided {0}, will use cublas.dll from its own directory
42-
#[argh(option)]
43-
cublas: Option<PathBuf>,
44-
45-
/// DLL to be injected instead of system cudnn.dll. If not provided {0}, will use cudnn.dll from its own directory
46-
#[argh(option)]
47-
cudnn: Option<PathBuf>,
48-
49-
/// DLL to be injected instead of system cufft.dll. If not provided {0}, will use cufft.dll from its own directory
50-
#[argh(option)]
51-
cufft: Option<PathBuf>,
52-
53-
/// DLL to be injected instead of system cusparse.dll. If not provided {0}, will use cusparse.dll from its own directory
54-
#[argh(option)]
55-
cusparse: Option<PathBuf>,
56-
5738
/// DLL to be injected instead of system nccl.dll. If not provided {0}, will use nccl.dll from its own directory
5839
#[argh(option)]
5940
nccl: Option<PathBuf>,
6041

42+
/// DLL to be injected instead of system nvrtc.dll. If not provided {0}, will use nvrtc.dll from its own directory
43+
#[argh(option)]
44+
nvrtc: Option<PathBuf>,
45+
6146
/// DLL to be injected instead of system nvcuda.dll. If not provided {0}, will use nvcuda.dll from its own directory
6247
#[argh(option)]
6348
nvcuda: Option<PathBuf>,
@@ -90,11 +75,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
9075
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
9176
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
9277
let mut dlls_to_inject = vec![
93-
environment.cublas_path_zero_terminated.as_ptr() as _,
94-
//environment.cudnn_path_zero_terminated.as_ptr() as _,
95-
environment.cufft_path_zero_terminated.as_ptr() as _,
96-
environment.cusparse_path_zero_terminated.as_ptr() as _,
9778
environment.nccl_path_zero_terminated.as_ptr() as _,
79+
environment.nvrtc_path_zero_terminated.as_ptr() as _,
9880
environment.nvcuda_path_zero_terminated.as_ptr() as _,
9981
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
10082
environment.redirect_path_zero_terminated.as_ptr() as _,
@@ -176,11 +158,8 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
176158
}
177159

178160
struct NormalizedArguments {
179-
cublas_path: PathBuf,
180-
cudnn_path: PathBuf,
181-
cufft_path: PathBuf,
182-
cusparse_path: PathBuf,
183161
nccl_path: PathBuf,
162+
nvrtc_path: PathBuf,
184163
nvcuda_path: PathBuf,
185164
nvml_path: PathBuf,
186165
nvapi_path: Option<PathBuf>,
@@ -192,16 +171,10 @@ struct NormalizedArguments {
192171
impl NormalizedArguments {
193172
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
194173
let current_exe = env::current_exe()?;
195-
let cublas_path =
196-
Self::get_absolute_path_or_default(&current_exe, prog_args.cublas, CUBLAS_DLL)?;
197-
let cudnn_path =
198-
Self::get_absolute_path_or_default(&current_exe, prog_args.cudnn, CUDNN_DLL)?;
199-
let cufft_path =
200-
Self::get_absolute_path_or_default(&current_exe, prog_args.cufft, CUFFT_DLL)?;
201-
let cusparse_path =
202-
Self::get_absolute_path_or_default(&current_exe, prog_args.cusparse, CUSPARSE_DLL)?;
203174
let nccl_path =
204175
Self::get_absolute_path_or_default(&current_exe, prog_args.nccl, NCCL_DLL)?;
176+
let nvrtc_path =
177+
Self::get_absolute_path_or_default(&current_exe, prog_args.nvrtc, NVRTC_DLL)?;
205178
let nvcuda_path =
206179
Self::get_absolute_path_or_default(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?;
207180
let nvml_path = Self::get_absolute_path_or_default(&current_exe, prog_args.nvml, NVML_DLL)?;
@@ -212,11 +185,8 @@ impl NormalizedArguments {
212185
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
213186
redirect_path.push(REDIRECT_DLL);
214187
Ok(Self {
215-
cublas_path,
216-
cudnn_path,
217-
cufft_path,
218-
cusparse_path,
219188
nccl_path,
189+
nvrtc_path,
220190
nvcuda_path,
221191
nvml_path,
222192
nvapi_path,
@@ -274,11 +244,8 @@ impl NormalizedArguments {
274244
}
275245

276246
struct Environment {
277-
cublas_path_zero_terminated: String,
278-
cudnn_path_zero_terminated: String,
279-
cufft_path_zero_terminated: String,
280-
cusparse_path_zero_terminated: String,
281247
nccl_path_zero_terminated: String,
248+
nvrtc_path_zero_terminated: String,
282249
nvcuda_path_zero_terminated: String,
283250
nvml_path_zero_terminated: String,
284251
nvapi_path_zero_terminated: Option<String>,
@@ -294,31 +261,16 @@ struct Environment {
294261
impl Environment {
295262
fn setup(args: NormalizedArguments) -> io::Result<Self> {
296263
let _temp_dir = TempDir::new()?;
297-
let cublas_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
298-
args.cublas_path,
299-
&_temp_dir,
300-
CUBLAS_DLL,
301-
)?);
302-
let cudnn_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
303-
args.cudnn_path,
304-
&_temp_dir,
305-
CUDNN_DLL,
306-
)?);
307-
let cufft_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
308-
args.cufft_path,
309-
&_temp_dir,
310-
CUFFT_DLL,
311-
)?);
312-
let cusparse_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
313-
args.cusparse_path,
314-
&_temp_dir,
315-
CUSPARSE_DLL,
316-
)?);
317264
let nccl_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
318265
args.nccl_path,
319266
&_temp_dir,
320267
NCCL_DLL,
321268
)?);
269+
let nvrtc_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
270+
args.nvrtc_path,
271+
&_temp_dir,
272+
NVRTC_DLL,
273+
)?);
322274
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
323275
args.nvcuda_path,
324276
&_temp_dir,
@@ -349,11 +301,8 @@ impl Environment {
349301
.transpose()?;
350302
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
351303
Ok(Self {
352-
cublas_path_zero_terminated,
353-
cudnn_path_zero_terminated,
354-
cufft_path_zero_terminated,
355-
cusparse_path_zero_terminated,
356304
nccl_path_zero_terminated,
305+
nvrtc_path_zero_terminated,
357306
nvcuda_path_zero_terminated,
358307
nvml_path_zero_terminated,
359308
nvapi_path_zero_terminated,

zluda_redirect/src/lib.rs

+2-34
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ use winapi::{
5252
include!("payload_guid.rs");
5353

5454
const WIN_MAX_PATH: usize = 260;
55-
const CUBLAS_UTF8: &'static str = "CUBLAS.DLL";
56-
const CUBLAS_UTF16: &[u16] = wch!("CUBLAS.DLL");
57-
const CUDNN_UTF8: &'static str = "CUDNN.DLL";
58-
const CUDNN_UTF16: &[u16] = wch!("CUDNN.DLL");
5955
const NVCUDA1_UTF8: &'static str = "NVCUDA.DLL";
6056
const NVCUDA1_UTF16: &[u16] = wch!("NVCUDA.DLL");
6157
const NVCUDA2_UTF8: &'static str = "NVCUDA.DLL";
@@ -68,10 +64,6 @@ const NVOPTIX_UTF8: &'static str = "OPTIX.6.6.0.DLL";
6864
const NVOPTIX_UTF16: &[u16] = wch!("OPTIX.6.6.0.DLL");
6965
static mut ZLUDA_PATH_UTF8: Option<&'static [u8]> = None;
7066
static mut ZLUDA_PATH_UTF16: Vec<u16> = Vec::new();
71-
static mut ZLUDA_BLAS_PATH_UTF8: Option<&'static [u8]> = None;
72-
static mut ZLUDA_BLAS_PATH_UTF16: Vec<u16> = Vec::new();
73-
static mut ZLUDA_DNN_PATH_UTF8: Option<&'static [u8]> = None;
74-
static mut ZLUDA_DNN_PATH_UTF16: Vec<u16> = Vec::new();
7567
static mut ZLUDA_ML_PATH_UTF8: Option<&'static [u8]> = None;
7668
static mut ZLUDA_ML_PATH_UTF16: Vec<u16> = Vec::new();
7769
static mut ZLUDA_API_PATH_UTF8: Option<&'static [u8]> = None;
@@ -207,11 +199,7 @@ unsafe fn get_library_name_utf8(raw_library_name: *const u8) -> *const u8 {
207199
}
208200
}
209201
}
210-
if is_cublas_dll_utf8(library_name) {
211-
return ZLUDA_BLAS_PATH_UTF8.unwrap().as_ptr();
212-
} /*else if is_cudnn_dll_utf8(library_name) {
213-
return ZLUDA_DNN_PATH_UTF8.unwrap().as_ptr();
214-
}*/ else if is_nvcuda_dll_utf8(library_name) {
202+
if is_nvcuda_dll_utf8(library_name) {
215203
return ZLUDA_PATH_UTF8.unwrap().as_ptr();
216204
} else if is_nvml_dll_utf8(library_name) {
217205
return ZLUDA_ML_PATH_UTF8.unwrap().as_ptr();
@@ -249,11 +237,7 @@ unsafe fn get_library_name_utf16(raw_library_name: *const u16) -> *const u16 {
249237
}
250238
}
251239
}
252-
if is_cublas_dll_utf16(library_name) {
253-
return ZLUDA_BLAS_PATH_UTF16.as_ptr();
254-
} /*else if is_cudnn_dll_utf16(library_name) {
255-
return ZLUDA_DNN_PATH_UTF16.as_ptr();
256-
}*/ else if is_nvcuda_dll_utf16(library_name) {
240+
if is_nvcuda_dll_utf16(library_name) {
257241
return ZLUDA_PATH_UTF16.as_ptr();
258242
} else if is_nvml_dll_utf16(library_name) {
259243
return ZLUDA_ML_PATH_UTF16.as_ptr();
@@ -329,22 +313,6 @@ unsafe fn is_driverstore_utf16(lib: &[u16]) -> bool {
329313
starts_with_ignore_case(lib, &DRIVERSTORE_UTF16, utf16_to_ascii_uppercase)
330314
}
331315

332-
fn is_cublas_dll_utf8(lib: &[u8]) -> bool {
333-
is_dll_utf8(lib, CUBLAS_UTF8.as_bytes())
334-
}
335-
336-
fn is_cublas_dll_utf16(lib: &[u16]) -> bool {
337-
is_dll_utf16(lib, CUBLAS_UTF16)
338-
}
339-
340-
fn is_cudnn_dll_utf8(lib: &[u8]) -> bool {
341-
is_dll_utf8(lib, CUDNN_UTF8.as_bytes())
342-
}
343-
344-
fn is_cudnn_dll_utf16(lib: &[u16]) -> bool {
345-
is_dll_utf16(lib, CUDNN_UTF16)
346-
}
347-
348316
fn is_nvcuda_dll_utf8(lib: &[u8]) -> bool {
349317
is_dll_utf8(lib, NVCUDA1_UTF8.as_bytes()) || is_dll_utf8(lib, NVCUDA2_UTF8.as_bytes())
350318
}

zluda_rtc/Cargo.toml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "zluda_rtc"
3+
version = "0.0.0"
4+
authors = ["Seunghoon Lee <op@lsh.sh>"]
5+
edition = "2018"
6+
7+
[lib]
8+
name = "nvrtc"
9+
crate-type = ["cdylib"]
10+
11+
[dependencies]
12+
hip_common = { path = "../hip_common" }
13+
hiprtc-sys = { path = "../hiprtc-sys" }
14+
15+
[package.metadata.zluda]
16+
linux_names = ["libnvrtc.so.10", "libnvrtc.so.11"]
17+
dump_names = ["libnvrtc.so"]

zluda_rtc/README

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
bindgen include/nvrtc.h -o src/nvrtc.rs --allowlist-function="^nvrtc.*" --default-enum-style=newtype --no-layout-tests --no-derive-debug -- -Iinclude
2+
sed -i -e 's/extern "C" {//g' -e 's/-> nvrtcResult;/-> nvrtcResult { crate::unsupported()/g' -e 's/pub fn /#[no_mangle] pub extern "system" fn /g' src/nvrtc.rs
3+
rustfmt src/nvrtc.rs

zluda_rtc/src/lib.rs

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
mod nvrtc;
2+
pub use nvrtc::*;
3+
4+
use hiprtc_sys::*;
5+
6+
#[cfg(debug_assertions)]
7+
fn unsupported() -> nvrtcResult {
8+
unimplemented!()
9+
}
10+
11+
#[cfg(not(debug_assertions))]
12+
fn unsupported() -> nvrtcResult {
13+
nvrtcResult::NVRTC_ERROR_INTERNAL_ERROR
14+
}
15+
16+
fn to_nvrtc(status: hiprtc_sys::hiprtcResult) -> nvrtcResult {
17+
match status {
18+
hiprtc_sys::hiprtcResult::HIPRTC_SUCCESS => nvrtcResult::NVRTC_SUCCESS,
19+
err => panic!("[ZLUDA] HIPRTC failed: {}", err.0),
20+
}
21+
}
22+
23+
unsafe fn create_program(
24+
prog: *mut nvrtcProgram,
25+
src: *const std::ffi::c_char,
26+
name: *const std::ffi::c_char,
27+
num_headers: i32,
28+
headers: *const *const std::ffi::c_char,
29+
include_names: *const *const std::ffi::c_char,
30+
) -> nvrtcResult {
31+
to_nvrtc(hiprtcCreateProgram(
32+
prog.cast(),
33+
src,
34+
name,
35+
num_headers,
36+
headers.cast_mut(),
37+
include_names.cast_mut(),
38+
))
39+
}
40+
41+
unsafe fn destroy_program(
42+
prog: *mut nvrtcProgram,
43+
) -> nvrtcResult {
44+
to_nvrtc(hiprtcDestroyProgram(prog.cast()))
45+
}
46+
47+
unsafe fn compile_program(
48+
prog: nvrtcProgram,
49+
num_options: i32,
50+
options: *const *const std::ffi::c_char,
51+
) -> nvrtcResult {
52+
to_nvrtc(hiprtcCompileProgram(prog.cast(), num_options, options.cast_mut()))
53+
}

0 commit comments

Comments
 (0)