diff --git a/src/FunctionWrappersWrappers.jl b/src/FunctionWrappersWrappers.jl index f358ee8..8b5ef0a 100644 --- a/src/FunctionWrappersWrappers.jl +++ b/src/FunctionWrappersWrappers.jl @@ -111,6 +111,25 @@ end TruncatedStacktraces.@truncate_stacktrace FunctionWrappersWrapper +""" + FunctionWrappersWrapper{FW, P, CS}(f) + +Create a `FunctionWrappersWrapper` when the type parameters are specified. + +# Arguments +- `f`: The function to wrap + +# Type parameters +- `FW`: Tuple type of `FunctionWrapper`s +- `P`: Fallback policy (`Strict`, `AllowAll`, or `AllowNonIsBits`) +- `CS`: Cache storage type (`NoCacheStorage`, `SingleCacheStorage`, `DictCacheStorage`) +""" +function FunctionWrappersWrapper{FW, P, CS}(f) where {K, FW <: NTuple{K, Any}, P, CS} + fw = ntuple(i -> FW.parameters[i](f), Val(K)) + cs = CS() + return FunctionWrappersWrapper{FW, P, CS}(fw, cs) +end + """ FunctionWrappersWrapper(f, argtypes, rettypes; cache=SingleCache(), policy=AllowNonIsBits()) @@ -137,6 +156,8 @@ function FunctionWrappersWrapper( return FunctionWrappersWrapper{typeof(fwt), typeof(policy), typeof(cs)}(fwt, cs) end +Base.convert(::Type{T}, obj) where {T <: FunctionWrappersWrapper} = T(obj) +Base.convert(::Type{T}, obj::T) where {T <: FunctionWrappersWrapper} = obj # ============================================================================ # Call dispatch — entry point diff --git a/test/basictests.jl b/test/basictests.jl index 592049f..a469614 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -243,3 +243,22 @@ end # Float32 is isbits mismatch → errors @test_throws FunctionWrappersWrappers.NoFunctionWrapperFoundError fww(4.0f0, 8.0f0) end + +@testset "Conversion" begin + fww_exp = FunctionWrappersWrapper( + exp, + (Tuple{Float64}, Tuple{Float32}), + (Float64, Float32) + ) + FWW = typeof(fww_exp) + + fww_cos = FWW(cos) + @test typeof(fww_cos) == FWW + @test fww_cos(0.5) == cos(0.5) + + fww_vector = FWW[sin, tan, log] + @test eltype(fww_vector) == FWW + fww_vector[1](0.5) == sin(0.5) + fww_vector[2](0.5) == tan(0.5) + fww_vector[3](0.5) == log(0.5) +end