diff --git a/Project.toml b/Project.toml index 1395ead..ddec63f 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -31,6 +32,7 @@ DocStringExtensions = "0.9.3" Flux = "0.14.16, 0.15, 0.16" Functors = "0.4.11, 0.5" LinearAlgebra = "1.10" +NNlib = "0.9" Random = "1.10" Reexport = "1.2.2" SafeTestsets = "0.1" diff --git a/src/HighDimPDE.jl b/src/HighDimPDE.jl index 332ac77..139f69d 100644 --- a/src/HighDimPDE.jl +++ b/src/HighDimPDE.jl @@ -16,6 +16,17 @@ using Tracker using CUDA, cuDNN using Random using SparseArrays +using NNlib + +# Disable NNlib's `tanh_fast` substitution on CuArrays. Inside +# `NNlib.bias_act!`, the activation is wrapped into a `ComposedFunction +# (σ, +)` that is broadcast on GPU. With the polynomial `tanh_fast`, +# GPU compilation of the resulting `Broadcasted{ComposedFunction{...}}` +# kernel hits an unsupported dynamic dispatch through the +# `ComposedFunction` kwsorter, while the device intrinsic `tanh` +# compiles cleanly. Same fix as suggested for Metal in +# https://github.com/FluxML/Flux.jl/issues/2633. +NNlib.fast_act(::typeof(tanh), ::CuArray) = tanh abstract type HighDimPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end abstract type AbstractPDEProblem <: SciMLBase.AbstractSciMLProblem end diff --git a/test/qa.jl b/test/qa.jl index 4ab460f..e8cb757 100644 --- a/test/qa.jl +++ b/test/qa.jl @@ -1,9 +1,15 @@ using HighDimPDE, Aqua +using NNlib: NNlib @testset "Aqua" begin Aqua.find_persistent_tasks_deps(HighDimPDE) Aqua.test_ambiguities(HighDimPDE, recursive = false) Aqua.test_deps_compat(HighDimPDE) - Aqua.test_piracies(HighDimPDE) + # `NNlib.fast_act(::typeof(tanh), ::CuArray) = tanh` is intentional + # type piracy: a CUDA-side opt-out of NNlib's `tanh_fast` substitution + # that works around an `InvalidIRError` when broadcasting + # `ComposedFunction{tanh_fast, +}` on the GPU. See FluxML/Flux.jl#2633 + # for the analogous Metal report and resolution. + Aqua.test_piracies(HighDimPDE, treat_as_own = [NNlib.fast_act]) Aqua.test_project_extras(HighDimPDE) Aqua.test_stale_deps(HighDimPDE) Aqua.test_unbound_args(HighDimPDE)