Skip to content

Commit fbfcbcc

Browse files
committed
remove enzyme
1 parent 62dd5ca commit fbfcbcc

1 file changed

Lines changed: 0 additions & 87 deletions

File tree

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -454,91 +454,4 @@ function EnzymeRules.reverse(
454454
return (nothing, nothing, nothing)
455455
end
456456

457-
# single-output projections: project_hermitian!, project_antihermitian!
458-
# single-output projections: project_hermitian!, project_antihermitian!
459-
for (f!, project_f) in (
460-
(project_hermitian!, project_hermitian),
461-
(project_antihermitian!, project_antihermitian),
462-
)
463-
@eval begin
464-
function EnzymeRules.augmented_primal(
465-
config::EnzymeRules.RevConfigWidth{1},
466-
func::Const{typeof($f!)},
467-
::Type{RT},
468-
A::Annotation,
469-
arg::Annotation{TA},
470-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
471-
) where {RT, TA}
472-
ret = func.val(A.val, arg.val, alg.val)
473-
cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing
474-
dret = if EnzymeRules.needs_shadow(config)
475-
(TA == Nothing || isa(arg, Const)) ? zero(ret) : arg.dval
476-
else
477-
nothing
478-
end
479-
primal = EnzymeRules.needs_primal(config) ? ret : nothing
480-
return EnzymeRules.AugmentedReturn(primal, dret, (cache_arg, dret))
481-
end
482-
function EnzymeRules.reverse(
483-
config::EnzymeRules.RevConfigWidth{1},
484-
func::Const{typeof($f!)},
485-
::Type{RT},
486-
cache,
487-
A::Annotation,
488-
arg::Annotation,
489-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
490-
) where {RT}
491-
cache_arg, darg = cache
492-
argdval = something(darg, arg.dval)
493-
if !isa(A, Const)
494-
A.dval .+= $project_f(argdval)
495-
end
496-
!isa(arg, Const) && make_zero!(arg.dval)
497-
return (nothing, nothing, nothing)
498-
end
499-
end
500-
end
501-
502-
# project_isometric! needs special handling: compute full polar decomposition
503-
function EnzymeRules.augmented_primal(
504-
config::EnzymeRules.RevConfigWidth{1},
505-
func::Const{typeof(project_isometric!)},
506-
::Type{RT},
507-
A::Annotation,
508-
W::Annotation{TW},
509-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
510-
) where {RT, TW}
511-
# Compute the full polar decomposition for the pullback
512-
Ac = copy(A.val)
513-
m, n = size(A.val)
514-
P = similar(A.val, n, n)
515-
WP = left_polar!(Ac, (W.val, P), alg.val)
516-
cache_WP = EnzymeRules.overwritten(config)[3] ? copy.(WP) : nothing
517-
dret = if EnzymeRules.needs_shadow(config)
518-
(TW == Nothing || isa(W, Const)) ? zero(WP[1]) : W.dval
519-
else
520-
nothing
521-
end
522-
primal = EnzymeRules.needs_primal(config) ? WP[1] : nothing
523-
return EnzymeRules.AugmentedReturn(primal, dret, (cache_WP, dret))
524-
end
525-
function EnzymeRules.reverse(
526-
config::EnzymeRules.RevConfigWidth{1},
527-
func::Const{typeof(project_isometric!)},
528-
::Type{RT},
529-
cache,
530-
A::Annotation,
531-
W::Annotation,
532-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
533-
) where {RT}
534-
cache_WP, dW = cache
535-
Aval = nothing
536-
WPval = something(cache_WP, (W.val, cache_WP[2]))
537-
if !isa(A, Const)
538-
left_polar_pullback!(A.dval, Aval, WPval, (dW, nothing))
539-
end
540-
!isa(W, Const) && make_zero!(W.dval)
541-
return (nothing, nothing, nothing)
542-
end
543-
544457
end

0 commit comments

Comments
 (0)