Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -31,6 +32,7 @@ DocStringExtensions = "0.9.3"
Flux = "0.14.16, 0.15, 0.16"
Functors = "0.4.11, 0.5"
LinearAlgebra = "1.10"
NNlib = "0.9"
Random = "1.10"
Reexport = "1.2.2"
SafeTestsets = "0.1"
Expand Down
11 changes: 11 additions & 0 deletions src/HighDimPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ using Tracker
using CUDA, cuDNN
using Random
using SparseArrays
using NNlib

# Disable NNlib's `tanh_fast` substitution on CuArrays. Inside
# `NNlib.bias_act!`, the activation is wrapped into a `ComposedFunction
# (σ, +)` that is broadcast on GPU. With the polynomial `tanh_fast`,
# GPU compilation of the resulting `Broadcasted{ComposedFunction{...}}`
# kernel hits an unsupported dynamic dispatch through the
# `ComposedFunction` kwsorter, while the device intrinsic `tanh`
# compiles cleanly. Same fix as suggested for Metal in
# https://github.com/FluxML/Flux.jl/issues/2633.
NNlib.fast_act(::typeof(tanh), ::CuArray) = tanh

abstract type HighDimPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
abstract type AbstractPDEProblem <: SciMLBase.AbstractSciMLProblem end
Expand Down
8 changes: 7 additions & 1 deletion test/qa.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
using HighDimPDE, Aqua
using NNlib: NNlib
@testset "Aqua" begin
Aqua.find_persistent_tasks_deps(HighDimPDE)
Aqua.test_ambiguities(HighDimPDE, recursive = false)
Aqua.test_deps_compat(HighDimPDE)
Aqua.test_piracies(HighDimPDE)
# `NNlib.fast_act(::typeof(tanh), ::CuArray) = tanh` is intentional
# type piracy: a CUDA-side opt-out of NNlib's `tanh_fast` substitution
# that works around an `InvalidIRError` when broadcasting
# `ComposedFunction{tanh_fast, +}` on the GPU. See FluxML/Flux.jl#2633
# for the analogous Metal report and resolution.
Aqua.test_piracies(HighDimPDE, treat_as_own = [NNlib.fast_act])
Aqua.test_project_extras(HighDimPDE)
Aqua.test_stale_deps(HighDimPDE)
Aqua.test_unbound_args(HighDimPDE)
Expand Down
Loading