Skip to content

Error when differentiating complex-valued ODE with Mooncake #1321

@albertomercurio

Description

@albertomercurio

Describe the bug 🐞

I get an error when differentiating a complex-valued ODE using either MooncakeVJP or MooncakeAdjoint.

Expected behavior

Minimal Reproducible Example 👇

using LinearAlgebra
using SparseArrays
using OrdinaryDiffEq
using DifferentiationInterface
using Mooncake
using SciMLSensitivity
using SciMLSensitivity: MooncakeVJP

# %%

const T = ComplexF64

function dudt!(du, u, p, t)
    du[1] = -p[1]*u[2]
    du[2] = p[2]*u[1]

    return nothing
end

function my_fun2(p)
    x = T[3.0, 4.0]
    prob = ODEProblem{true}(dudt!, x, (0.0, 100.0), T.(p))
    sol = solve(prob, Tsit5(), save_everystep=false,
        sensealg=BacksolveAdjoint(autojacvec=MooncakeVJP()),
        # sensealg=MooncakeAdjoint(),
    )
    return real(sol.u[end][end])
end

p = [1.0, 2.0];

my_fun2(p)

# %%

backend = AutoMooncake(; config=nothing)
prep = prepare_gradient(my_fun2, backend, p)
DifferentiationInterface.gradient(my_fun2, prep, backend, p)

Error & Stacktrace ⚠️

Details
ERROR: MethodError: Cannot `convert` an object of type 
  MooncakeFunctionWrappersExt.FunctionWrapperTangent{Core.OpaqueClosure{Tuple{Mooncake.CoDual{Array{Complex{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag,Complex{Float64}},Complex{Float64},1}},1},Array{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Mooncake.Tangent{NamedTuple{(:value, :partials),Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},Mooncake.Tangent{NamedTuple{(:values,),Tuple{Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}}}}}}}}},Mooncake.Tangent{NamedTuple{(:value, :partials),Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},Mooncake.Tangent{NamedTuple{(:values,),Tuple{Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}}}}}}}}}}}},1}},Mooncake.CoDual{Array{Complex{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag,Complex{Float64}},Complex{Float64},1}},1},Array{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Mooncake.Tangent{NamedTuple{(:value, :partials),Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},Mooncake.Tangent{NamedTuple{(:values,),Tuple{Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}}}}}}}}},Mooncake.Tangent{NamedTuple{(:value, :partials),Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},Mooncake.Tangent{NamedTuple{(:values,),Tuple{Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}}}}}}}}}}}},1}},Mooncake.CoDual{Array{Complex{Float64},1},Array{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},1}},Mooncake.CoDual{Float64,Mooncake.NoFData}},Union{}}} to an object of type 
  MooncakeFunctionWrappersExt.FunctionWrapperTangent{Core.OpaqueClosure{Tuple{Mooncake.CoDual{Array{Complex{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag,Complex{Float64}},Complex{Float64},1}},1},Array{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Mooncake.Tangent{NamedTuple{(:value, :partials),Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},Mooncake.Tangent{NamedTuple{(:values,),Tuple{Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}}}}}}}}},Mooncake.Tangent{NamedTuple{(:value, :partials),Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},Mooncake.Tangent{NamedTuple{(:values,),Tuple{Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}}}}}}}}}}}},1}},Mooncake.CoDual{Array{Complex{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag,Complex{Float64}},Complex{Float64},1}},1},Array{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Mooncake.Tangent{NamedTuple{(:value, :partials),Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},Mooncake.Tangent{NamedTuple{(:values,),Tuple{Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}}}}}}}}},Mooncake.Tangent{NamedTuple{(:value, :partials),Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},Mooncake.Tangent{NamedTuple{(:values,),Tuple{Tuple{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}}}}}}}}}}}},1}},Mooncake.CoDual{Array{Complex{Float64},1},Array{Mooncake.Tangent{NamedTuple{(:re, :im),Tuple{Float64,Float64}}},1}},Mooncake.CoDual{Float64,Mooncake.NoFData}},Tuple{Mooncake.CoDual{Nothing, Mooncake.NoFData}, Core.OpaqueClosure{Tuple{Mooncake.NoRData}, Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, Float64}}}}}
The function `convert` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  (::Type{MooncakeFunctionWrappersExt.FunctionWrapperTangent{Tfwds_oc}} where Tfwds_oc)(::Any, ::Any)
   @ MooncakeFunctionWrappersExt ~/.julia/packages/Mooncake/pjlQh/ext/MooncakeFunctionWrappersExt.jl:44
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:126

