Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FunctionWrappersWrappers"
uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf"
authors = ["Chris Elrod <elrodc@gmail.com> and contributors"]
version = "0.1.5"
version = "1.0.0"

[deps]
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
Expand Down
275 changes: 214 additions & 61 deletions src/FunctionWrappersWrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,151 @@ using FunctionWrappers
import TruncatedStacktraces

export FunctionWrappersWrapper, unwrap, wrapped_signatures, wrapped_return_types
export NoCache, SingleCache, DictCache
export Strict, AllowAll, AllowNonIsBits

struct FunctionWrappersWrapper{FW, FB}
# ============================================================================
# Cache modes: control how fallback FunctionWrappers are cached
# ============================================================================
abstract type AbstractCacheMode end

"""
NoCache()

No caching — every fallback call goes through dynamic dispatch (`obj[](arg...)`),
incurring 1 allocation per call.
"""
struct NoCache <: AbstractCacheMode end

"""
SingleCache()

Cache a single `FunctionWrapper` for the last-seen argument types. After the first
fallback call, subsequent calls with the same types are zero-allocation. If called with
different types, the cache is replaced (1 alloc on miss). This is the recommended default.
"""
struct SingleCache <: AbstractCacheMode end

"""
DictCache()

Cache `FunctionWrapper`s in a `Dict` keyed by argument type. Handles multiple
non-isbits types without thrashing. Slightly higher lookup overhead than `SingleCache`.
"""
struct DictCache <: AbstractCacheMode end

# ============================================================================
# Fallback policies: control when fallback is allowed
# ============================================================================
abstract type AbstractFallbackPolicy end

"""
Strict()

Never fall back — throw `NoFunctionWrapperFoundError` if no wrapper matches.
"""
struct Strict <: AbstractFallbackPolicy end

"""
AllowAll()

Always fall back to the original function when no wrapper matches.
"""
struct AllowAll <: AbstractFallbackPolicy end

"""
AllowNonIsBits()

Fall back only when argument types contain non-isbits elements (e.g., `BigFloat`,
`SparseConnectivityTracer` types). Throws `NoFunctionWrapperFoundError` for isbits
type mismatches (e.g., `Float32` when `Float64` was expected), which catches bugs.
This is the recommended default.
"""
struct AllowNonIsBits <: AbstractFallbackPolicy end

# ============================================================================
# Cache storage types
# ============================================================================
struct NoCacheStorage end
mutable struct SingleCacheStorage
cached::Any # Union{Nothing, FunctionWrapper}
SingleCacheStorage() = new(nothing)
end
struct DictCacheStorage
cache::Dict{DataType, Any}
DictCacheStorage() = new(Dict{DataType, Any}())
end

_make_cache_storage(::NoCache) = NoCacheStorage()
_make_cache_storage(::SingleCache) = SingleCacheStorage()
_make_cache_storage(::DictCache) = DictCacheStorage()

# ============================================================================
# Main type
# ============================================================================

"""
FunctionWrappersWrapper{FW, P, CS}

A wrapper around a tuple of `FunctionWrapper`s that dispatches calls to the
matching wrapper based on argument types. When no wrapper matches, behavior is
controlled by the fallback policy `P` and cache mode `CS`.

# Type parameters
- `FW`: Tuple type of `FunctionWrapper`s
- `P`: Fallback policy (`Strict`, `AllowAll`, or `AllowNonIsBits`)
- `CS`: Cache storage type (`NoCacheStorage`, `SingleCacheStorage`, `DictCacheStorage`)
"""
struct FunctionWrappersWrapper{FW, P, CS}
fw::FW
cache_storage::CS
function FunctionWrappersWrapper{FW, P, CS}(
fw::FW, cs::CS
) where {FW, P, CS}
return new{FW, P, CS}(fw, cs)
end
end

TruncatedStacktraces.@truncate_stacktrace FunctionWrappersWrapper

