diff --git a/Project.toml b/Project.toml index 048e597f..6ca22ef9 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ Enzyme = "0.13.118" EnzymeTestUtils = "0.2.5" JET = "0.9, 0.10" LinearAlgebra = "1" -Mooncake = "0.4.195" +Mooncake = "0.5" ParallelTestRunner = "2" Random = "1" SafeTestsets = "0.1" diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index 0ea1b018..dc15dde9 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -51,14 +51,10 @@ MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.c MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) -make_mooncake_tangent(ΔAelem::T) where {T <: Real} = ΔAelem -make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) -make_mooncake_tangent(ΔA::AbstractMatrix{<:Real}) = ΔA -make_mooncake_tangent(ΔA::AbstractVector{<:Real}) = ΔA -make_mooncake_tangent(ΔA::AbstractMatrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔA::AbstractVector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) +make_mooncake_tangent(ΔAelem::T) where {T <: Number} = ΔAelem +make_mooncake_tangent(ΔA::Matrix) = ΔA +make_mooncake_tangent(ΔA::Vector) = ΔA +make_mooncake_tangent(ΔD::Diagonal) = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...)