Stacktrace:
  [1] cvt1
    @ ./essentials.jl:612 [inlined]
  [2] macro expansion
    @ ./ntuple.jl:72 [inlined]
  [3] ntuple
    @ ./ntuple.jl:69 [inlined]
  [4] convert(::Type{…}, x::Tuple{…})
    @ Base ./essentials.jl:614
  [5] cvt1
    @ ./essentials.jl:612 [inlined]
  [6] ntuple
    @ ./ntuple.jl:48 [inlined]
  [7] convert(::Type{Tuple{…}}, x::Tuple{Tuple{…}})
    @ Base ./essentials.jl:614
  [8] Tuple{Tuple{…}}(x::Tuple{Tuple{…}})
    @ Base ./tuple.jl:450
  [9] @NamedTuple{fw::Tuple{…}}(args::Tuple{Tuple{…}})
    @ Base ./namedtuple.jl:121
 [10] macro expansion
    @ ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:592 [inlined]
 [11] macro expansion
    @ ./none:0 [inlined]
 [12] zero_tangent_internal(x::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{…}, false}, d::IdDict{Any, Any})
    @ Mooncake ./none:0
 [13] macro expansion
    @ ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:592 [inlined]
 [14] macro expansion
    @ ./none:0 [inlined]
 [15] zero_tangent_internal(x::SciMLBase.Void{FunctionWrappersWrappers.FunctionWrappersWrapper{…}}, d::IdDict{Any, Any})
    @ Mooncake ./none:0
 [16] zero_tangent_internal(p::FunctionWrappers.FunctionWrapper{Nothing, Tuple{…}}, dict::IdDict{Any, Any})
    @ MooncakeFunctionWrappersExt ~/.julia/packages/Mooncake/pjlQh/ext/MooncakeFunctionWrappersExt.jl:138
 [17] macro expansion
    @ ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:544 [inlined]
 [18] zero_tangent_internal
    @ ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:540 [inlined]
 [19] macro expansion
    @ ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:592 [inlined]
 [20] macro expansion
    @ ./none:0 [inlined]
 [21] zero_tangent_internal(x::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{…}, false}, d::IdDict{Any, Any})
    @ Mooncake ./none:0
 [22] macro expansion
    @ ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:592 [inlined]
 [23] macro expansion
    @ ./none:0 [inlined]
 [24] zero_tangent_internal
    @ ./none:0 [inlined]
 [25] macro expansion
    @ ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:592 [inlined]
 [26] macro expansion
    @ ./none:0 [inlined]
 [27] zero_tangent_internal(x::ODEProblem{…}, d::IdDict{…})
    @ Mooncake ./none:0
 [28] macro expansion
    @ ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:592 [inlined]
 [29] macro expansion
    @ ./none:0 [inlined]
 [30] zero_tangent_internal(x::ODESolution{…}, d::IdDict{…})
    @ Mooncake ./none:0
 [31] zero_tangent(x::ODESolution{…})
    @ Mooncake ~/.julia/packages/Mooncake/pjlQh/src/tangents/tangents.jl:516
 [32] rrule_wrapper(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/pjlQh/src/tools_for_rules.jl:484
 [33] rrule!!
    @ ./array.jl:0 [inlined]
 [34] #solve#37
    @ ~/.julia/packages/DiffEqBase/kizyx/src/solve.jl:575 [inlined]
 [35] (::Tuple{…})(_2::Mooncake.CoDual{…}, _3::Mooncake.CoDual{…}, _4::Mooncake.CoDual{…}, _5::Mooncake.CoDual{…}, _6::Mooncake.CoDual{…}, _7::Mooncake.CoDual{…}, _8::Mooncake.CoDual{…}, _9::Mooncake.CoDual{…}, _10::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [36] (::MistyClosures.MistyClosure{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
 [37] DerivedRule
    @ ~/.julia/packages/Mooncake/pjlQh/src/interpreter/reverse_mode.jl:990 [inlined]
 [38] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/.julia/packages/Mooncake/pjlQh/src/interpreter/reverse_mode.jl:1893
 [39] LazyDerivedRule
    @ ~/.julia/packages/Mooncake/pjlQh/src/interpreter/reverse_mode.jl:1887 [inlined]
 [40] my_fun2
    @ ~/GitHub/Research/Undef/Autodiff QuantumToolbox/issue_example.jl:27 [inlined]
 [41] (::Tuple{…})(_2::Mooncake.CoDual{…}, _3::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [42] (::MistyClosures.MistyClosure{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
 [43] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/pjlQh/src/interpreter/reverse_mode.jl:990
 [44] prepare_gradient_cache(::Function, ::Vararg{…}; friendly_tangents::Bool, kwargs::@Kwargs{})
    @ Mooncake ~/.julia/packages/Mooncake/pjlQh/src/interface.jl:584
 [45] prepare_gradient_cache
    @ ~/.julia/packages/Mooncake/pjlQh/src/interface.jl:581 [inlined]
 [46] prepare_gradient_nokwarg(::Val{true}, ::typeof(my_fun2), ::AutoMooncake{Nothing}, ::Vector{Float64})
    @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/6H4dc/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:124
 [47] #prepare_gradient#68
    @ ~/.julia/packages/DifferentiationInterface/6H4dc/src/first_order/gradient.jl:11 [inlined]
 [48] prepare_gradient(::typeof(my_fun2), ::AutoMooncake{Nothing}, ::Vector{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/6H4dc/src/first_order/gradient.jl:8
 [49] top-level scope
    @ ~/GitHub/Research/Undef/Autodiff QuantumToolbox/issue_example.jl:43
Some type information was truncated. Use `show(err)` to see complete types.

It works if I use T = Float64.

Moreover, if I use a sparse matrix version of the RHS

const A1 = sparse(T[0.0 1.0; 0.0 0.0])
const A2 = sparse(T[0.0 0.0; 1.0 0.0])

function dudt!(du, u, p, t)
    mul!(du, A1, u, -p[1], zero(T))
    mul!(du, A2, u, p[2], one(T))
    return nothing
end

I get a different error

Error & Stacktrace ⚠️

Details
ERROR: MethodError: Cannot `convert` an object of type ChainRulesCore.Tangent{Any, @NamedTuple{re::Float64, im::Float64}} to an object of type ComplexF64
The function `convert` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  (::Type{Complex{T}} where T<:Real)(::Any, ::Any)
   @ Base complex.jl:14
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:126
  convert(::Type{T}, ::Union{Static.StaticBool{N}, Static.StaticFloat64{N}, Static.StaticInt{N}} where N) where T<:Number
   @ Static ~/.julia/packages/Static/d7YOk/src/Static.jl:435
  ...

Stacktrace:
  [1] setindex!(A::Vector{ComplexF64}, x::ChainRulesCore.Tangent{Any, @NamedTuple{re::Float64, im::Float64}}, i::Int64)
    @ Base ./array.jl:987
  [2] setindex!
    @ ./subarray.jl:384 [inlined]
  [3] copyto_unaliased!
    @ ./abstractarray.jl:1081 [inlined]
  [4] copyto!
    @ ./abstractarray.jl:1061 [inlined]
  [5] copyto!
    @ ./broadcast.jl:966 [inlined]
  [6] copyto!
    @ ./broadcast.jl:925 [inlined]
  [7] materialize!
    @ ./broadcast.jl:883 [inlined]
  [8] materialize!
    @ ./broadcast.jl:880 [inlined]
  [9] (::SciMLSensitivity.var"#df_iip#305"{})(_out::SubArray{…}, u::Vector{…}, p::Vector{…}, t::Float64, i::Int64)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/iaNlp/src/concrete_solve.jl:603
 [10] (::SciMLSensitivity.ReverseLossCallback{…})(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/iaNlp/src/adjoint_common.jl:589
 [11] (::DiffEqCallbacks.PresetTimeFunction{…})(c::DiscreteCallback{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/uOnCI/src/preset_time.jl:30
 [12] initialize!
    @ ~/.julia/packages/DiffEqBase/kizyx/src/callbacks.jl:24 [inlined]
 [13] initialize!
    @ ~/.julia/packages/DiffEqBase/kizyx/src/callbacks.jl:18 [inlined]
 [14] initialize!
    @ ~/.julia/packages/DiffEqBase/kizyx/src/callbacks.jl:7 [inlined]
 [15] initialize_callbacks!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, initialize_save::Bool)
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/5Ctu8/src/solve.jl:724
 [16] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_discretes::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias::ODEAliasSpecifier, initializealg::BrownFullBasicInit{…}, kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/5Ctu8/src/solve.jl:580
 [17] __init (repeats 2 times)
    @ ~/.julia/packages/OrdinaryDiffEqCore/5Ctu8/src/solve.jl:11 [inlined]
 [18] #__solve#60
    @ ~/.julia/packages/OrdinaryDiffEqCore/5Ctu8/src/solve.jl:6 [inlined]
 [19] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/5Ctu8/src/solve.jl:1 [inlined]
 [20] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/kizyx/src/solve.jl:172
 [21] solve_call
    @ ~/.julia/packages/DiffEqBase/kizyx/src/solve.jl:137 [inlined]
 [22] #solve_up#38
    @ ~/.julia/packages/DiffEqBase/kizyx/src/solve.jl:626 [inlined]
 [23] solve_up
    @ ~/.julia/packages/DiffEqBase/kizyx/src/solve.jl:599 [inlined]
 [24] #solve#37
    @ ~/.julia/packages/DiffEqBase/kizyx/src/solve.jl:583 [inlined]
 [25] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::BacksolveAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, no_start::Bool, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::@Kwargs{})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/iaNlp/src/sensitivity_interface.jl:459
 [26] _adjoint_sensitivities
    @ ~/.julia/packages/SciMLSensitivity/iaNlp/src/sensitivity_interface.jl:415 [inlined]
 [27] #adjoint_sensitivities#67
    @ ~/.julia/packages/SciMLSensitivity/iaNlp/src/sensitivity_interface.jl:411 [inlined]
 [28] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#304"{})(Δ::ChainRulesCore.Tangent{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/iaNlp/src/concrete_solve.jl:696
 [29] (::Mooncake.var"#pb!!#324"{})(y_rdata::Mooncake.NoRData)
    @ Mooncake ~/.julia/packages/Mooncake/pjlQh/src/tools_for_rules.jl:492
 [30] #solve#37
    @ ~/.julia/packages/DiffEqBase/kizyx/src/solve.jl:575 [inlined]
 [31] (::Tuple{…})(_2::Any)
    @ Base.Experimental ./<missing>:0
 [32] (::MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Any}, NTuple{9, Mooncake.NoRData}}})(x::Mooncake.NoRData)
    @ MistyClosures ~/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
 [33] Pullback
    @ ~/.julia/packages/Mooncake/pjlQh/src/interpreter/reverse_mode.jl:957 [inlined]
 [34] my_fun2
    @ ~/GitHub/Research/Undef/Autodiff QuantumToolbox/issue_example.jl:27 [inlined]
 [35] (::Tuple{…})(_2::Any)
    @ Base.Experimental ./<missing>:0
 [36] (::MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Any}, Tuple{Mooncake.NoRData, Mooncake.NoRData}}})(x::Float64)
    @ MistyClosures ~/.julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
 [37] (::Mooncake.Pullback{Tuple{…}, Tuple{…}, Tuple{…}, false, 2})(dy::Float64)
    @ Mooncake ~/.julia/packages/Mooncake/pjlQh/src/interpreter/reverse_mode.jl:957
 [38] prepare_gradient_cache(::Function, ::Vararg{…}; friendly_tangents::Bool, kwargs::@Kwargs{})
    @ Mooncake ~/.julia/packages/Mooncake/pjlQh/src/interface.jl:586
 [39] prepare_gradient_cache
    @ ~/.julia/packages/Mooncake/pjlQh/src/interface.jl:581 [inlined]
 [40] prepare_gradient_nokwarg(::Val{true}, ::typeof(my_fun2), ::AutoMooncake{Nothing}, ::Vector{Float64})
    @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/6H4dc/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:124
 [41] #prepare_gradient#68
    @ ~/.julia/packages/DifferentiationInterface/6H4dc/src/first_order/gradient.jl:11 [inlined]
 [42] prepare_gradient(::typeof(my_fun2), ::AutoMooncake{Nothing}, ::Vector{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/6H4dc/src/first_order/gradient.jl:8
 [43] top-level scope
    @ ~/GitHub/Research/Undef/Autodiff QuantumToolbox/issue_example.jl:43
Some type information was truncated. Use `show(err)` to see complete types.

Differentiating just the RHS works, so I think this is a question related more to SciMLSensitivity

function test_ad(p)
    u = T[3.0, 4.0]
    du = similar(u)
    dudt!(du, u, p, 0.0)
    return sum(real, du)
end

test_ad(p)

# %%

backend = AutoMooncake(; config=nothing)
prep2 = prepare_gradient(test_ad, backend, p)
DifferentiationInterface.gradient(test_ad, prep2, backend, p) # [-4.0, 3.0]

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
(Autodiff QuantumToolbox) pkg> st
Status `~/GitHub/Research/Undef/Autodiff QuantumToolbox/Project.toml`
  [6e4b80f9] BenchmarkTools v1.6.3
  [d360d2e6] ChainRulesCore v1.26.0
  [a0c0ee7d] DifferentiationInterface v0.7.13
  [7da242da] Enzyme v0.13.114
  [26cc04aa] FiniteDifferences v0.12.33
  [f6369f11] ForwardDiff v1.3.1
  [b27dd330] GTPSA v1.4.10
  [7ed4a6bd] LinearSolve v3.57.0
  [da2b9cff] Mooncake v0.4.198
  [1dea7af3] OrdinaryDiffEq v6.105.0
  [6c2fb7c5] QuantumToolbox v0.41.0
  [c0aeaf25] SciMLOperators v1.14.1
  [1ed8b502] SciMLSensitivity v7.91.0
  [e88e6eb3] Zygote v0.7.10
  • Output of versioninfo()
Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900KF
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 16 default, 0 interactive, 8 GC (on 32 virtual cores)
Environment:
  JULIA_NUM_THREADS = 16

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions