Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ updates:
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
ignore:
- dependency-name: "crate-ci/typos"
update-types: ["version-update:semver-patch", "version-update:semver-minor"]
- package-ecosystem: "julia"
directories:
- "/"
Expand Down
21 changes: 21 additions & 0 deletions .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: "Documentation"

on:
pull_request:
branches:
- main
- 'release-'
push:
branches:
- main
tags: '*'

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref_name != github.event.repository.default_branch || github.ref != 'refs/tags/v*' }}

jobs:
docs:
name: Documentation
uses: "SciML/.github/.github/workflows/documentation.yml@v1"
secrets: "inherit"
33 changes: 33 additions & 0 deletions .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Downgrade

on:
pull_request:
branches:
- main
- 'release-'
paths-ignore:
- 'docs/**'
push:
branches:
- main
paths-ignore:
- 'docs/**'

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref_name != github.event.repository.default_branch || github.ref != 'refs/tags/v*' }}

jobs:
downgrade:
name: Downgrade
strategy:
fail-fast: false
matrix:
group:
- Core
uses: "SciML/.github/.github/workflows/downgrade.yml@v1"
with:
group: "${{ matrix.group }}"
julia-version: "lts"
skip: "Pkg,TOML"
secrets: "inherit"
12 changes: 3 additions & 9 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@ on:

jobs:
runic:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v3
with:
version: '1'
- uses: fredrikekre/runic-action@v1
with:
version: '1'
name: Runic
uses: "SciML/.github/.github/workflows/runic.yml@v1"
secrets: "inherit"
9 changes: 9 additions & 0 deletions .github/workflows/SpellCheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: Spell Check

on: [pull_request]

jobs:
typos-check:
name: Spell Check with Typos
uses: "SciML/.github/.github/workflows/spellcheck.yml@v1"
secrets: "inherit"
20 changes: 0 additions & 20 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,3 @@ jobs:
julia-version: "${{ matrix.version }}"
julia-arch: "${{ matrix.arch }}"
secrets: "inherit"

