@@ -58,7 +58,7 @@ unsafe fn create_csr(
58
58
let csr_row_offsets_type = index_type ( csr_row_offsets_type) ;
59
59
let csr_col_ind_type = index_type ( csr_col_ind_type) ;
60
60
let idx_base = index_base ( idx_base) ;
61
- let value_type = data_type ( value_type) ;
61
+ let value_type = to_roc_data_type ( value_type) ;
62
62
to_cuda ( rocsparse_create_csr_descr (
63
63
descr as _ ,
64
64
rows,
@@ -74,7 +74,7 @@ unsafe fn create_csr(
74
74
) )
75
75
}
76
76
77
- fn data_type ( data_type : cudaDataType_t ) -> rocsparse_datatype {
77
+ fn to_roc_data_type ( data_type : cudaDataType_t ) -> rocsparse_datatype {
78
78
match data_type {
79
79
cudaDataType_t:: CUDA_R_32F => rocsparse_datatype:: rocsparse_datatype_f32_r,
80
80
cudaDataType_t:: CUDA_R_64F => rocsparse_datatype:: rocsparse_datatype_f64_r,
@@ -88,6 +88,20 @@ fn data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
88
88
}
89
89
}
90
90
91
+ fn to_cuda_data_type ( data_type : rocsparse_datatype ) -> cudaDataType_t {
92
+ match data_type {
93
+ rocsparse_datatype:: rocsparse_datatype_f32_r => cudaDataType_t:: CUDA_R_32F ,
94
+ rocsparse_datatype:: rocsparse_datatype_f64_r => cudaDataType_t:: CUDA_R_64F ,
95
+ rocsparse_datatype:: rocsparse_datatype_f32_c => cudaDataType_t:: CUDA_C_32F ,
96
+ rocsparse_datatype:: rocsparse_datatype_f64_c => cudaDataType_t:: CUDA_C_64F ,
97
+ rocsparse_datatype:: rocsparse_datatype_i8_r => cudaDataType_t:: CUDA_R_8I ,
98
+ rocsparse_datatype:: rocsparse_datatype_u8_r => cudaDataType_t:: CUDA_R_8U ,
99
+ rocsparse_datatype:: rocsparse_datatype_i32_r => cudaDataType_t:: CUDA_R_32I ,
100
+ rocsparse_datatype:: rocsparse_datatype_u32_r => cudaDataType_t:: CUDA_R_32U ,
101
+ _ => panic ! ( ) ,
102
+ }
103
+ }
104
+
91
105
fn index_type ( index_type : cusparseIndexType_t ) -> rocsparse_indextype {
92
106
match index_type {
93
107
cusparseIndexType_t:: CUSPARSE_INDEX_16U => rocsparse_indextype:: rocsparse_indextype_u16,
@@ -109,14 +123,22 @@ fn index_base(index_base: cusparseIndexBase_t) -> rocsparse_index_base {
109
123
}
110
124
}
111
125
112
- fn order ( order : cusparseOrder_t ) -> rocsparse_order {
126
+ fn to_roc_order ( order : cusparseOrder_t ) -> rocsparse_order {
113
127
match order {
114
128
cusparseOrder_t:: CUSPARSE_ORDER_COL => rocsparse_order:: rocsparse_order_column,
115
129
cusparseOrder_t:: CUSPARSE_ORDER_ROW => rocsparse_order:: rocsparse_order_row,
116
130
_ => panic ! ( ) ,
117
131
}
118
132
}
119
133
134
+ fn to_cuda_order ( order : rocsparse_order ) -> cusparseOrder_t {
135
+ match order {
136
+ rocsparse_order:: rocsparse_order_column => cusparseOrder_t:: CUSPARSE_ORDER_COL ,
137
+ rocsparse_order:: rocsparse_order_row => cusparseOrder_t:: CUSPARSE_ORDER_ROW ,
138
+ _ => panic ! ( ) ,
139
+ }
140
+ }
141
+
120
142
unsafe fn create_csrsv2_info ( info : * mut * mut csrsv2Info ) -> cusparseStatus_t {
121
143
to_cuda ( rocsparse_create_mat_info ( info. cast ( ) ) )
122
144
}
@@ -127,7 +149,7 @@ unsafe fn create_dn_vec(
127
149
values : * mut std:: ffi:: c_void ,
128
150
value_type : cudaDataType_t ,
129
151
) -> cusparseStatus_t {
130
- let value_type = data_type ( value_type) ;
152
+ let value_type = to_roc_data_type ( value_type) ;
131
153
to_cuda ( rocsparse_create_dnvec_descr (
132
154
dn_vec_descr. cast ( ) ,
133
155
size,
@@ -466,7 +488,7 @@ unsafe fn spmv(
466
488
external_buffer : * mut std:: ffi:: c_void ,
467
489
) -> cusparseStatus_t {
468
490
let op_a = operation ( op_a) ;
469
- let compute_type = data_type ( compute_type) ;
491
+ let compute_type = to_roc_data_type ( compute_type) ;
470
492
let alg = to_spmv_alg ( alg) ;
471
493
// divide by 2 in case there's any arithmetic done on it
472
494
let mut size = usize:: MAX / 2 ;
@@ -498,7 +520,7 @@ unsafe fn spmv_buffersize(
498
520
buffer_size : * mut usize ,
499
521
) -> cusparseStatus_t {
500
522
let op_a = operation ( op_a) ;
501
- let compute_type = data_type ( compute_type) ;
523
+ let compute_type = to_roc_data_type ( compute_type) ;
502
524
let alg = to_spmv_alg ( alg) ;
503
525
to_cuda ( rocsparse_spmv (
504
526
handle. cast ( ) ,
@@ -1130,7 +1152,7 @@ unsafe fn spsv_buffersize(
1130
1152
buffer_size : * mut usize ,
1131
1153
) -> cusparseStatus_t {
1132
1154
let op_a = operation ( op_a) ;
1133
- let compute_type = data_type ( compute_type) ;
1155
+ let compute_type = to_roc_data_type ( compute_type) ;
1134
1156
let alg = to_spsv_alg ( alg) ;
1135
1157
to_cuda ( rocsparse_spsv (
1136
1158
handle. cast ( ) ,
@@ -1169,7 +1191,7 @@ unsafe fn spsv_analysis(
1169
1191
external_buffer : * mut c_void ,
1170
1192
) -> cusparseStatus_t {
1171
1193
let op_a = operation ( op_a) ;
1172
- let compute_type = data_type ( compute_type) ;
1194
+ let compute_type = to_roc_data_type ( compute_type) ;
1173
1195
let alg = to_spsv_alg ( alg) ;
1174
1196
let spsv_descr = spsv_descr. cast :: < SpSvDescr > ( ) . as_mut ( ) . unwrap ( ) ;
1175
1197
spsv_descr. external_buffer = external_buffer;
@@ -1200,7 +1222,7 @@ unsafe fn spsv_solve(
1200
1222
spsv_descr : * mut cusparseSpSVDescr ,
1201
1223
) -> cusparseStatus_t {
1202
1224
let op_a = operation ( op_a) ;
1203
- let compute_type = data_type ( compute_type) ;
1225
+ let compute_type = to_roc_data_type ( compute_type) ;
1204
1226
let alg = to_spsv_alg ( alg) ;
1205
1227
let spsv_descr = spsv_descr. cast :: < SpSvDescr > ( ) . as_ref ( ) . unwrap ( ) ;
1206
1228
to_cuda ( rocsparse_spsv (
@@ -1322,18 +1344,18 @@ unsafe fn create_dn_mat(
1322
1344
ld : i64 ,
1323
1345
values : * mut :: std:: os:: raw:: c_void ,
1324
1346
value_type : cudaDataType ,
1325
- o : cusparseOrder_t ,
1347
+ order : cusparseOrder_t ,
1326
1348
) -> cusparseStatus_t {
1327
- let value_type = data_type ( value_type) ;
1328
- let o = order ( o ) ;
1349
+ let value_type = to_roc_data_type ( value_type) ;
1350
+ let order = to_roc_order ( order ) ;
1329
1351
to_cuda ( rocsparse_create_dnmat_descr (
1330
1352
dn_mat_descr. cast ( ) ,
1331
1353
rows,
1332
1354
cols,
1333
1355
ld,
1334
1356
values,
1335
1357
value_type,
1336
- o ,
1358
+ order ,
1337
1359
) )
1338
1360
}
1339
1361
@@ -1350,19 +1372,22 @@ unsafe fn dn_mat_get(
1350
1372
ld : * mut i64 ,
1351
1373
values : * mut * mut :: std:: os:: raw:: c_void ,
1352
1374
type_ : * mut cudaDataType ,
1353
- o : * mut cusparseOrder_t ,
1375
+ order : * mut cusparseOrder_t ,
1354
1376
) -> cusparseStatus_t {
1355
- let mut type_ = data_type ( * type_) ;
1356
- let mut o = order ( * o ) ;
1357
- to_cuda ( rocsparse_dnmat_get (
1377
+ let mut out_type = to_roc_data_type ( * type_) ;
1378
+ let mut out_order = to_roc_order ( * order ) ;
1379
+ let status = to_cuda ( rocsparse_dnmat_get (
1358
1380
dn_mat_descr. cast ( ) ,
1359
1381
rows,
1360
1382
cols,
1361
1383
ld,
1362
1384
values,
1363
- & mut type_,
1364
- & mut o,
1365
- ) )
1385
+ & mut out_type,
1386
+ & mut out_order,
1387
+ ) ) ;
1388
+ * type_ = to_cuda_data_type ( out_type) ;
1389
+ * order = to_cuda_order ( out_order) ;
1390
+ status
1366
1391
}
1367
1392
1368
1393
unsafe fn dn_mat_get_values (
0 commit comments