Skip to content

Commit f6f959c

Browse files
committed
bugfix
1 parent 20faba4 commit f6f959c

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

ext/TensorKitMooncakeExt/tangent.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ Mooncake.frule!!(::Dual{typeof(getfield)}, t_dt::Dual{<:DiagOrTensorMap}, f_df::
198198

199199
# rrules
200200
function _rrule_getfield_common(t_dt::CoDual{<:DiagOrTensorMap}, field_sym::Symbol, n_args::Int)
201-
t = primal(t)
202-
dt = tangent(t)
201+
t = primal(t_dt)
202+
dt = tangent(t_dt)
203203

204204
value_primal = getfield(t, field_sym)
205205
value_dvalue = Mooncake.CoDual(
@@ -224,13 +224,13 @@ function _rrule_getfield_common(t_dt::CoDual{<:DiagOrTensorMap}, field_sym::Symb
224224
end
225225

226226
Mooncake.rrule!!(::CoDual{typeof(Mooncake.lgetfield)}, t_dt::CoDual{<:DiagOrTensorMap}, f_df::CoDual) =
227-
_rrule_getfield_common(t_dt, _field_symbol(primal(f_df)), 3)
227+
_rrule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)), 3)
228228
Mooncake.rrule!!(::CoDual{typeof(Mooncake.lgetfield)}, t_dt::CoDual{<:DiagOrTensorMap}, f_df::CoDual, o_do::CoDual) =
229-
_rrule_getfield_common(t_dt, _field_symbol(primal(f_df)), 4)
229+
_rrule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)), 4)
230230
Mooncake.rrule!!(::CoDual{typeof(getfield)}, t_dt::CoDual{<:DiagOrTensorMap}, f_df::CoDual) =
231-
_rrule_getfield_common(t_dt, _field_symbol(primal(f_df)), 3)
231+
_rrule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)), 3)
232232
Mooncake.rrule!!(::CoDual{typeof(getfield)}, t_dt::CoDual{<:DiagOrTensorMap}, f_df::CoDual, o_do::CoDual) =
233-
_rrule_getfield_common(t_dt, _field_symbol(primal(f_df)), 4)
233+
_rrule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)), 4)
234234

235235

236236
# Custom rules for constructors

ext/TensorKitMooncakeExt/tensoroperations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function Mooncake.rrule!!(
153153
C_cache = copy(C)
154154
At = if _needs_tangent(α)
155155
At = TO.tensortrace(A, p, q, false, One(), backend)
156-
add!(C, A, α, β)
156+
add!(C, At, α, β)
157157
At
158158
else
159159
TensorKit.trace_permute!(C, A, p, q, α, β, backend)

0 commit comments

Comments
 (0)