Skip to content

Commit 86cbba9

Browse files
Merge pull request #34 from ChrisRackauckas-Claude/cr/cached-fallback-policy
Add configurable cache modes and fallback policies for non-isbits types
2 parents f45aed9 + cb2faed commit 86cbba9

3 files changed

Lines changed: 368 additions & 66 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FunctionWrappersWrappers"
22
uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf"
33
authors = ["Chris Elrod <elrodc@gmail.com> and contributors"]
4-
version = "0.1.5"
4+
version = "1.0.0"
55

66
[deps]
77
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"

src/FunctionWrappersWrappers.jl

Lines changed: 214 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,151 @@ using FunctionWrappers
44
import TruncatedStacktraces
55

66
export FunctionWrappersWrapper, unwrap, wrapped_signatures, wrapped_return_types
7+
export NoCache, SingleCache, DictCache
8+
export Strict, AllowAll, AllowNonIsBits
79

8-
struct FunctionWrappersWrapper{FW, FB}
10+
# ============================================================================
11+
# Cache modes: control how fallback FunctionWrappers are cached
12+
# ============================================================================
13+
abstract type AbstractCacheMode end
14+
15+
"""
16+
NoCache()
17+
18+
No caching — every fallback call goes through dynamic dispatch (`obj[](arg...)`),
19+
incurring 1 allocation per call.
20+
"""
21+
struct NoCache <: AbstractCacheMode end
22+
23+
"""
24+
SingleCache()
25+
26+
Cache a single `FunctionWrapper` for the last-seen argument types. After the first
27+
fallback call, subsequent calls with the same types are zero-allocation. If called with
28+
different types, the cache is replaced (1 alloc on miss). This is the recommended default.
29+
"""
30+
struct SingleCache <: AbstractCacheMode end
31+
32+
"""
33+
DictCache()
34+
35+
Cache `FunctionWrapper`s in a `Dict` keyed by argument type. Handles multiple
36+
non-isbits types without thrashing. Slightly higher lookup overhead than `SingleCache`.
37+
"""
38+
struct DictCache <: AbstractCacheMode end
39+
40+
# ============================================================================
41+
# Fallback policies: control when fallback is allowed
42+
# ============================================================================
43+
abstract type AbstractFallbackPolicy end
44+
45+
"""
46+
Strict()
47+
48+
Never fall back — throw `NoFunctionWrapperFoundError` if no wrapper matches.
49+
"""
50+
struct Strict <: AbstractFallbackPolicy end
51+
52+
"""
53+
AllowAll()
54+
55+
Always fall back to the original function when no wrapper matches.
56+
"""
57+
struct AllowAll <: AbstractFallbackPolicy end
58+
59+
"""
60+
AllowNonIsBits()
61+
62+
Fall back only when argument types contain non-isbits elements (e.g., `BigFloat`,
63+
`SparseConnectivityTracer` types). Throws `NoFunctionWrapperFoundError` for isbits
64+
type mismatches (e.g., `Float32` when `Float64` was expected), which catches bugs.
65+
This is the recommended default.
66+
"""
67+
struct AllowNonIsBits <: AbstractFallbackPolicy end
68+
69+
# ============================================================================
70+
# Cache storage types
71+
# ============================================================================
72+
struct NoCacheStorage end
73+
mutable struct SingleCacheStorage
74+
cached::Any # Union{Nothing, FunctionWrapper}
75+
SingleCacheStorage() = new(nothing)
76+
end
77+
struct DictCacheStorage
78+
cache::Dict{DataType, Any}
79+
DictCacheStorage() = new(Dict{DataType, Any}())
80+
end
81+
82+
_make_cache_storage(::NoCache) = NoCacheStorage()
83+
_make_cache_storage(::SingleCache) = SingleCacheStorage()
84+
_make_cache_storage(::DictCache) = DictCacheStorage()
85+
86+
# ============================================================================
87+
# Main type
88+
# ============================================================================
89+
90+
"""
91+
FunctionWrappersWrapper{FW, P, CS}
92+
93+
A wrapper around a tuple of `FunctionWrapper`s that dispatches calls to the
94+
matching wrapper based on argument types. When no wrapper matches, behavior is
95+
controlled by the fallback policy `P` and cache mode `CS`.
96+
97+
# Type parameters
98+
- `FW`: Tuple type of `FunctionWrapper`s
99+
- `P`: Fallback policy (`Strict`, `AllowAll`, or `AllowNonIsBits`)
100+
- `CS`: Cache storage type (`NoCacheStorage`, `SingleCacheStorage`, `DictCacheStorage`)
101+
"""
102+
struct FunctionWrappersWrapper{FW, P, CS}
9103
fw::FW
104+
cache_storage::CS
105+
function FunctionWrappersWrapper{FW, P, CS}(
106+
fw::FW, cs::CS
107+
) where {FW, P, CS}
108+
return new{FW, P, CS}(fw, cs)
109+
end
10110
end
11111

