@@ -454,91 +454,4 @@ function EnzymeRules.reverse(
454454 return (nothing , nothing , nothing )
455455end
456456
457- # single-output projections: project_hermitian!, project_antihermitian!
458- # single-output projections: project_hermitian!, project_antihermitian!
459- for (f!, project_f) in (
460- (project_hermitian!, project_hermitian),
461- (project_antihermitian!, project_antihermitian),
462- )
463- @eval begin
464- function EnzymeRules. augmented_primal (
465- config:: EnzymeRules.RevConfigWidth{1} ,
466- func:: Const{typeof($f!)} ,
467- :: Type{RT} ,
468- A:: Annotation ,
469- arg:: Annotation{TA} ,
470- alg:: Const{<:MatrixAlgebraKit.AbstractAlgorithm} ,
471- ) where {RT, TA}
472- ret = func. val (A. val, arg. val, alg. val)
473- cache_arg = (arg. val != = ret) || EnzymeRules. overwritten (config)[3 ] ? copy (ret) : nothing
474- dret = if EnzymeRules. needs_shadow (config)
475- (TA == Nothing || isa (arg, Const)) ? zero (ret) : arg. dval
476- else
477- nothing
478- end
479- primal = EnzymeRules. needs_primal (config) ? ret : nothing
480- return EnzymeRules. AugmentedReturn (primal, dret, (cache_arg, dret))
481- end
482- function EnzymeRules. reverse (
483- config:: EnzymeRules.RevConfigWidth{1} ,
484- func:: Const{typeof($f!)} ,
485- :: Type{RT} ,
486- cache,
487- A:: Annotation ,
488- arg:: Annotation ,
489- alg:: Const{<:MatrixAlgebraKit.AbstractAlgorithm} ,
490- ) where {RT}
491- cache_arg, darg = cache
492- argdval = something (darg, arg. dval)
493- if ! isa (A, Const)
494- A. dval .+ = $ project_f (argdval)
495- end
496- ! isa (arg, Const) && make_zero! (arg. dval)
497- return (nothing , nothing , nothing )
498- end
499- end
500- end
501-
502- # project_isometric! needs special handling: compute full polar decomposition
503- function EnzymeRules. augmented_primal (
504- config:: EnzymeRules.RevConfigWidth{1} ,
505- func:: Const{typeof(project_isometric!)} ,
506- :: Type{RT} ,
507- A:: Annotation ,
508- W:: Annotation{TW} ,
509- alg:: Const{<:MatrixAlgebraKit.AbstractAlgorithm} ,
510- ) where {RT, TW}
511- # Compute the full polar decomposition for the pullback
512- Ac = copy (A. val)
513- m, n = size (A. val)
514- P = similar (A. val, n, n)
515- WP = left_polar! (Ac, (W. val, P), alg. val)
516- cache_WP = EnzymeRules. overwritten (config)[3 ] ? copy .(WP) : nothing
517- dret = if EnzymeRules. needs_shadow (config)
518- (TW == Nothing || isa (W, Const)) ? zero (WP[1 ]) : W. dval
519- else
520- nothing
521- end
522- primal = EnzymeRules. needs_primal (config) ? WP[1 ] : nothing
523- return EnzymeRules. AugmentedReturn (primal, dret, (cache_WP, dret))
524- end
525- function EnzymeRules. reverse (
526- config:: EnzymeRules.RevConfigWidth{1} ,
527- func:: Const{typeof(project_isometric!)} ,
528- :: Type{RT} ,
529- cache,
530- A:: Annotation ,
531- W:: Annotation ,
532- alg:: Const{<:MatrixAlgebraKit.AbstractAlgorithm} ,
533- ) where {RT}
534- cache_WP, dW = cache
535- Aval = nothing
536- WPval = something (cache_WP, (W. val, cache_WP[2 ]))
537- if ! isa (A, Const)
538- left_polar_pullback! (A. dval, Aval, WPval, (dW, nothing ))
539- end
540- ! isa (W, Const) && make_zero! (W. dval)
541- return (nothing , nothing , nothing )
542- end
543-
544457end
0 commit comments