diff --git a/Project.toml b/Project.toml index 75a3004..020412b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FunctionWrappersWrappers" uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf" authors = ["Chris Elrod and contributors"] -version = "0.1.5" +version = "1.0.0" [deps] FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" diff --git a/src/FunctionWrappersWrappers.jl b/src/FunctionWrappersWrappers.jl index 5038735..f358ee8 100644 --- a/src/FunctionWrappersWrappers.jl +++ b/src/FunctionWrappersWrappers.jl @@ -4,17 +4,151 @@ using FunctionWrappers import TruncatedStacktraces export FunctionWrappersWrapper, unwrap, wrapped_signatures, wrapped_return_types +export NoCache, SingleCache, DictCache +export Strict, AllowAll, AllowNonIsBits -struct FunctionWrappersWrapper{FW, FB} +# ============================================================================ +# Cache modes: control how fallback FunctionWrappers are cached +# ============================================================================ +abstract type AbstractCacheMode end + +""" + NoCache() + +No caching — every fallback call goes through dynamic dispatch (`obj[](arg...)`), +incurring 1 allocation per call. +""" +struct NoCache <: AbstractCacheMode end + +""" + SingleCache() + +Cache a single `FunctionWrapper` for the last-seen argument types. After the first +fallback call, subsequent calls with the same types are zero-allocation. If called with +different types, the cache is replaced (1 alloc on miss). This is the recommended default. +""" +struct SingleCache <: AbstractCacheMode end + +""" + DictCache() + +Cache `FunctionWrapper`s in a `Dict` keyed by argument type. Handles multiple +non-isbits types without thrashing. Slightly higher lookup overhead than `SingleCache`. +""" +struct DictCache <: AbstractCacheMode end + +# ============================================================================ +# Fallback policies: control when fallback is allowed +# ============================================================================ +abstract type AbstractFallbackPolicy end + +""" + Strict() + +Never fall back — throw `NoFunctionWrapperFoundError` if no wrapper matches. +""" +struct Strict <: AbstractFallbackPolicy end + +""" + AllowAll() + +Always fall back to the original function when no wrapper matches. +""" +struct AllowAll <: AbstractFallbackPolicy end + +""" + AllowNonIsBits() + +Fall back only when argument types contain non-isbits elements (e.g., `BigFloat`, +`SparseConnectivityTracer` types). Throws `NoFunctionWrapperFoundError` for isbits +type mismatches (e.g., `Float32` when `Float64` was expected), which catches bugs. +This is the recommended default. +""" +struct AllowNonIsBits <: AbstractFallbackPolicy end + +# ============================================================================ +# Cache storage types +# ============================================================================ +struct NoCacheStorage end +mutable struct SingleCacheStorage + cached::Any # Union{Nothing, FunctionWrapper} + SingleCacheStorage() = new(nothing) +end +struct DictCacheStorage + cache::Dict{DataType, Any} + DictCacheStorage() = new(Dict{DataType, Any}()) +end + +_make_cache_storage(::NoCache) = NoCacheStorage() +_make_cache_storage(::SingleCache) = SingleCacheStorage() +_make_cache_storage(::DictCache) = DictCacheStorage() + +# ============================================================================ +# Main type +# ============================================================================ + +""" + FunctionWrappersWrapper{FW, P, CS} + +A wrapper around a tuple of `FunctionWrapper`s that dispatches calls to the +matching wrapper based on argument types. When no wrapper matches, behavior is +controlled by the fallback policy `P` and cache mode `CS`. + +# Type parameters +- `FW`: Tuple type of `FunctionWrapper`s +- `P`: Fallback policy (`Strict`, `AllowAll`, or `AllowNonIsBits`) +- `CS`: Cache storage type (`NoCacheStorage`, `SingleCacheStorage`, `DictCacheStorage`) +""" +struct FunctionWrappersWrapper{FW, P, CS} fw::FW + cache_storage::CS + function FunctionWrappersWrapper{FW, P, CS}( + fw::FW, cs::CS + ) where {FW, P, CS} + return new{FW, P, CS}(fw, cs) + end end TruncatedStacktraces.@truncate_stacktrace FunctionWrappersWrapper -function (fww::FunctionWrappersWrapper{FW, FB})(args::Vararg{Any, K}) where {FW, K, FB} +""" + FunctionWrappersWrapper(f, argtypes, rettypes; cache=SingleCache(), policy=AllowNonIsBits()) + +Create a `FunctionWrappersWrapper` with configurable fallback behavior. + +# Arguments +- `f`: The function to wrap +- `argtypes`: Tuple of argument type signatures (e.g., `(Tuple{Float64, Float64},)`) +- `rettypes`: Tuple of return types (e.g., `(Float64,)`) + +# Keywords +- `cache`: Cache mode for fallback path — `NoCache()`, `SingleCache()` (default), or `DictCache()` +- `policy`: Fallback policy — `Strict()`, `AllowAll()`, or `AllowNonIsBits()` (default) +""" +function FunctionWrappersWrapper( + f::F, argtypes::Tuple{Vararg{Any, K}}, rettypes::Tuple{Vararg{Type, K}}; + cache::AbstractCacheMode = SingleCache(), + policy::AbstractFallbackPolicy = AllowNonIsBits() + ) where {F, K} + fwt = map(argtypes, rettypes) do A, R + FunctionWrappers.FunctionWrapper{R, A}(f) + end + cs = _make_cache_storage(cache) + return FunctionWrappersWrapper{typeof(fwt), typeof(policy), typeof(cs)}(fwt, cs) +end + + +# ============================================================================ +# Call dispatch — entry point +# ============================================================================ + +function (fww::FunctionWrappersWrapper{FW, P, CS})( + args::Vararg{Any, K} + ) where {FW, K, P, CS} return _call(fww.fw, args, fww) end +# Match path: try each FunctionWrapper in order function _call( fw::Tuple{FunctionWrappers.FunctionWrapper{R, A}, Vararg}, arg::A, fww::FunctionWrappersWrapper @@ -28,6 +162,10 @@ function _call( return _call(Base.tail(fw), arg, fww) end +# ============================================================================ +# Fallback — Strict: always error +# ============================================================================ + const NO_FUNCTIONWRAPPER_FOUND_MESSAGE = "No matching function wrapper was found!" struct NoFunctionWrapperFoundError <: Exception end @@ -36,68 +174,96 @@ function Base.showerror(io::IO, e::NoFunctionWrapperFoundError) return print(io, NO_FUNCTIONWRAPPER_FOUND_MESSAGE) end -function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, false}) +function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, Strict}) throw(NoFunctionWrapperFoundError()) end -function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, true}) - return first(fww.fw).obj[](arg...) + +# ============================================================================ +# Fallback — AllowAll: always fall back +# ============================================================================ + +function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, AllowAll}) + return _fallback(arg, fww) end -function FunctionWrappersWrapper( - f::F, argtypes::Tuple{Vararg{Any, K}}, rettypes::Tuple{Vararg{Type, K}}, - fallback::Val{FB} = Val{false}() - ) where {F, K, FB} - fwt = map(argtypes, rettypes) do A, R - FunctionWrappers.FunctionWrapper{R, A}(f) +# ============================================================================ +# Fallback — AllowNonIsBits: fall back only for non-isbits arg types +# ============================================================================ + +function _call( + ::Tuple{}, arg::A, fww::FunctionWrappersWrapper{<:Any, AllowNonIsBits} + ) where {A} + if _has_non_isbits_args(A) + return _fallback(arg, fww) end - return FunctionWrappersWrapper{typeof(fwt), FB}(fwt) + throw(NoFunctionWrapperFoundError()) end -""" - unwrap(fww::FunctionWrappersWrapper) - -Return the original function that was wrapped. This is useful for debugging -wrapped functions - you can use the returned function with debugging tools -like Debugger.jl or Infiltrator.jl. +@generated function _has_non_isbits_args(::Type{T}) where {T <: Tuple} + checks = [] + for P in T.parameters + if P <: AbstractArray + push!(checks, :(!(isbitstype($(eltype(P)))))) + else + push!(checks, :(!(isbitstype($P)))) + end + end + isempty(checks) && return :(false) + return Expr(:||, checks...) +end -# Example +# ============================================================================ +# Fallback execution — dispatch on cache storage type +# ============================================================================ -```julia -using FunctionWrappersWrappers +# --- NoCache: direct dynamic dispatch every time --- +function _fallback(arg, fww::FunctionWrappersWrapper{<:Any, <:Any, NoCacheStorage}) + return first(fww.fw).obj[](arg...) +end -# Create a wrapped function -fww = FunctionWrappersWrapper(sin, (Tuple{Float64},), (Float64,)) +# --- SingleCache: cache one FunctionWrapper for the last arg types --- +function _fallback( + arg::A, fww::FunctionWrappersWrapper{<:Any, <:Any, SingleCacheStorage} + ) where {A} + cached = fww.cache_storage.cached + if cached isa FunctionWrappers.FunctionWrapper{Any, A} + return cached(arg...) + end + f = first(fww.fw).obj[] + new_fw = FunctionWrappers.FunctionWrapper{Any, A}(f) + fww.cache_storage.cached = new_fw + return new_fw(arg...) +end -# Get the original function for debugging -f = unwrap(fww) # Returns sin +# --- DictCache: cache FunctionWrappers keyed by arg type --- +function _fallback( + arg::A, fww::FunctionWrappersWrapper{<:Any, <:Any, DictCacheStorage} + ) where {A} + cached = get(fww.cache_storage.cache, A, nothing) + if cached isa FunctionWrappers.FunctionWrapper{Any, A} + return cached(arg...) + end + f = first(fww.fw).obj[] + new_fw = FunctionWrappers.FunctionWrapper{Any, A}(f) + fww.cache_storage.cache[A] = new_fw + return new_fw(arg...) +end -# Now you can debug with Debugger.jl: -# using Debugger -# @enter f(0.5) +# ============================================================================ +# Introspection +# ============================================================================ -# Or use Infiltrator.jl in your original function definition -``` +""" + unwrap(fww::FunctionWrappersWrapper) -See also: [`wrapped_signatures`](@ref), [`wrapped_return_types`](@ref) +Return the original function that was wrapped. """ unwrap(fww::FunctionWrappersWrapper) = first(fww.fw).obj[] """ wrapped_signatures(fww::FunctionWrappersWrapper) -Return a tuple of the argument type signatures that the `FunctionWrappersWrapper` -can dispatch on. Each element is a `Tuple` type representing the argument types. - -# Example - -```julia -using FunctionWrappersWrappers - -fww = FunctionWrappersWrapper(+, (Tuple{Float64, Float64}, Tuple{Int, Int}), (Float64, Int)) -wrapped_signatures(fww) # Returns (Tuple{Float64, Float64}, Tuple{Int, Int}) -``` - -See also: [`unwrap`](@ref), [`wrapped_return_types`](@ref) +Return a tuple of the argument type signatures that the wrapper can dispatch on. """ function wrapped_signatures(fww::FunctionWrappersWrapper) return map(fw -> typeof(fw).parameters[2], fww.fw) @@ -107,30 +273,19 @@ end wrapped_return_types(fww::FunctionWrappersWrapper) Return a tuple of the return types for each wrapped function signature. - -# Example - -```julia -using FunctionWrappersWrappers - -fww = FunctionWrappersWrapper(+, (Tuple{Float64, Float64}, Tuple{Int, Int}), (Float64, Int)) -wrapped_return_types(fww) # Returns (Float64, Int64) -``` - -See also: [`unwrap`](@ref), [`wrapped_signatures`](@ref) """ function wrapped_return_types(fww::FunctionWrappersWrapper) return map(fw -> typeof(fw).parameters[1], fww.fw) end +# ============================================================================ +# Precompilation +# ============================================================================ + using PrecompileTools @setup_workload begin @compile_workload begin - # Precompile common use cases with Float64 and Int types - # These are the most common type combinations for numerical computations - - # Binary operation with multiple type combinations (common pattern) fw_binary = FunctionWrappersWrapper( +, (Tuple{Float64, Float64}, Tuple{Int, Int}), @@ -139,7 +294,6 @@ using PrecompileTools fw_binary(1.0, 2.0) fw_binary(1, 2) - # Unary operation with multiple types (common pattern) fw_unary = FunctionWrappersWrapper( abs, (Tuple{Float64}, Tuple{Int}), @@ -148,7 +302,6 @@ using PrecompileTools fw_unary(1.0) fw_unary(1) - # Precompile introspection functions unwrap(fw_unary) wrapped_signatures(fw_binary) wrapped_return_types(fw_binary) diff --git a/test/basictests.jl b/test/basictests.jl index b8ecfaa..592049f 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -19,7 +19,6 @@ using Test end @testset "Type inference" begin - # Test return type inference fwplus = FunctionWrappersWrapper( +, (Tuple{Float64, Float64}, Tuple{Int, Int}), ( Float64, Int, @@ -37,7 +36,6 @@ end end @testset "Introspection functions" begin - # Test with a simple function fwsin = FunctionWrappersWrapper(sin, (Tuple{Float64},), (Float64,)) @testset "unwrap" begin @@ -56,7 +54,6 @@ end @test rets == (Float64,) end - # Test with multiple signatures fwplus = FunctionWrappersWrapper( +, (Tuple{Float64, Float64}, Tuple{Int, Int}), ( Float64, Int, @@ -79,7 +76,6 @@ end @test rets == (Float64, Int) end - # Test with a custom function my_func(x) = x^2 fwcustom = FunctionWrappersWrapper( my_func, (Tuple{Float64}, Tuple{Int}), ( @@ -94,3 +90,156 @@ end @test f(2.5) == 6.25 end end + +@testset "Fallback policies" begin + @testset "Strict" begin + fww = FunctionWrappersWrapper( + +, (Tuple{Float64, Float64},), (Float64,); + cache = NoCache(), policy = Strict() + ) + @test fww(4.0, 8.0) === 12.0 + @test_throws FunctionWrappersWrappers.NoFunctionWrapperFoundError fww(4, 8) + @test_throws FunctionWrappersWrappers.NoFunctionWrapperFoundError fww( + BigFloat(4), BigFloat(8) + ) + end + + @testset "AllowAll" begin + fww = FunctionWrappersWrapper( + +, (Tuple{Float64, Float64},), (Float64,); + cache = NoCache(), policy = AllowAll() + ) + @test fww(4.0, 8.0) === 12.0 + @test fww(4, 8) === 12 + @test fww(4.0f0, 8.0f0) == 12.0f0 + @test fww(BigFloat(4), BigFloat(8)) == BigFloat(12) + end + + @testset "AllowNonIsBits" begin + fww = FunctionWrappersWrapper( + +, (Tuple{Float64, Float64},), (Float64,); + cache = NoCache(), policy = AllowNonIsBits() + ) + @test fww(4.0, 8.0) === 12.0 + # Float32 is isbits but doesn't match Float64 wrapper → error + @test_throws FunctionWrappersWrappers.NoFunctionWrapperFoundError fww(4.0f0, 8.0f0) + # Int is isbits but doesn't match Float64 wrapper → error + @test_throws FunctionWrappersWrappers.NoFunctionWrapperFoundError fww(4, 8) + # BigFloat is non-isbits → allowed + @test fww(BigFloat(4), BigFloat(8)) == BigFloat(12) + end + + @testset "AllowNonIsBits with arrays" begin + f!(du, u) = (du[1] = u[1]^2; nothing) + fww = FunctionWrappersWrapper( + f!, (Tuple{Vector{Float64}, Vector{Float64}},), (Nothing,); + cache = NoCache(), policy = AllowNonIsBits() + ) + du_f = [0.0]; u_f = [3.0] + fww(du_f, u_f) + @test du_f[1] === 9.0 + + # Float32 arrays: eltype is isbits but doesn't match → error + @test_throws FunctionWrappersWrappers.NoFunctionWrapperFoundError fww( + Float32[0.0], Float32[3.0] + ) + + # BigFloat arrays: eltype is non-isbits → allowed + du_bf = BigFloat[0]; u_bf = BigFloat[3] + fww(du_bf, u_bf) + @test du_bf[1] == BigFloat(9) + end +end + +@testset "Cache modes" begin + f!(du, u, p, t) = (du[1] = p[1] * u[1]; nothing) + + @testset "NoCache" begin + fww = FunctionWrappersWrapper( + f!, + (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64},), + (Nothing,); + cache = NoCache(), policy = AllowAll() + ) + # Float64 match + du = [0.0]; u = [2.0]; p = [3.0] + fww(du, u, p, 0.0) + @test du[1] === 6.0 + + # BigFloat fallback (NoCache: 1 alloc per call) + du_bf = BigFloat[0]; u_bf = BigFloat[2]; p_bf = BigFloat[3]; t_bf = BigFloat(0) + fww(du_bf, u_bf, p_bf, t_bf) + @test du_bf[1] == BigFloat(6) + end + + @testset "SingleCache" begin + fww = FunctionWrappersWrapper( + f!, + (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64},), + (Nothing,); + cache = SingleCache(), policy = AllowAll() + ) + du_bf = BigFloat[0]; u_bf = BigFloat[2]; p_bf = BigFloat[3]; t_bf = BigFloat(0) + # First call caches + fww(du_bf, u_bf, p_bf, t_bf) + @test du_bf[1] == BigFloat(6) + # Second call uses cache (0 alloc) + du_bf[1] = BigFloat(0) + fww(du_bf, u_bf, p_bf, t_bf) + @test du_bf[1] == BigFloat(6) + end + + @testset "DictCache" begin + fww = FunctionWrappersWrapper( + f!, + (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64},), + (Nothing,); + cache = DictCache(), policy = AllowAll() + ) + du_bf = BigFloat[0]; u_bf = BigFloat[2]; p_bf = BigFloat[3]; t_bf = BigFloat(0) + fww(du_bf, u_bf, p_bf, t_bf) + @test du_bf[1] == BigFloat(6) + + # Different type also works and caches separately + du_f32 = Float32[0]; u_f32 = Float32[2]; p_f32 = Float32[3]; t_f32 = Float32(0) + fww(du_f32, u_f32, p_f32, t_f32) + @test du_f32[1] === Float32(6) + + # BigFloat still cached + du_bf[1] = BigFloat(0) + fww(du_bf, u_bf, p_bf, t_bf) + @test du_bf[1] == BigFloat(6) + end + + @testset "SingleCache thrashing recovers" begin + fww = FunctionWrappersWrapper( + f!, + (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64},), + (Nothing,); + cache = SingleCache(), policy = AllowAll() + ) + du_bf = BigFloat[0]; u_bf = BigFloat[2]; p_bf = BigFloat[3]; t_bf = BigFloat(0) + du_f32 = Float32[0]; u_f32 = Float32[2]; p_f32 = Float32[3]; t_f32 = Float32(0) + + # Alternate types — each call replaces the cache but still works + fww(du_bf, u_bf, p_bf, t_bf) + @test du_bf[1] == BigFloat(6) + fww(du_f32, u_f32, p_f32, t_f32) + @test du_f32[1] === Float32(6) + du_bf[1] = BigFloat(0) + fww(du_bf, u_bf, p_bf, t_bf) + @test du_bf[1] == BigFloat(6) + end +end + +@testset "Default constructor uses SingleCache + AllowNonIsBits" begin + fww = FunctionWrappersWrapper( + +, (Tuple{Float64, Float64},), (Float64,) + ) + # Float64 matches wrapper + @test fww(4.0, 8.0) === 12.0 + # BigFloat is non-isbits → falls back + @test fww(BigFloat(4), BigFloat(8)) == BigFloat(12) + # Float32 is isbits mismatch → errors + @test_throws FunctionWrappersWrappers.NoFunctionWrapperFoundError fww(4.0f0, 8.0f0) +end