12112
TruncatedStacktraces.@truncate_stacktrace FunctionWrappersWrapper
13113

14-
function (fww::FunctionWrappersWrapper{FW, FB})(args::Vararg{Any, K}) where {FW, K, FB}
114+
"""
115+
FunctionWrappersWrapper(f, argtypes, rettypes; cache=SingleCache(), policy=AllowNonIsBits())
116+
117+
Create a `FunctionWrappersWrapper` with configurable fallback behavior.
118+
119+
# Arguments
120+
- `f`: The function to wrap
121+
- `argtypes`: Tuple of argument type signatures (e.g., `(Tuple{Float64, Float64},)`)
122+
- `rettypes`: Tuple of return types (e.g., `(Float64,)`)
123+
124+
# Keywords
125+
- `cache`: Cache mode for fallback path — `NoCache()`, `SingleCache()` (default), or `DictCache()`
126+
- `policy`: Fallback policy — `Strict()`, `AllowAll()`, or `AllowNonIsBits()` (default)
127+
"""
128+
function FunctionWrappersWrapper(
129+
f::F, argtypes::Tuple{Vararg{Any, K}}, rettypes::Tuple{Vararg{Type, K}};
130+
cache::AbstractCacheMode = SingleCache(),
131+
policy::AbstractFallbackPolicy = AllowNonIsBits()
132+
) where {F, K}
133+
fwt = map(argtypes, rettypes) do A, R
134+
FunctionWrappers.FunctionWrapper{R, A}(f)
135+
end
136+
cs = _make_cache_storage(cache)
137+
return FunctionWrappersWrapper{typeof(fwt), typeof(policy), typeof(cs)}(fwt, cs)
138+
end
139+
140+
141+
# ============================================================================
142+
# Call dispatch — entry point
143+
# ============================================================================
144+
145+
function (fww::FunctionWrappersWrapper{FW, P, CS})(
146+
args::Vararg{Any, K}
147+
) where {FW, K, P, CS}
15148
return _call(fww.fw, args, fww)
16149
end
17150

151+
# Match path: try each FunctionWrapper in order
18152
function _call(
19153
fw::Tuple{FunctionWrappers.FunctionWrapper{R, A}, Vararg},
20154
arg::A, fww::FunctionWrappersWrapper
@@ -28,6 +162,10 @@ function _call(
28162
return _call(Base.tail(fw), arg, fww)
29163
end
30164

165+
# ============================================================================
166+
# Fallback — Strict: always error
167+
# ============================================================================
168+
31169
const NO_FUNCTIONWRAPPER_FOUND_MESSAGE = "No matching function wrapper was found!"
32170

33171
struct NoFunctionWrapperFoundError <: Exception end
@@ -36,68 +174,96 @@ function Base.showerror(io::IO, e::NoFunctionWrapperFoundError)
36174
return print(io, NO_FUNCTIONWRAPPER_FOUND_MESSAGE)
37175
end
38176

39-
function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, false})
177+
function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, Strict})
40178
throw(NoFunctionWrapperFoundError())
41179
end
42-
function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, true})
43-
return first(fww.fw).obj[](arg...)
180+
181+
# ============================================================================
182+
# Fallback — AllowAll: always fall back
183+
# ============================================================================
184+
185+
function _call(::Tuple{}, arg, fww::FunctionWrappersWrapper{<:Any, AllowAll})
186+
return _fallback(arg, fww)
44187
end
45188