docs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v3
with:
version: '1'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
- run: |
julia --project=docs -e '
using Documenter: DocMeta, doctest
using FunctionWrappersWrappers
DocMeta.setdocmeta!(FunctionWrappersWrappers, :DocTestSetup, :(using FunctionWrappersWrappers); recursive=true)
doctest(FunctionWrappersWrappers)'
136 changes: 69 additions & 67 deletions ext/FunctionWrappersWrappersEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ EnzymeCore.EnzymeRules.inactive_type(::Type{NoCacheStorage}) = true
# `ForwardMode` matching the outer config so the delegated call inherits
# those flags.
@inline function _fwd_mode(
::Val{NeedsPrimal}, ::Val{RuntimeActivity}, ::Val{StrongZero}
) where {NeedsPrimal, RuntimeActivity, StrongZero}
::Val{NeedsPrimal}, ::Val{RuntimeActivity}, ::Val{StrongZero}
) where {NeedsPrimal, RuntimeActivity, StrongZero}
mode = NeedsPrimal ? ForwardWithPrimal : Forward
RuntimeActivity && (mode = Enzyme.set_runtime_activity(mode))
StrongZero && (mode = Enzyme.set_strong_zero(mode))
Expand All @@ -63,11 +63,11 @@ end
# meaningful tangent — so the function shadow is ignored and the inner
# `Enzyme.autodiff` call uses `Const(f_orig)`.
function EnzymeRules.forward(
::EnzymeRules.FwdConfig{false, true, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation{T}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, W, N, RuntimeActivity, StrongZero}
::EnzymeRules.FwdConfig{false, true, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation{T}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, W, N, RuntimeActivity, StrongZero}
f_orig = unwrap(func.val)
mode = _fwd_mode(Val(false), Val(RuntimeActivity), Val(StrongZero))
if W == 1
Expand All @@ -86,11 +86,11 @@ end

# Both primal and shadow (ForwardWithPrimal mode)
function EnzymeRules.forward(
::EnzymeRules.FwdConfig{true, true, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation{T}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, W, N, RuntimeActivity, StrongZero}
::EnzymeRules.FwdConfig{true, true, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation{T}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, W, N, RuntimeActivity, StrongZero}
f_orig = unwrap(func.val)
pargs = ntuple(i -> args[i].val, Val(N))
primal = f_orig(pargs...)::T
Expand All @@ -113,11 +113,11 @@ end

# Primal only (Const return type) — width-independent
function EnzymeRules.forward(
::EnzymeRules.FwdConfig{true, false, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation},
args::Vararg{EnzymeCore.Annotation, N}
) where {W, N, RuntimeActivity, StrongZero}
::EnzymeRules.FwdConfig{true, false, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation},
args::Vararg{EnzymeCore.Annotation, N}
) where {W, N, RuntimeActivity, StrongZero}
f_orig = unwrap(func.val)
pargs = ntuple(i -> args[i].val, Val(N))
return f_orig(pargs...)
Expand All @@ -137,13 +137,13 @@ end
# IMPORTANT: forward the `RuntimeActivity` and `StrongZero` flags from the
# outer config into the delegated `Enzyme.autodiff` call. Prior to this
# fix the rule hard-coded `Forward`, silently dropping
# `set_runtime_activity(Forward)` on the way down into `f_orig`.
# `set_runtime_activity(Forward)` on the way down into `f_orig`.
function EnzymeRules.forward(
::EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation},
args::Vararg{EnzymeCore.Annotation, N}
) where {W, N, RuntimeActivity, StrongZero}
::EnzymeRules.FwdConfig{false, false, W, RuntimeActivity, StrongZero},
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Annotation},
args::Vararg{EnzymeCore.Annotation, N}
) where {W, N, RuntimeActivity, StrongZero}
f_orig = unwrap(func.val)
mode = _fwd_mode(Val(false), Val(RuntimeActivity), Val(StrongZero))
Enzyme.autodiff(mode, Const(f_orig), Const, args...)
Expand All @@ -155,11 +155,11 @@ end
# =============================================================================

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Active{T}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, N}
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Active{T}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, N}
f_orig = unwrap(func.val)
pargs = ntuple(i -> args[i].val, Val(N))
result = f_orig(pargs...)::T
Expand All @@ -175,11 +175,11 @@ end
# return). Just run the primal for its side effects; no tape is needed because
# the reverse pass has nothing to propagate back from the return.
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Const},
args::Vararg{EnzymeCore.Annotation, N}
) where {N}
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Const},
args::Vararg{EnzymeCore.Annotation, N}
) where {N}
f_orig = unwrap(func.val)
pargs = ntuple(i -> args[i].val, Val(N))
f_orig(pargs...)
Expand All @@ -189,11 +189,11 @@ end
# Duplicated / BatchDuplicated return: record the primal so that reverse has
# it available when propagating dret through the arguments.
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Duplicated{T}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, N}
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.Duplicated{T}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, N}
f_orig = unwrap(func.val)
pargs = ntuple(i -> args[i].val, Val(N))
primal = f_orig(pargs...)::T
Expand All @@ -205,11 +205,11 @@ function EnzymeRules.augmented_primal(
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.BatchDuplicated{T, W}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, W, N}
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
RT::Type{<:EnzymeCore.BatchDuplicated{T, W}},
args::Vararg{EnzymeCore.Annotation, N}
) where {T, W, N}
f_orig = unwrap(func.val)
pargs = ntuple(i -> args[i].val, Val(N))
primal = f_orig(pargs...)::T
Expand All @@ -235,41 +235,43 @@ end
# function, then scale by dret. This avoids type-inference issues that arise
# from calling autodiff(Reverse, Const{Any}(...), ...).
@generated function _fww_reverse_grads(
mode, f_orig, dret_val::T, args::Vararg{EnzymeCore.Active, N}
) where {T, N}
mode, f_orig, dret_val::T, args::Vararg{EnzymeCore.Active, N}
) where {T, N}
# Build forward-mode calls for each partial derivative
exprs = []
for i in 1:N
seeds = [j == i ? :(one(eltype(typeof(args[$j])))) : :(zero(eltype(typeof(args[$j])))) for j in 1:N]
dups = [:(Duplicated(args[$j].val, $(seeds[j]))) for j in 1:N]
Ti = :(eltype(typeof(args[$i])))
push!(exprs, quote
fwd = Enzyme.autodiff(mode, Const(f_orig), Duplicated{$T}, $(dups...))
$Ti(fwd[1] * dret_val)::$Ti
end)
push!(
exprs, quote
fwd = Enzyme.autodiff(mode, Const(f_orig), Duplicated{$T}, $(dups...))
$Ti(fwd[1] * dret_val)::$Ti
end
)
end
return Expr(:tuple, exprs...)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
dret::EnzymeCore.Active{T},
tape,
args::Vararg{EnzymeCore.Active, N}
) where {T, N}
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
dret::EnzymeCore.Active{T},
tape,
args::Vararg{EnzymeCore.Active, N}
) where {T, N}
f_orig = unwrap(func.val)
return _fww_reverse_grads(_fwd_mode_from_rev(config), f_orig, dret.val, args...)
end