function (fww::FunctionWrappersWrapper{FW, FB})(args::Vararg{Any, K}) where {FW, K, FB}
"""
FunctionWrappersWrapper(f, argtypes, rettypes; cache=SingleCache(), policy=AllowNonIsBits())

Create a `FunctionWrappersWrapper` with configurable fallback behavior.

# Arguments
- `f`: The function to wrap
- `argtypes`: Tuple of argument type signatures (e.g., `(Tuple{Float64, Float64},)`)
- `rettypes`: Tuple of return types (e.g., `(Float64,)`)

# Keywords
- `cache`: Cache mode for fallback path — `NoCache()`, `SingleCache()` (default), or `DictCache()`
- `policy`: Fallback policy — `Strict()`, `AllowAll()`, or `AllowNonIsBits()` (default)
"""
function FunctionWrappersWrapper(
f::F, argtypes::Tuple{Vararg{Any, K}}, rettypes::Tuple{Vararg{Type, K}};
cache::AbstractCacheMode = SingleCache(),
policy::AbstractFallbackPolicy = AllowNonIsBits()
) where {F, K}
fwt = map(argtypes, rettypes) do A, R
FunctionWrappers.FunctionWrapper{R, A}(f)
end
cs = _make_cache_storage(cache)
return FunctionWrappersWrapper{typeof(fwt), typeof(policy), typeof(cs)}(fwt, cs)
end


# ============================================================================
# Call dispatch — entry point
# ============================================================================

function (fww::FunctionWrappersWrapper{FW, P, CS})(
args::Vararg{Any, K}
) where {FW, K, P, CS}
return _call(fww.fw, args, fww)
end

# Match path: try each FunctionWrapper in order
function _call(
fw::Tuple{FunctionWrappers.FunctionWrapper{R, A}, Vararg},
arg::A, fww::FunctionWrappersWrapper
Expand All @@ -28,6 +162,10 @@ function _call(
return _call(Base.tail(fw), arg, fww)
end

# ============================================================================
# Fallback — Strict: always error
# ============================================================================

const NO_FUNCTIONWRAPPER_FOUND_MESSAGE = "No matching function wrapper was found!"

struct NoFunctionWrapperFoundError <: Exception end
Expand All @@ -36,68 +174,96 @@ function Base.showerror(io::IO, e::NoFunctionWrapperFoundError)
return print(io, NO_FUNCTIONWRAPPER_FOUND_MESSAGE)
end

function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, false})
function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, Strict})
throw(NoFunctionWrapperFoundError())
end
function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, true})
return first(fww.fw).obj[](arg...)

# ============================================================================
# Fallback — AllowAll: always fall back
# ============================================================================

function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, AllowAll})
return _fallback(arg, fww)
end

function FunctionWrappersWrapper(
f::F, argtypes::Tuple{Vararg{Any, K}}, rettypes::Tuple{Vararg{Type, K}},
fallback::Val{FB} = Val{false}()
) where {F, K, FB}
fwt = map(argtypes, rettypes) do A, R
FunctionWrappers.FunctionWrapper{R, A}(f)
# ============================================================================
# Fallback — AllowNonIsBits: fall back only for non-isbits arg types
# ============================================================================

function _call(
::Tuple{}, arg::A, fww::FunctionWrappersWrapper{<:Any, AllowNonIsBits}
) where {A}
if _has_non_isbits_args(A)
return _fallback(arg, fww)
end
return FunctionWrappersWrapper{typeof(fwt), FB}(fwt)
throw(NoFunctionWrapperFoundError())
end

