@@ -4,17 +4,151 @@ using FunctionWrappers
44import TruncatedStacktraces
55
66export FunctionWrappersWrapper, unwrap, wrapped_signatures, wrapped_return_types
7+ export NoCache, SingleCache, DictCache
8+ export Strict, AllowAll, AllowNonIsBits
79
8- struct FunctionWrappersWrapper{FW, FB}
10+ # ============================================================================
11+ # Cache modes: control how fallback FunctionWrappers are cached
12+ # ============================================================================
13+ abstract type AbstractCacheMode end
14+
15+ """
16+ NoCache()
17+
18+ No caching — every fallback call goes through dynamic dispatch (`obj[](arg...)`),
19+ incurring 1 allocation per call.
20+ """
21+ struct NoCache <: AbstractCacheMode end
22+
23+ """
24+ SingleCache()
25+
26+ Cache a single `FunctionWrapper` for the last-seen argument types. After the first
27+ fallback call, subsequent calls with the same types are zero-allocation. If called with
28+ different types, the cache is replaced (1 alloc on miss). This is the recommended default.
29+ """
30+ struct SingleCache <: AbstractCacheMode end
31+
32+ """
33+ DictCache()
34+
35+ Cache `FunctionWrapper`s in a `Dict` keyed by argument type. Handles multiple
36+ non-isbits types without thrashing. Slightly higher lookup overhead than `SingleCache`.
37+ """
38+ struct DictCache <: AbstractCacheMode end
39+
40+ # ============================================================================
41+ # Fallback policies: control when fallback is allowed
42+ # ============================================================================
43+ abstract type AbstractFallbackPolicy end
44+
45+ """
46+ Strict()
47+
48+ Never fall back — throw `NoFunctionWrapperFoundError` if no wrapper matches.
49+ """
50+ struct Strict <: AbstractFallbackPolicy end
51+
52+ """
53+ AllowAll()
54+
55+ Always fall back to the original function when no wrapper matches.
56+ """
57+ struct AllowAll <: AbstractFallbackPolicy end
58+
59+ """
60+ AllowNonIsBits()
61+
62+ Fall back only when argument types contain non-isbits elements (e.g., `BigFloat`,
63+ `SparseConnectivityTracer` types). Throws `NoFunctionWrapperFoundError` for isbits
64+ type mismatches (e.g., `Float32` when `Float64` was expected), which catches bugs.
65+ This is the recommended default.
66+ """
67+ struct AllowNonIsBits <: AbstractFallbackPolicy end
68+
69+ # ============================================================================
70+ # Cache storage types
71+ # ============================================================================
72+ struct NoCacheStorage end
73+ mutable struct SingleCacheStorage
74+ cached:: Any # Union{Nothing, FunctionWrapper}
75+ SingleCacheStorage () = new (nothing )
76+ end
77+ struct DictCacheStorage
78+ cache:: Dict{DataType, Any}
79+ DictCacheStorage () = new (Dict {DataType, Any} ())
80+ end
81+
82+ _make_cache_storage (:: NoCache ) = NoCacheStorage ()
83+ _make_cache_storage (:: SingleCache ) = SingleCacheStorage ()
84+ _make_cache_storage (:: DictCache ) = DictCacheStorage ()
85+
86+ # ============================================================================
87+ # Main type
88+ # ============================================================================
89+
90+ """
91+ FunctionWrappersWrapper{FW, P, CS}
92+
93+ A wrapper around a tuple of `FunctionWrapper`s that dispatches calls to the
94+ matching wrapper based on argument types. When no wrapper matches, behavior is
95+ controlled by the fallback policy `P` and cache mode `CS`.
96+
97+ # Type parameters
98+ - `FW`: Tuple type of `FunctionWrapper`s
99+ - `P`: Fallback policy (`Strict`, `AllowAll`, or `AllowNonIsBits`)
100+ - `CS`: Cache storage type (`NoCacheStorage`, `SingleCacheStorage`, `DictCacheStorage`)
101+ """
102+ struct FunctionWrappersWrapper{FW, P, CS}
9103 fw:: FW
104+ cache_storage:: CS
105+ function FunctionWrappersWrapper {FW, P, CS} (
106+ fw:: FW , cs:: CS
107+ ) where {FW, P, CS}
108+ return new {FW, P, CS} (fw, cs)
109+ end
10110end
11111
12112TruncatedStacktraces. @truncate_stacktrace FunctionWrappersWrapper
13113
14- function (fww:: FunctionWrappersWrapper{FW, FB} )(args:: Vararg{Any, K} ) where {FW, K, FB}
114+ """
115+ FunctionWrappersWrapper(f, argtypes, rettypes; cache=SingleCache(), policy=AllowNonIsBits())
116+
117+ Create a `FunctionWrappersWrapper` with configurable fallback behavior.
118+
119+ # Arguments
120+ - `f`: The function to wrap
121+ - `argtypes`: Tuple of argument type signatures (e.g., `(Tuple{Float64, Float64},)`)
122+ - `rettypes`: Tuple of return types (e.g., `(Float64,)`)
123+
124+ # Keywords
125+ - `cache`: Cache mode for fallback path — `NoCache()`, `SingleCache()` (default), or `DictCache()`
126+ - `policy`: Fallback policy — `Strict()`, `AllowAll()`, or `AllowNonIsBits()` (default)
127+ """
128+ function FunctionWrappersWrapper (
129+ f:: F , argtypes:: Tuple{Vararg{Any, K}} , rettypes:: Tuple{Vararg{Type, K}} ;
130+ cache:: AbstractCacheMode = SingleCache (),
131+ policy:: AbstractFallbackPolicy = AllowNonIsBits ()
132+ ) where {F, K}
133+ fwt = map (argtypes, rettypes) do A, R
134+ FunctionWrappers. FunctionWrapper {R, A} (f)
135+ end
136+ cs = _make_cache_storage (cache)
137+ return FunctionWrappersWrapper {typeof(fwt), typeof(policy), typeof(cs)} (fwt, cs)
138+ end
139+
140+
141+ # ============================================================================
142+ # Call dispatch — entry point
143+ # ============================================================================
144+
145+ function (fww:: FunctionWrappersWrapper{FW, P, CS} )(
146+ args:: Vararg{Any, K}
147+ ) where {FW, K, P, CS}
15148 return _call (fww. fw, args, fww)
16149end
17150
151+ # Match path: try each FunctionWrapper in order
18152function _call (
19153 fw:: Tuple{FunctionWrappers.FunctionWrapper{R, A}, Vararg} ,
20154 arg:: A , fww:: FunctionWrappersWrapper
@@ -28,6 +162,10 @@ function _call(
28162 return _call (Base. tail (fw), arg, fww)
29163end
30164
165+ # ============================================================================
166+ # Fallback — Strict: always error
167+ # ============================================================================
168+
31169const NO_FUNCTIONWRAPPER_FOUND_MESSAGE = " No matching function wrapper was found!"
32170
33171struct NoFunctionWrapperFoundError <: Exception end
@@ -36,68 +174,96 @@ function Base.showerror(io::IO, e::NoFunctionWrapperFoundError)
36174 return print (io, NO_FUNCTIONWRAPPER_FOUND_MESSAGE)
37175end
38176
39- function _call (:: Tuple{} , arg, fww:: FunctionWrappersWrapper{<:Any, false } )
177+ function _call (:: Tuple{} , arg, fww:: FunctionWrappersWrapper{<:Any, Strict } )
40178 throw (NoFunctionWrapperFoundError ())
41179end
42- function _call (:: Tuple{} , arg, fww:: FunctionWrappersWrapper{<:Any, true} )
43- return first (fww. fw). obj[](arg... )
180+
181+ # ============================================================================
182+ # Fallback — AllowAll: always fall back
183+ # ============================================================================
184+
185+ function _call (:: Tuple{} , arg, fww:: FunctionWrappersWrapper{<:Any, AllowAll} )
186+ return _fallback (arg, fww)
44187end
45188
46- function FunctionWrappersWrapper (
47- f:: F , argtypes:: Tuple{Vararg{Any, K}} , rettypes:: Tuple{Vararg{Type, K}} ,
48- fallback:: Val{FB} = Val {false} ()
49- ) where {F, K, FB}
50- fwt = map (argtypes, rettypes) do A, R
51- FunctionWrappers. FunctionWrapper {R, A} (f)
189+ # ============================================================================
190+ # Fallback — AllowNonIsBits: fall back only for non-isbits arg types
191+ # ============================================================================
192+
193+ function _call (
194+ :: Tuple{} , arg:: A , fww:: FunctionWrappersWrapper{<:Any, AllowNonIsBits}
195+ ) where {A}
196+ if _has_non_isbits_args (A)
197+ return _fallback (arg, fww)
52198 end
53- return FunctionWrappersWrapper {typeof(fwt), FB} (fwt )
199+ throw ( NoFunctionWrapperFoundError () )
54200end
55201
56- """
57- unwrap(fww::FunctionWrappersWrapper)
58-
59- Return the original function that was wrapped. This is useful for debugging
60- wrapped functions - you can use the returned function with debugging tools
61- like Debugger.jl or Infiltrator.jl.
202+ @generated function _has_non_isbits_args (:: Type{T} ) where {T <: Tuple }
203+ checks = []
204+ for P in T. parameters
205+ if P <: AbstractArray
206+ push! (checks, :(! (isbitstype ($ (eltype (P))))))
207+ else
208+ push! (checks, :(! (isbitstype ($ P))))
209+ end
210+ end
211+ isempty (checks) && return :(false )
212+ return Expr (:|| , checks... )
213+ end
62214
63- # Example
215+ # ============================================================================
216+ # Fallback execution — dispatch on cache storage type
217+ # ============================================================================
64218
65- ```julia
66- using FunctionWrappersWrappers
219+ # --- NoCache: direct dynamic dispatch every time ---
220+ function _fallback (arg, fww:: FunctionWrappersWrapper{<:Any, <:Any, NoCacheStorage} )
221+ return first (fww. fw). obj[](arg... )
222+ end
67223
68- # Create a wrapped function
69- fww = FunctionWrappersWrapper(sin, (Tuple{Float64},), (Float64,))
224+ # --- SingleCache: cache one FunctionWrapper for the last arg types ---
225+ function _fallback (
226+ arg:: A , fww:: FunctionWrappersWrapper{<:Any, <:Any, SingleCacheStorage}
227+ ) where {A}
228+ cached = fww. cache_storage. cached
229+ if cached isa FunctionWrappers. FunctionWrapper{Any, A}
230+ return cached (arg... )
231+ end
232+ f = first (fww. fw). obj[]
233+ new_fw = FunctionWrappers. FunctionWrapper {Any, A} (f)
234+ fww. cache_storage. cached = new_fw
235+ return new_fw (arg... )
236+ end
70237
71- # Get the original function for debugging
72- f = unwrap(fww) # Returns sin
238+ # --- DictCache: cache FunctionWrappers keyed by arg type ---
239+ function _fallback (
240+ arg:: A , fww:: FunctionWrappersWrapper{<:Any, <:Any, DictCacheStorage}
241+ ) where {A}
242+ cached = get (fww. cache_storage. cache, A, nothing )
243+ if cached isa FunctionWrappers. FunctionWrapper{Any, A}
244+ return cached (arg... )
245+ end
246+ f = first (fww. fw). obj[]
247+ new_fw = FunctionWrappers. FunctionWrapper {Any, A} (f)
248+ fww. cache_storage. cache[A] = new_fw
249+ return new_fw (arg... )
250+ end
73251
74- # Now you can debug with Debugger.jl:
75- # using Debugger
76- # @enter f(0.5)
252+ # ============================================================================
253+ # Introspection
254+ # ============================================================================
77255
78- # Or use Infiltrator.jl in your original function definition
79- ```
256+ """
257+ unwrap(fww::FunctionWrappersWrapper)
80258
81- See also: [`wrapped_signatures`](@ref), [`wrapped_return_types`](@ref)
259+ Return the original function that was wrapped.
82260"""
83261unwrap (fww:: FunctionWrappersWrapper ) = first (fww. fw). obj[]
84262
85263"""
86264 wrapped_signatures(fww::FunctionWrappersWrapper)
87265
88- Return a tuple of the argument type signatures that the `FunctionWrappersWrapper`
89- can dispatch on. Each element is a `Tuple` type representing the argument types.
90-
91- # Example
92-
93- ```julia
94- using FunctionWrappersWrappers
95-
96- fww = FunctionWrappersWrapper(+, (Tuple{Float64, Float64}, Tuple{Int, Int}), (Float64, Int))
97- wrapped_signatures(fww) # Returns (Tuple{Float64, Float64}, Tuple{Int, Int})
98- ```
99-
100- See also: [`unwrap`](@ref), [`wrapped_return_types`](@ref)
266+ Return a tuple of the argument type signatures that the wrapper can dispatch on.
101267"""
102268function wrapped_signatures (fww:: FunctionWrappersWrapper )
103269 return map (fw -> typeof (fw). parameters[2 ], fww. fw)
@@ -107,30 +273,19 @@ end
107273 wrapped_return_types(fww::FunctionWrappersWrapper)
108274
109275Return a tuple of the return types for each wrapped function signature.
110-
111- # Example
112-
113- ```julia
114- using FunctionWrappersWrappers
115-
116- fww = FunctionWrappersWrapper(+, (Tuple{Float64, Float64}, Tuple{Int, Int}), (Float64, Int))
117- wrapped_return_types(fww) # Returns (Float64, Int64)
118- ```
119-
120- See also: [`unwrap`](@ref), [`wrapped_signatures`](@ref)
121276"""
122277function wrapped_return_types (fww:: FunctionWrappersWrapper )
123278 return map (fw -> typeof (fw). parameters[1 ], fww. fw)
124279end
125280
281+ # ============================================================================
282+ # Precompilation
283+ # ============================================================================
284+
126285using PrecompileTools
127286
128287@setup_workload begin
129288 @compile_workload begin
130- # Precompile common use cases with Float64 and Int types
131- # These are the most common type combinations for numerical computations
132-
133- # Binary operation with multiple type combinations (common pattern)
134289 fw_binary = FunctionWrappersWrapper (
135290 + ,
136291 (Tuple{Float64, Float64}, Tuple{Int, Int}),
@@ -139,7 +294,6 @@ using PrecompileTools
139294 fw_binary (1.0 , 2.0 )
140295 fw_binary (1 , 2 )
141296
142- # Unary operation with multiple types (common pattern)
143297 fw_unary = FunctionWrappersWrapper (
144298 abs,
145299 (Tuple{Float64}, Tuple{Int}),
@@ -148,7 +302,6 @@ using PrecompileTools
148302 fw_unary (1.0 )
149303 fw_unary (1 )
150304
151- # Precompile introspection functions
152305 unwrap (fw_unary)
153306 wrapped_signatures (fw_binary)
154307 wrapped_return_types (fw_binary)
0 commit comments