# Handle mixed Active/Const args: return nothing for Const, gradient for Active
function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
dret::EnzymeCore.Active,
tape,
args::Vararg{EnzymeCore.Annotation, N}
) where {N}
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
dret::EnzymeCore.Active,
tape,
args::Vararg{EnzymeCore.Annotation, N}
) where {N}
f_orig = unwrap(func.val)
dret_val = dret.val
mode = _fwd_mode_from_rev(config)
Expand Down Expand Up @@ -303,12 +305,12 @@ end
# `BatchDuplicated` args return `nothing` because their gradients are
# accumulated in-place by the `Enzyme.autodiff(Reverse, …)` call above.
function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
dret::Type{<:EnzymeCore.Const},
tape,
args::Vararg{EnzymeCore.Annotation, N}
) where {N}
config::EnzymeRules.RevConfig,
func::EnzymeCore.Annotation{<:FunctionWrappersWrapper},
dret::Type{<:EnzymeCore.Const},
tape,
args::Vararg{EnzymeCore.Annotation, N}
) where {N}
f_orig = unwrap(func.val)
# Only worth invoking Enzyme.autodiff when at least one arg is
# Duplicated/BatchDuplicated — otherwise there's nothing to accumulate.
Expand Down
11 changes: 5 additions & 6 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ end
# du_shadow gives the accumulation into u_shadow:
# u_shadow[1] += a*y + b*2x
# u_shadow[2] += a*x + b*3y^2
f!(du, u) = (du[1] = u[1]*u[2]; du[2] = u[1]^2 + u[2]^3; nothing)
f!(du, u) = (du[1] = u[1] * u[2]; du[2] = u[1]^2 + u[2]^3; nothing)
fww = FunctionWrappersWrapper(
f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,)
)
Expand All @@ -300,9 +300,9 @@ end
Reverse, Const(fww), Const,
Duplicated(du, du_shadow), Duplicated(u, u_shadow)
)
@test du ≈ [x*y, x^2 + y^3]
@test u_shadow[1] ≈ a*y + b*2*x # 5 + 2 = 7
@test u_shadow[2] ≈ a*x + b*3*y^2 # 2 + 37.5 = 39.5
@test du ≈ [x * y, x^2 + y^3]
@test u_shadow[1] ≈ a * y + b * 2 * x # 5 + 2 = 7
@test u_shadow[2] ≈ a * x + b * 3 * y^2 # 2 + 37.5 = 39.5
end

@testset "Enzyme ReverseWithPrimal: IIP with Duplicated args" begin
Expand Down Expand Up @@ -432,7 +432,7 @@ end
)

du = [0.0]; du_shadow = [1.0]
u = [3.0]; u_shadow = [0.0]
u = [3.0]; u_shadow = [0.0]

rconfig = EnzymeRules.RevConfig{false, false, 1, (false, false), false, false}()
aug = EnzymeRules.augmented_primal(
Expand Down Expand Up @@ -515,4 +515,3 @@ end
@test du_h[1] ≈ 2.0 * 3.5 # primal: u[1] * t = 7.0
@test ddu_h[1] ≈ 2.0 # ∂(u[1]*t)/∂t = u[1]
end

Loading