"""
unwrap(fww::FunctionWrappersWrapper)

Return the original function that was wrapped. This is useful for debugging
wrapped functions - you can use the returned function with debugging tools
like Debugger.jl or Infiltrator.jl.
@generated function _has_non_isbits_args(::Type{T}) where {T <: Tuple}
checks = []
for P in T.parameters
if P <: AbstractArray
push!(checks, :(!(isbitstype($(eltype(P))))))
else
push!(checks, :(!(isbitstype($P))))
end
end
isempty(checks) && return :(false)
return Expr(:||, checks...)
end

# Example
# ============================================================================
# Fallback execution — dispatch on cache storage type
# ============================================================================

```julia
using FunctionWrappersWrappers
# --- NoCache: direct dynamic dispatch every time ---
function _fallback(arg, fww::FunctionWrappersWrapper{<:Any, <:Any, NoCacheStorage})
return first(fww.fw).obj[](arg...)
end

# Create a wrapped function
fww = FunctionWrappersWrapper(sin, (Tuple{Float64},), (Float64,))
# --- SingleCache: cache one FunctionWrapper for the last arg types ---
function _fallback(
arg::A, fww::FunctionWrappersWrapper{<:Any, <:Any, SingleCacheStorage}
) where {A}
cached = fww.cache_storage.cached
if cached isa FunctionWrappers.FunctionWrapper{Any, A}
return cached(arg...)
end
f = first(fww.fw).obj[]
new_fw = FunctionWrappers.FunctionWrapper{Any, A}(f)
fww.cache_storage.cached = new_fw
return new_fw(arg...)
end

# Get the original function for debugging
f = unwrap(fww) # Returns sin
# --- DictCache: cache FunctionWrappers keyed by arg type ---
function _fallback(
arg::A, fww::FunctionWrappersWrapper{<:Any, <:Any, DictCacheStorage}
) where {A}
cached = get(fww.cache_storage.cache, A, nothing)
if cached isa FunctionWrappers.FunctionWrapper{Any, A}
return cached(arg...)
end
f = first(fww.fw).obj[]
new_fw = FunctionWrappers.FunctionWrapper{Any, A}(f)
fww.cache_storage.cache[A] = new_fw
return new_fw(arg...)
end

# Now you can debug with Debugger.jl:
# using Debugger
# @enter f(0.5)
# ============================================================================
# Introspection
# ============================================================================

# Or use Infiltrator.jl in your original function definition
```
"""
unwrap(fww::FunctionWrappersWrapper)

See also: [`wrapped_signatures`](@ref), [`wrapped_return_types`](@ref)
Return the original function that was wrapped.
"""
unwrap(fww::FunctionWrappersWrapper) = first(fww.fw).obj[]

"""
wrapped_signatures(fww::FunctionWrappersWrapper)

Return a tuple of the argument type signatures that the `FunctionWrappersWrapper`
can dispatch on. Each element is a `Tuple` type representing the argument types.

# Example

```julia
using FunctionWrappersWrappers

fww = FunctionWrappersWrapper(+, (Tuple{Float64, Float64}, Tuple{Int, Int}), (Float64, Int))
wrapped_signatures(fww) # Returns (Tuple{Float64, Float64}, Tuple{Int, Int})
```

See also: [`unwrap`](@ref), [`wrapped_return_types`](@ref)
Return a tuple of the argument type signatures that the wrapper can dispatch on.
"""
function wrapped_signatures(fww::FunctionWrappersWrapper)
return map(fw -> typeof(fw).parameters[2], fww.fw)
Expand All @@ -107,30 +273,19 @@ end
wrapped_return_types(fww::FunctionWrappersWrapper)

Return a tuple of the return types for each wrapped function signature.

# Example

```julia
using FunctionWrappersWrappers

fww = FunctionWrappersWrapper(+, (Tuple{Float64, Float64}, Tuple{Int, Int}), (Float64, Int))
wrapped_return_types(fww) # Returns (Float64, Int64)
```

See also: [`unwrap`](@ref), [`wrapped_signatures`](@ref)
"""
function wrapped_return_types(fww::FunctionWrappersWrapper)
return map(fw -> typeof(fw).parameters[1], fww.fw)
end

# ============================================================================
# Precompilation
# ============================================================================

using PrecompileTools

@setup_workload begin
@compile_workload begin
# Precompile common use cases with Float64 and Int types
# These are the most common type combinations for numerical computations

# Binary operation with multiple type combinations (common pattern)
fw_binary = FunctionWrappersWrapper(
+,
(Tuple{Float64, Float64}, Tuple{Int, Int}),
Expand All @@ -139,7 +294,6 @@ using PrecompileTools
fw_binary(1.0, 2.0)
fw_binary(1, 2)

# Unary operation with multiple types (common pattern)
fw_unary = FunctionWrappersWrapper(
abs,
(Tuple{Float64}, Tuple{Int}),
Expand All @@ -148,7 +302,6 @@ using PrecompileTools
fw_unary(1.0)
fw_unary(1)

# Precompile introspection functions
unwrap(fw_unary)
wrapped_signatures(fw_binary)
wrapped_return_types(fw_binary)
Expand Down
Loading
Loading