From da87c7e2c0a3263fdc1ba8dbd3b2fa89741b0f45 Mon Sep 17 00:00:00 2001 From: Kazuki Komatsu Date: Mon, 11 Feb 2019 19:25:15 +0900 Subject: [PATCH] Add complex-typed API (#10) * Add complex version gelsd * Add complex version gesdd * Add complex version gesvd * Add unittests for complex-type API --- source/mir/lapack.d | 188 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 186 insertions(+), 2 deletions(-) diff --git a/source/mir/lapack.d b/source/mir/lapack.d index c310bac..2d63095 100644 --- a/source/mir/lapack.d +++ b/source/mir/lapack.d @@ -12,6 +12,7 @@ import mir.ndslice.slice; import mir.ndslice.topology: retro; import mir.ndslice.iterator; import mir.utility: min, max; +import mir.internal.utility : realType, isComplex; static import lapack; @@ -45,6 +46,8 @@ unittest { alias s = getri_wq!float; alias d = getri_wq!double; + alias c = getri_wq!cfloat; + alias z = getri_wq!cdouble; } /// @@ -73,6 +76,8 @@ unittest { alias s = getri!float; alias d = getri!double; + alias c = getri!cfloat; + alias z = getri!cdouble; } /// @@ -98,6 +103,8 @@ unittest { alias s = getrf!float; alias d = getrf!double; + alias c = getrf!cfloat; + alias z = getrf!cdouble; } /// @@ -181,6 +188,7 @@ size_t gelsd_wq(T)( Slice!(T*, 2, Canonical) b, ref size_t liwork, ) + if(!isComplex!T) { assert(b.length!1 == a.length!1); @@ -203,10 +211,45 @@ size_t gelsd_wq(T)( return cast(size_t) work; } + +/// ditto +size_t gelsd_wq(T)( + Slice!(T*, 2, Canonical) a, + Slice!(T*, 2, Canonical) b, + ref size_t lrwork, + ref size_t liwork, + ) + if(isComplex!T) +{ + assert(b.length!1 == a.length!1); + + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint nrhs = cast(lapackint) b.length; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldb = cast(lapackint) b._stride.max(1); + realType!T rcond = void; + lapackint rank = void; + T work = void; + lapackint lwork = -1; + realType!T rwork = void; + lapackint iwork = void; + lapackint info = void; + + lapack.gelsd_(m, n, nrhs, a.iterator, lda, b.iterator, ldb, null, rcond, rank, &work, lwork, &rwork, &iwork, info); + + assert(info == 0); + lrwork = cast(size_t)rwork; + liwork = iwork; + return cast(size_t) work; +} + unittest { alias s = gelsd_wq!float; alias d = gelsd_wq!double; + alias c = gelsd_wq!cfloat; + alias z = gelsd_wq!cdouble; } /// @@ -219,6 +262,7 @@ size_t gelsd(T)( Slice!(T*) work, Slice!(lapackint*) iwork, ) + if(!isComplex!T) { assert(b.length!1 == a.length!1); assert(s.length == min(a.length!0, a.length!1)); @@ -239,10 +283,44 @@ size_t gelsd(T)( return info; } +/// ditto +size_t gelsd(T)( + Slice!(T*, 2, Canonical) a, + Slice!(T*, 2, Canonical) b, + Slice!(realType!T*) s, + realType!T rcond, + ref size_t rank, + Slice!(T*) work, + Slice!(realType!T*) rwork, + Slice!(lapackint*) iwork, + ) + if(isComplex!T) +{ + assert(b.length!1 == a.length!1); + assert(s.length == min(a.length!0, a.length!1)); + + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint nrhs = cast(lapackint) b.length; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldb = cast(lapackint) b._stride.max(1); + lapackint rank_ = void; + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + + lapack.gelsd_(m, n, nrhs, a.iterator, lda, b.iterator, ldb, s.iterator, rcond, rank_, work.iterator, lwork, rwork.iterator, iwork.iterator, info); + + assert(info >= 0); + rank = rank_; + return info; +} + unittest { alias s = gelsd!float; alias d = gelsd!double; + alias c = gelsd!cfloat; + alias z = gelsd!cdouble; } /// `gesdd` work space query @@ -262,7 +340,14 @@ size_t gesdd_wq(T)( lapackint lwork = -1; lapackint info = void; - lapack.gesdd_(jobz, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, null, info); + static if(!isComplex!T) + { + lapack.gesdd_(jobz, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, null, info); + } + else + { + lapack.gesdd_(jobz, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, null, null, info); + } assert(info == 0); return cast(size_t) work; @@ -272,6 +357,8 @@ unittest { alias s = gesdd_wq!float; alias d = gesdd_wq!double; + alias c = gesdd_wq!cfloat; + alias z = gesdd_wq!cdouble; } /// @@ -284,6 +371,7 @@ size_t gesdd(T)( Slice!(T*) work, Slice!(lapackint*) iwork, ) + if(!isComplex!T) { lapackint m = cast(lapackint) a.length!1; lapackint n = cast(lapackint) a.length!0; @@ -299,10 +387,39 @@ size_t gesdd(T)( return info; } +/// ditto +size_t gesdd(T)( + char jobz, + Slice!(T*, 2, Canonical) a, + Slice!(realType!T*) s, + Slice!(T*, 2, Canonical) u, + Slice!(T*, 2, Canonical) vt, + Slice!(T*) work, + Slice!(realType!T*) rwork, + Slice!(lapackint*) iwork, + ) + if(isComplex!T) +{ + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldu = cast(lapackint) u._stride.max(1); + lapackint ldvt = cast(lapackint) vt._stride.max(1); + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + + lapack.gesdd_(jobz, m, n, a.iterator, lda, s.iterator, u.iterator, ldu, vt.iterator, ldvt, work.iterator, lwork, rwork.iterator, iwork.iterator, info); + + assert(info >= 0); + return info; +} + unittest { alias s = gesdd!float; alias d = gesdd!double; + alias c = gesdd!cfloat; + alias z = gesdd!cdouble; } /// `gesvd` work space query @@ -323,7 +440,14 @@ size_t gesvd_wq(T)( lapackint lwork = -1; lapackint info = void; - lapack.gesvd_(jobu, jobvt, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, info); + static if(!isComplex!T) + { + lapack.gesvd_(jobu, jobvt, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, info); + } + else + { + lapack.gesvd_(jobu, jobvt, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, null, info); + } assert(info == 0); return cast(size_t) work; @@ -333,6 +457,8 @@ unittest { alias s = gesvd_wq!float; alias d = gesvd_wq!double; + alias c = gesvd_wq!cfloat; + alias z = gesvd_wq!cdouble; } /// @@ -345,6 +471,7 @@ size_t gesvd(T)( Slice!(T*, 2, Canonical) vt, Slice!(T*) work, ) + if(!isComplex!T) { lapackint m = cast(lapackint) a.length!1; lapackint n = cast(lapackint) a.length!0; @@ -360,10 +487,39 @@ size_t gesvd(T)( return info; } +/// ditto +size_t gesvd(T)( + char jobu, + char jobvt, + Slice!(T*, 2, Canonical) a, + Slice!(realType!T*) s, + Slice!(T*, 2, Canonical) u, + Slice!(T*, 2, Canonical) vt, + Slice!(T*) work, + Slice!(realType!T*) rwork, + ) + if(isComplex!T) +{ + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldu = cast(lapackint) u._stride.max(1); + lapackint ldvt = cast(lapackint) vt._stride.max(1); + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + + lapack.gesvd_(jobu, jobvt, m, n, a.iterator, lda, s.iterator, u.iterator, ldu, vt.iterator, ldvt, work.iterator, lwork, rwork.iterator, info); + + assert(info >= 0); + return info; +} + unittest { alias s = gesvd!float; alias d = gesvd!double; + alias c = gesvd!cfloat; + alias z = gesvd!cdouble; } /// @@ -450,6 +606,8 @@ unittest { alias s = sytrf!float; alias d = sytrf!double; + alias c = sytrf!cfloat; + alias z = sytrf!cdouble; } /// @@ -477,6 +635,8 @@ unittest { alias s = geqrf!float; alias d = geqrf!double; + alias c = geqrf!cfloat; + alias z = geqrf!cdouble; } /// @@ -508,6 +668,8 @@ unittest { alias s = getrs!float; alias d = getrs!double; + alias c = getrs!cfloat; + alias z = getrs!cdouble; } /// @@ -537,6 +699,8 @@ unittest { alias s = potrs!float; alias d = potrs!double; + alias c = potrs!cfloat; + alias z = potrs!cdouble; } /// @@ -568,6 +732,8 @@ unittest { alias s = sytrs2!float; alias d = sytrs2!double; + alias c = sytrs2!cfloat; + alias z = sytrs2!cdouble; } /// @@ -598,6 +764,8 @@ version(none) unittest { alias s = geqrs!float; alias d = geqrs!double; + alias c = geqrs!cfloat; + alias z = geqrs!cdouble; } /// @@ -627,6 +795,8 @@ unittest { alias s = sysv_rook_wk!float; alias d = sysv_rook_wk!double; + alias c = sysv_rook_wk!cfloat; + alias z = sysv_rook_wk!cdouble; } /// @@ -659,6 +829,8 @@ unittest { alias s = sysv_rook!float; alias d = sysv_rook!double; + alias c = sysv_rook!cfloat; + alias z = sysv_rook!cdouble; } /// @@ -800,6 +972,8 @@ unittest { alias s = potrf!float; alias d = potrf!double; + alias c = potrf!cfloat; + alias z = potrf!cdouble; } /// @@ -874,6 +1048,8 @@ unittest { alias s = sptri!float; alias d = sptri!double; + alias c = sptri!cfloat; + alias z = sptri!cdouble; } /// @@ -898,6 +1074,8 @@ unittest { alias s = potri!float; alias d = potri!double; + alias c = potri!cfloat; + alias z = potri!cdouble; } /// @@ -938,6 +1116,8 @@ unittest { alias s = pptri!float; alias d = pptri!double; + alias c = pptri!cfloat; + alias z = pptri!cdouble; } /// @@ -963,6 +1143,8 @@ unittest { alias s = trtri!float; alias d = trtri!double; + alias c = trtri!cfloat; + alias z = trtri!cdouble; } /// @@ -1006,6 +1188,8 @@ unittest { alias s = tptri!float; alias d = tptri!double; + alias c = tptri!cfloat; + alias z = tptri!cdouble; } ///