Skip to content

Commit 7c3891e

Browse files
committed
Fix cusparseDnMatGet.
1 parent ff1bc6d commit 7c3891e

File tree

1 file changed

+45
-20
lines changed

1 file changed

+45
-20
lines changed

zluda_sparse/src/lib.rs

+45-20
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ unsafe fn create_csr(
5858
let csr_row_offsets_type = index_type(csr_row_offsets_type);
5959
let csr_col_ind_type = index_type(csr_col_ind_type);
6060
let idx_base = index_base(idx_base);
61-
let value_type = data_type(value_type);
61+
let value_type = to_roc_data_type(value_type);
6262
to_cuda(rocsparse_create_csr_descr(
6363
descr as _,
6464
rows,
@@ -74,7 +74,7 @@ unsafe fn create_csr(
7474
))
7575
}
7676

77-
fn data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
77+
fn to_roc_data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
7878
match data_type {
7979
cudaDataType_t::CUDA_R_32F => rocsparse_datatype::rocsparse_datatype_f32_r,
8080
cudaDataType_t::CUDA_R_64F => rocsparse_datatype::rocsparse_datatype_f64_r,
@@ -88,6 +88,20 @@ fn data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
8888
}
8989
}
9090

91+
fn to_cuda_data_type(data_type: rocsparse_datatype) -> cudaDataType_t {
92+
match data_type {
93+
rocsparse_datatype::rocsparse_datatype_f32_r => cudaDataType_t::CUDA_R_32F,
94+
rocsparse_datatype::rocsparse_datatype_f64_r => cudaDataType_t::CUDA_R_64F,
95+
rocsparse_datatype::rocsparse_datatype_f32_c => cudaDataType_t::CUDA_C_32F,
96+
rocsparse_datatype::rocsparse_datatype_f64_c => cudaDataType_t::CUDA_C_64F,
97+
rocsparse_datatype::rocsparse_datatype_i8_r => cudaDataType_t::CUDA_R_8I,
98+
rocsparse_datatype::rocsparse_datatype_u8_r => cudaDataType_t::CUDA_R_8U,
99+
rocsparse_datatype::rocsparse_datatype_i32_r => cudaDataType_t::CUDA_R_32I,
100+
rocsparse_datatype::rocsparse_datatype_u32_r => cudaDataType_t::CUDA_R_32U,
101+
_ => panic!(),
102+
}
103+
}
104+
91105
fn index_type(index_type: cusparseIndexType_t) -> rocsparse_indextype {
92106
match index_type {
93107
cusparseIndexType_t::CUSPARSE_INDEX_16U => rocsparse_indextype::rocsparse_indextype_u16,
@@ -109,14 +123,22 @@ fn index_base(index_base: cusparseIndexBase_t) -> rocsparse_index_base {
109123
}
110124
}
111125

112-
fn order(order: cusparseOrder_t) -> rocsparse_order {
126+
fn to_roc_order(order: cusparseOrder_t) -> rocsparse_order {
113127
match order {
114128
cusparseOrder_t::CUSPARSE_ORDER_COL => rocsparse_order::rocsparse_order_column,
115129
cusparseOrder_t::CUSPARSE_ORDER_ROW => rocsparse_order::rocsparse_order_row,
116130
_ => panic!(),
117131
}
118132
}
119133

134+
fn to_cuda_order(order: rocsparse_order) -> cusparseOrder_t {
135+
match order {
136+
rocsparse_order::rocsparse_order_column => cusparseOrder_t::CUSPARSE_ORDER_COL,
137+
rocsparse_order::rocsparse_order_row => cusparseOrder_t::CUSPARSE_ORDER_ROW,
138+
_ => panic!(),
139+
}
140+
}
141+
120142
unsafe fn create_csrsv2_info(info: *mut *mut csrsv2Info) -> cusparseStatus_t {
121143
to_cuda(rocsparse_create_mat_info(info.cast()))
122144
}
@@ -127,7 +149,7 @@ unsafe fn create_dn_vec(
127149
values: *mut std::ffi::c_void,
128150
value_type: cudaDataType_t,
129151
) -> cusparseStatus_t {
130-
let value_type = data_type(value_type);
152+
let value_type = to_roc_data_type(value_type);
131153
to_cuda(rocsparse_create_dnvec_descr(
132154
dn_vec_descr.cast(),
133155
size,
@@ -466,7 +488,7 @@ unsafe fn spmv(
466488
external_buffer: *mut std::ffi::c_void,
467489
) -> cusparseStatus_t {
468490
let op_a = operation(op_a);
469-
let compute_type = data_type(compute_type);
491+
let compute_type = to_roc_data_type(compute_type);
470492
let alg = to_spmv_alg(alg);
471493
// divide by 2 in case there's any arithmetic done on it
472494
let mut size = usize::MAX / 2;
@@ -498,7 +520,7 @@ unsafe fn spmv_buffersize(
498520
buffer_size: *mut usize,
499521
) -> cusparseStatus_t {
500522
let op_a = operation(op_a);
501-
let compute_type = data_type(compute_type);
523+
let compute_type = to_roc_data_type(compute_type);
502524
let alg = to_spmv_alg(alg);
503525
to_cuda(rocsparse_spmv(
504526
handle.cast(),
@@ -1130,7 +1152,7 @@ unsafe fn spsv_buffersize(
11301152
buffer_size: *mut usize,
11311153
) -> cusparseStatus_t {
11321154
let op_a = operation(op_a);
1133-
let compute_type = data_type(compute_type);
1155+
let compute_type = to_roc_data_type(compute_type);
11341156
let alg = to_spsv_alg(alg);
11351157
to_cuda(rocsparse_spsv(
11361158
handle.cast(),
@@ -1169,7 +1191,7 @@ unsafe fn spsv_analysis(
11691191
external_buffer: *mut c_void,
11701192
) -> cusparseStatus_t {
11711193
let op_a = operation(op_a);
1172-
let compute_type = data_type(compute_type);
1194+
let compute_type = to_roc_data_type(compute_type);
11731195
let alg = to_spsv_alg(alg);
11741196
let spsv_descr = spsv_descr.cast::<SpSvDescr>().as_mut().unwrap();
11751197
spsv_descr.external_buffer = external_buffer;
@@ -1200,7 +1222,7 @@ unsafe fn spsv_solve(
12001222
spsv_descr: *mut cusparseSpSVDescr,
12011223
) -> cusparseStatus_t {
12021224
let op_a = operation(op_a);
1203-
let compute_type = data_type(compute_type);
1225+
let compute_type = to_roc_data_type(compute_type);
12041226
let alg = to_spsv_alg(alg);
12051227
let spsv_descr = spsv_descr.cast::<SpSvDescr>().as_ref().unwrap();
12061228
to_cuda(rocsparse_spsv(
@@ -1322,18 +1344,18 @@ unsafe fn create_dn_mat(
13221344
ld: i64,
13231345
values: *mut ::std::os::raw::c_void,
13241346
value_type: cudaDataType,
1325-
o: cusparseOrder_t,
1347+
order: cusparseOrder_t,
13261348
) -> cusparseStatus_t {
1327-
let value_type = data_type(value_type);
1328-
let o = order(o);
1349+
let value_type = to_roc_data_type(value_type);
1350+
let order = to_roc_order(order);
13291351
to_cuda(rocsparse_create_dnmat_descr(
13301352
dn_mat_descr.cast(),
13311353
rows,
13321354
cols,
13331355
ld,
13341356
values,
13351357
value_type,
1336-
o,
1358+
order,
13371359
))
13381360
}
13391361

@@ -1350,19 +1372,22 @@ unsafe fn dn_mat_get(
13501372
ld: *mut i64,
13511373
values: *mut *mut ::std::os::raw::c_void,
13521374
type_: *mut cudaDataType,
1353-
o: *mut cusparseOrder_t,
1375+
order: *mut cusparseOrder_t,
13541376
) -> cusparseStatus_t {
1355-
let mut type_ = data_type(*type_);
1356-
let mut o = order(*o);
1357-
to_cuda(rocsparse_dnmat_get(
1377+
let mut out_type = to_roc_data_type(*type_);
1378+
let mut out_order = to_roc_order(*order);
1379+
let status = to_cuda(rocsparse_dnmat_get(
13581380
dn_mat_descr.cast(),
13591381
rows,
13601382
cols,
13611383
ld,
13621384
values,
1363-
&mut type_,
1364-
&mut o,
1365-
))
1385+
&mut out_type,
1386+
&mut out_order,
1387+
));
1388+
*type_ = to_cuda_data_type(out_type);
1389+
*order = to_cuda_order(out_order);
1390+
status
13661391
}
13671392

13681393
unsafe fn dn_mat_get_values(

0 commit comments

Comments
 (0)