Skip to content

Commit 98dea08

Browse files
committedJan 23, 2025
Add tests AD of matrixfunctions
1 parent f8ca1fd commit 98dea08

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed
 

‎test/ad.jl

+22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ChainRulesTestUtils
33
using FiniteDifferences: FiniteDifferences
44
using Random
55
using LinearAlgebra
6+
using Zygote
67

78
const _repartition = @static if isdefined(Base, :get_extension)
89
Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition
@@ -220,6 +221,27 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
220221
test_rrule(LinearAlgebra.dot, A, B)
221222
end
222223

224+
@timedtestset "Matrix functions ($T)" for T in (Float64, ComplexF64)
225+
for f in (sqrt, exp)
226+
check_inferred = false # !(T <: Real) # not type-stable for real functions
227+
t1 = randn(T, V[1] V[1])
228+
t2 = randn(T, V[2] V[2])
229+
d = DiagonalTensorMap{T}(undef, V[1])
230+
randn!(d.data)
231+
if T <: Real
232+
d.data .= abs.(d.data)
233+
end
234+
d2 = DiagonalTensorMap{T}(undef, V[1])
235+
randn!(d2.data)
236+
if T <: Real
237+
d2.data .= abs.(d2.data)
238+
end
239+
test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred)
240+
test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred)
241+
test_rrule(f, d; check_inferred, output_tangent=d2)
242+
end
243+
end
244+
223245
@timedtestset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64)
224246
atol = precision(T)
225247
rtol = precision(T)

‎test/runtests.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,13 @@ include("spaces.jl")
6060
include("tensors.jl")
6161
include("diagonal.jl")
6262
include("planar.jl")
63-
if !(Sys.isapple()) # TODO: remove once we know why this is so slow on macOS
64-
include("ad.jl")
63+
# TODO: remove once we know AD is slow on macOS CI
64+
test_ad = try
65+
!(Sys.isapple() && ENV["CI"] == true)
66+
catch
67+
true
6568
end
69+
test_ad && include("ad.jl")
6670
include("bugfixes.jl")
6771
Tf = time()
6872
printstyled("Finished all tests in ",

0 commit comments

Comments
 (0)