46-
function FunctionWrappersWrapper(
47-
f::F, argtypes::Tuple{Vararg{Any, K}}, rettypes::Tuple{Vararg{Type, K}},
48-
fallback::Val{FB} = Val{false}()
49-
) where {F, K, FB}
50-
fwt = map(argtypes, rettypes) do A, R
51-
FunctionWrappers.FunctionWrapper{R, A}(f)
189+
# ============================================================================
190+
# Fallback — AllowNonIsBits: fall back only for non-isbits arg types
191+
# ============================================================================
192+
193+
function _call(
194+
::Tuple{}, arg::A, fww::FunctionWrappersWrapper{<:Any, AllowNonIsBits}
195+
) where {A}
196+
if _has_non_isbits_args(A)
197+
return _fallback(arg, fww)
52198
end
53-
return FunctionWrappersWrapper{typeof(fwt), FB}(fwt)
199+
throw(NoFunctionWrapperFoundError())
54200
end
55201

56-
"""
57-
unwrap(fww::FunctionWrappersWrapper)
58-
59-
Return the original function that was wrapped. This is useful for debugging
60-
wrapped functions - you can use the returned function with debugging tools
61-
like Debugger.jl or Infiltrator.jl.
202+
@generated function _has_non_isbits_args(::Type{T}) where {T <: Tuple}
203+
checks = []
204+
for P in T.parameters
205+
if P <: AbstractArray
206+
push!(checks, :(!(isbitstype($(eltype(P))))))
207+
else
208+
push!(checks, :(!(isbitstype($P))))
209+
end
210+
end
211+
isempty(checks) && return :(false)
212+
return Expr(:||, checks...)
213+
end
62214

63-
# Example
215+
# ============================================================================
216+
# Fallback execution — dispatch on cache storage type
217+
# ============================================================================
64218

65-
```julia
66-
using FunctionWrappersWrappers
219+
# --- NoCache: direct dynamic dispatch every time ---
220+
function _fallback(arg, fww::FunctionWrappersWrapper{<:Any, <:Any, NoCacheStorage})
221+
return first(fww.fw).obj[](arg...)
222+
end
67223

68-
# Create a wrapped function
69-
fww = FunctionWrappersWrapper(sin, (Tuple{Float64},), (Float64,))
224+
# --- SingleCache: cache one FunctionWrapper for the last arg types ---
225+
function _fallback(
226+
arg::A, fww::FunctionWrappersWrapper{<:Any, <:Any, SingleCacheStorage}
227+
) where {A}
228+
cached = fww.cache_storage.cached
229+
if cached isa FunctionWrappers.FunctionWrapper{Any, A}
230+
return cached(arg...)
231+
end
232+
f = first(fww.fw).obj[]
233+
new_fw = FunctionWrappers.FunctionWrapper{Any, A}(f)
234+
fww.cache_storage.cached = new_fw
235+
return new_fw(arg...)
236+
end
70237

71-
# Get the original function for debugging
72-
f = unwrap(fww) # Returns sin
238+
# --- DictCache: cache FunctionWrappers keyed by arg type ---
239+
function _fallback(
240+
arg::A, fww::FunctionWrappersWrapper{<:Any, <:Any, DictCacheStorage}
241+
) where {A}
242+
cached = get(fww.cache_storage.cache, A, nothing)
243+
if cached isa FunctionWrappers.FunctionWrapper{Any, A}
244+
return cached(arg...)
245+
end
246+
f = first(fww.fw).obj[]
247+
new_fw = FunctionWrappers.FunctionWrapper{Any, A}(f)
248+
fww.cache_storage.cache[A] = new_fw
249+
return new_fw(arg...)
250+
end
73251

74-
# Now you can debug with Debugger.jl:
75-
# using Debugger
76-
# @enter f(0.5)
252+
# ============================================================================
253+
# Introspection
254+
# ============================================================================
77255

78-
# Or use Infiltrator.jl in your original function definition
79-
```
256+
"""
257+
unwrap(fww::FunctionWrappersWrapper)
80258
81-
See also: [`wrapped_signatures`](@ref), [`wrapped_return_types`](@ref)
259+
Return the original function that was wrapped.
82260
"""
83261
unwrap(fww::FunctionWrappersWrapper) = first(fww.fw).obj[]
84262

85263
"""
86264
wrapped_signatures(fww::FunctionWrappersWrapper)
87265
88-
Return a tuple of the argument type signatures that the `FunctionWrappersWrapper`
89-
can dispatch on. Each element is a `Tuple` type representing the argument types.
90-
91-
# Example
92-
93-
```julia
94-
using FunctionWrappersWrappers
95-
96-
fww = FunctionWrappersWrapper(+, (Tuple{Float64, Float64}, Tuple{Int, Int}), (Float64, Int))
97-
wrapped_signatures(fww) # Returns (Tuple{Float64, Float64}, Tuple{Int, Int})
98-
```
99-
100-
See also: [`unwrap`](@ref), [`wrapped_return_types`](@ref)
266+
Return a tuple of the argument type signatures that the wrapper can dispatch on.
101267
"""
102268
function wrapped_signatures(fww::FunctionWrappersWrapper)
103269
return map(fw -> typeof(fw).parameters[2], fww.fw)
@@ -107,30 +273,19 @@ end
107273
wrapped_return_types(fww::FunctionWrappersWrapper)
108274
109275
Return a tuple of the return types for each wrapped function signature.
110-
111-
# Example
112-
113-
```julia
114-
using FunctionWrappersWrappers
115-
116-
fww = FunctionWrappersWrapper(+, (Tuple{Float64, Float64}, Tuple{Int, Int}), (Float64, Int))
117-
wrapped_return_types(fww) # Returns (Float64, Int64)
118-
```
119-
120-
See also: [`unwrap`](@ref), [`wrapped_signatures`](@ref)
121276
"""
122277
function wrapped_return_types(fww::FunctionWrappersWrapper)
123278
return map(fw -> typeof(fw).parameters[1], fww.fw)
124279
end
125280

281+
# ============================================================================
282+
# Precompilation
283+
# ============================================================================
284+
126285
using PrecompileTools
127286

128287
@setup_workload begin
129288
@compile_workload begin
130-
# Precompile common use cases with Float64 and Int types
131-
# These are the most common type combinations for numerical computations
132-
133-
# Binary operation with multiple type combinations (common pattern)
134289
fw_binary = FunctionWrappersWrapper(
135290
+,
136291
(Tuple{Float64, Float64}, Tuple{Int, Int}),
@@ -139,7 +294,6 @@ using PrecompileTools
139294
fw_binary(1.0, 2.0)
140295
fw_binary(1, 2)
141296

142-
# Unary operation with multiple types (common pattern)
143297
fw_unary = FunctionWrappersWrapper(
144298
abs,
145299
(Tuple{Float64}, Tuple{Int}),
@@ -148,7 +302,6 @@ using PrecompileTools
148302
fw_unary(1.0)
149303
fw_unary(1)
150304

151-
# Precompile introspection functions
152305
unwrap(fw_unary)
153306
wrapped_signatures(fw_binary)
154307
wrapped_return_types(fw_binary)

0 commit comments

Comments
 (0)