Commit cd1e0a3 1 parent 605254b commit cd1e0a3 Copy full SHA for cd1e0a3
File tree 2 files changed +44
-1
lines changed
2 files changed +44
-1
lines changed Original file line number Diff line number Diff line change @@ -4366,7 +4366,19 @@ pub unsafe extern "system" fn cublasDgetrsBatched(
4366
4366
info : * mut :: std:: os:: raw:: c_int ,
4367
4367
batchSize : :: std:: os:: raw:: c_int ,
4368
4368
) -> cublasStatus_t {
4369
- crate :: unsupported ( )
4369
+ crate :: dgetrs_batched (
4370
+ handle,
4371
+ trans,
4372
+ n,
4373
+ nrhs,
4374
+ Aarray ,
4375
+ lda,
4376
+ devIpiv,
4377
+ Barray ,
4378
+ ldb,
4379
+ info,
4380
+ batchSize,
4381
+ )
4370
4382
}
4371
4383
4372
4384
#[ no_mangle]
Original file line number Diff line number Diff line change @@ -9,6 +9,7 @@ use rocsolver_sys::{
9
9
rocsolver_cgetrf_batched,
10
10
rocsolver_cgetri_outofplace_batched,
11
11
rocsolver_sgetrs_batched,
12
+ rocsolver_dgetrs_batched,
12
13
rocsolver_zgetrf_batched,
13
14
rocsolver_zgetri_outofplace_batched,
14
15
} ;
@@ -742,6 +743,36 @@ unsafe fn sgetrs_batched(
742
743
) )
743
744
}
744
745
746
+ unsafe fn dgetrs_batched (
747
+ handle : * mut cublasContext ,
748
+ trans : cublasOperation_t ,
749
+ n : i32 ,
750
+ nrhs : i32 ,
751
+ a : * const * const f64 ,
752
+ lda : i32 ,
753
+ dev_ipiv : * const i32 ,
754
+ b : * const * mut f64 ,
755
+ ldb : i32 ,
756
+ info : * mut i32 ,
757
+ batch_size : i32 ,
758
+ ) -> cublasStatus_t {
759
+ let trans = op_from_cuda_for_solver ( trans) ;
760
+ let stride = n * nrhs;
761
+ to_cuda_solver ( rocsolver_dgetrs_batched (
762
+ handle. cast ( ) ,
763
+ trans,
764
+ n,
765
+ nrhs,
766
+ a. cast ( ) ,
767
+ lda,
768
+ dev_ipiv,
769
+ stride as _ ,
770
+ b,
771
+ ldb,
772
+ batch_size,
773
+ ) )
774
+ }
775
+
745
776
unsafe fn dtrmm_v2 (
746
777
handle : * mut cublasContext ,
747
778
side : cublasSideMode_t ,
You can’t perform that action at this time.
0 commit comments