From 1c0769b4e9e67123c6c8d7c63b3f69e04d32e1f7 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Wed, 6 May 2026 00:25:57 -0400 Subject: [PATCH 1/2] Override NNlib.fast_act for CuArray to fix GPU tanh compilation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The DeepSplitting GPU tests with `Dense(d, hls, tanh)` activations have been failing on the self-hosted GPU runner since the SciMLBase v3 / CUDA 6 / NNlib 0.9.3x bump: InvalidIRError: compiling MethodInstance for gpu_broadcast_kernel_cartesian Reason: unsupported dynamic function invocation (call to var"#_#103"(kw::Base.Pairs{...}, c::ComposedFunction, x...) @ Base operators.jl:1041) Inside `NNlib.bias_act!`, the activation gets wrapped as a `ComposedFunction(fast_act(σ, x), +)` and the result is broadcast on GPU. For `σ = tanh`, `fast_act` substitutes `tanh_fast` (a polynomial approximation), and the resulting `ComposedFunction{tanh_fast, +}` broadcast kernel hits a dynamic dispatch on the `ComposedFunction` kwsorter that the GPU compiler rejects. The same construction with the device intrinsic `tanh` compiles cleanly — verified by Carlo Lucibello on Metal in FluxML/Flux.jl#2633. NNlib already exposes a per-array-type opt-out for exactly this case (see `NNlib.fast_act` in NNlib/src/activations.jl:897-906). Add the CuArray override so `Dense(_, _, tanh)` falls back to `Base.tanh` on the GPU. NNlib is added as a direct dep so the override is unambiguous rather than relying on Flux's transitive load. This restores the 4 previously-failing DeepSplitting GPU tests (`allen cahn`, `Black-Scholes Equation with Default Risk`, `replicator mutator`, `allen cahn non local - Neumann BC`) — the same set that was failing on `main` since PR #137. Co-Authored-By: Chris Rackauckas --- Project.toml | 2 ++ src/HighDimPDE.jl | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/Project.toml b/Project.toml index 1395ead..ddec63f 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -31,6 +32,7 @@ DocStringExtensions = "0.9.3" Flux = "0.14.16, 0.15, 0.16" Functors = "0.4.11, 0.5" LinearAlgebra = "1.10" +NNlib = "0.9" Random = "1.10" Reexport = "1.2.2" SafeTestsets = "0.1" diff --git a/src/HighDimPDE.jl b/src/HighDimPDE.jl index 332ac77..139f69d 100644 --- a/src/HighDimPDE.jl +++ b/src/HighDimPDE.jl @@ -16,6 +16,17 @@ using Tracker using CUDA, cuDNN using Random using SparseArrays +using NNlib + +# Disable NNlib's `tanh_fast` substitution on CuArrays. Inside +# `NNlib.bias_act!`, the activation is wrapped into a `ComposedFunction +# (σ, +)` that is broadcast on GPU. With the polynomial `tanh_fast`, +# GPU compilation of the resulting `Broadcasted{ComposedFunction{...}}` +# kernel hits an unsupported dynamic dispatch through the +# `ComposedFunction` kwsorter, while the device intrinsic `tanh` +# compiles cleanly. Same fix as suggested for Metal in +# https://github.com/FluxML/Flux.jl/issues/2633. +NNlib.fast_act(::typeof(tanh), ::CuArray) = tanh abstract type HighDimPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end abstract type AbstractPDEProblem <: SciMLBase.AbstractSciMLProblem end From 712fea1124a2d741569d4b9c0b98eddc44b34889 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Wed, 6 May 2026 01:07:00 -0400 Subject: [PATCH 2/2] Whitelist NNlib.fast_act override in Aqua's piracy check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The intentional `NNlib.fast_act(::typeof(tanh), ::CuArray) = tanh` opt-out from the previous commit is type piracy by design — that's exactly the shape NNlib's per-array-type fast_act API was built for. Tell `Aqua.test_piracies` to ignore methods we add to `NNlib.fast_act`. Co-Authored-By: Chris Rackauckas --- test/qa.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/qa.jl b/test/qa.jl index 4ab460f..e8cb757 100644 --- a/test/qa.jl +++ b/test/qa.jl @@ -1,9 +1,15 @@ using HighDimPDE, Aqua +using NNlib: NNlib @testset "Aqua" begin Aqua.find_persistent_tasks_deps(HighDimPDE) Aqua.test_ambiguities(HighDimPDE, recursive = false) Aqua.test_deps_compat(HighDimPDE) - Aqua.test_piracies(HighDimPDE) + # `NNlib.fast_act(::typeof(tanh), ::CuArray) = tanh` is intentional + # type piracy: a CUDA-side opt-out of NNlib's `tanh_fast` substitution + # that works around an `InvalidIRError` when broadcasting + # `ComposedFunction{tanh_fast, +}` on the GPU. See FluxML/Flux.jl#2633 + # for the analogous Metal report and resolution. + Aqua.test_piracies(HighDimPDE, treat_as_own = [NNlib.fast_act]) Aqua.test_project_extras(HighDimPDE) Aqua.test_stale_deps(HighDimPDE) Aqua.test_unbound_args(HighDimPDE)