From 3145f5d768a6117659a1829b7f318061465d11c5 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 30 Apr 2026 13:42:41 +0100 Subject: [PATCH 01/17] Add Evaluators module: prepare/Prepared interface, evaluator shapes, and vectorisation utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces `AbstractPPL.Evaluators`, a new submodule providing structural plumbing for preparing and calling model evaluators, with or without AD. Core types and functions: - `Prepared{AD,E,C}` — wraps an evaluator with an AD backend type (for dispatch) and an optional backend-specific cache; exported from AbstractPPL - `prepare(problem, x; check_dims=true)` / `prepare(adtype, problem, x; check_dims=true)` — structural and AD-aware preparation; AD-backend extensions implement the three-argument form and thread `check_dims` to `VectorEvaluator{check_dims}` - `value_and_gradient!!(prepared, x)` / `value_and_jacobian!!(prepared, x)` — stub interface for derivative computation; returned arrays may alias internal cache buffers (callers copy if they need to retain past the next call) - `evaluate!!(evaluator, x)` — extends the existing AbstractPPL.evaluate!! for Prepared, VectorEvaluator, and NamedTupleEvaluator Evaluator shapes: - `VectorEvaluator{CheckInput}` — wraps a callable with a flat-vector input contract; CheckInput=false skips the per-call length check for AD hot paths - `NamedTupleEvaluator{CheckInput}` — same for NamedTuple inputs with a fixed prototype Vectorisation utilities: - `flatten_to!!(buf, x)` / `unflatten_to!!(x, buf)` — round-trip vectorisation for Real, Complex, AbstractArray, Tuple, and NamedTuple; @generated NamedTuple unflatten preserves type inferability LogDensityProblems extension (updated): - logdensity, dimension, capabilities for Prepared and VectorEvaluator - logdensity_and_gradient copies the gradient out of the aliasable !! buffer Co-Authored-By: Claude Sonnet 4.6 --- .github/workflows/CI.yml | 20 ++ Project.toml | 5 + docs/Project.toml | 4 +- docs/make.jl | 2 +- docs/src/evaluators.md | 139 +++++++++++ docs/src/pplapi.md | 4 + ext/AbstractPPLLogDensityProblemsExt.jl | 28 +++ src/AbstractPPL.jl | 7 + src/evaluate.jl | 2 +- src/evaluators/Evaluators.jl | 216 ++++++++++++++++++ src/evaluators/utils.jl | 143 ++++++++++++ test/Project.toml | 3 + test/evaluators/Evaluators.jl | 123 ++++++++++ test/evaluators/utils.jl | 74 ++++++ test/ext/logdensityproblems/Project.toml | 11 + .../logdensityproblems/logdensityproblems.jl | 51 +++++ test/run_extras.jl | 15 ++ test/runtests.jl | 2 + 18 files changed, 846 insertions(+), 3 deletions(-) create mode 100644 docs/src/evaluators.md create mode 100644 ext/AbstractPPLLogDensityProblemsExt.jl create mode 100644 src/evaluators/Evaluators.jl create mode 100644 src/evaluators/utils.jl create mode 100644 test/evaluators/Evaluators.jl create mode 100644 test/evaluators/utils.jl create mode 100644 test/ext/logdensityproblems/Project.toml create mode 100644 test/ext/logdensityproblems/logdensityproblems.jl create mode 100644 test/run_extras.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 3a992da5..ef4049b5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -52,3 +52,23 @@ jobs: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true + + ext: + name: Ext (logdensityproblems, ${{ matrix.version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - '1' + - 'min' + steps: + - uses: actions/checkout@v6 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/cache@v3 + - uses: julia-actions/julia-buildpkg@v1 + - run: julia --project=. test/run_extras.jl + env: + LABEL: logdensityproblems diff --git a/Project.toml b/Project.toml index 00e48d00..bf497189 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ desc = "Common interfaces for probabilistic programming" version = "0.14.2" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -19,11 +20,15 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" [extensions] AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] +AbstractPPLLogDensityProblemsExt = ["LogDensityProblems"] [compat] +ADTypes = "1" +LogDensityProblems = "2" AbstractMCMC = "2, 3, 4, 5" Accessors = "0.1" BangBang = "0.4" diff --git a/docs/Project.toml b/docs/Project.toml index 9aed942a..1ef0eca3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,8 +1,10 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [sources] -AbstractPPL = {path = "../"} +AbstractPPL = {path = ".."} diff --git a/docs/make.jl b/docs/make.jl index abab3f69..ead3820f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,7 +9,7 @@ DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive= makedocs(; sitename="AbstractPPL", modules=[AbstractPPL, Base.get_extension(AbstractPPL, :AbstractPPLDistributionsExt)], - pages=["index.md", "varname.md", "pplapi.md", "interface.md"], + pages=["index.md", "varname.md", "pplapi.md", "evaluators.md", "interface.md"], checkdocs=:exports, doctest=false, ) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md new file mode 100644 index 00000000..ac55f930 --- /dev/null +++ b/docs/src/evaluators.md @@ -0,0 +1,139 @@ +# Evaluator preparation and AD + +AbstractPPL provides a small interface for preparing callables and asking a +prepared evaluator for values and derivatives. `prepare` binds a callable to a +sample input that establishes the expected input shape and type; +`value_and_gradient!!` and `value_and_jacobian!!` then return the value and +derivative together. + +The `!!` suffix signals that the returned gradient or Jacobian **may alias +internal cache buffers** of the prepared evaluator. Copy if you need to retain +the result past the next call. + +## Quick start + +```@example ad +using AbstractPPL +using AbstractPPL: prepare, value_and_gradient!!, Prepared +using AbstractPPL.Evaluators: VectorEvaluator, NamedTupleEvaluator +using ADTypes: AutoForwardDiff +using ForwardDiff: ForwardDiff + +function AbstractPPL.prepare(adtype::AutoForwardDiff, f, x::AbstractVector{<:Real}) + return Prepared(adtype, VectorEvaluator(f, length(x))) +end + +function AbstractPPL.value_and_gradient!!( + p::Prepared{AutoForwardDiff}, x::AbstractVector{<:Real} +) + return (p(x), ForwardDiff.gradient(p.evaluator.f, x)) +end + +mvnormal_logp(x) = -0.5 * sum(abs2, x) # standard normal log density (up to constant) +prepared = prepare(AutoForwardDiff(), mvnormal_logp, zeros(3)) +value_and_gradient!!(prepared, [1.0, 2.0, 3.0]) +``` + +## Two input styles + +### Vector inputs + +When the callable accepts a flat vector, pass a sample vector whose length +matches the expected input: + +```@example ad +prepared([1.0, 2.0, 3.0]) +``` + +For vector-valued callables, use `value_and_jacobian!!`. The returned Jacobian +has shape `(length(value), length(x))`: + +```@example ad +using AbstractPPL: value_and_jacobian!! + +vecfun(x) = [x[1] * x[2], x[2] + x[3]] + +function AbstractPPL.value_and_jacobian!!( + p::Prepared{AutoForwardDiff}, x::AbstractVector{<:Real} +) + return (p(x), ForwardDiff.jacobian(p.evaluator.f, x)) +end + +prepared_vec = prepare(AutoForwardDiff(), vecfun, zeros(3)) +value_and_jacobian!!(prepared_vec, [2.0, 3.0, 4.0]) +``` + +### NamedTuple inputs + +When the callable accepts a `NamedTuple`, pass a sample `NamedTuple` whose +field names and value types match the expected input. An extension can define a +`prepare` overload that wraps the function in a `NamedTupleEvaluator`: + +```@example ad +function AbstractPPL.prepare(adtype::AutoForwardDiff, f, values::NamedTuple) + return Prepared(adtype, NamedTupleEvaluator(f, values)) +end + +ntfun(v::NamedTuple) = v.a^2 + sum(abs2, v.b) +prepared_nt = prepare(AutoForwardDiff(), ntfun, (a=0.0, b=zeros(2))) +prepared_nt((a=1.0, b=[2.0, 3.0])) +``` + +## AD backends + +Automatic differentiation packages extend the interface by implementing +`value_and_gradient!!` and `value_and_jacobian!!` for specific cache types +stored in `prepared.cache`: + +```julia +prepared = prepare(adtype, problem, prototype) # returns Prepared{AD,E,Cache} +value_and_gradient!!(prepared, x) # may return aliased cache buffer +value_and_jacobian!!(prepared, x) +``` + +`Prepared` has three fields: `adtype`, `evaluator` (the user-facing callable), +and `cache` (backend-specific pre-allocated state such as ForwardDiff configs or +Mooncake tapes). Backend extensions dispatch on the cache type: + +```julia +function AbstractPPL.prepare( + adtype::MyADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true +) + f = # extract callable from problem + cache = MyCache(f, x) + return Prepared(adtype, VectorEvaluator{check_dims}(f, length(x)), cache) +end + +function AbstractPPL.value_and_gradient!!( + p::Prepared{<:AbstractADType, <:VectorEvaluator, <:MyCache}, + x::AbstractVector{<:Real}, +) + # use p.cache to avoid allocations + ... +end +``` + +Pass `check_dims=false` in your `prepare` implementation to construct a +`VectorEvaluator{false}`, which skips the per-call length check. AD libraries +that guarantee input shape (ForwardDiff, Mooncake, etc.) should do this to +avoid redundant checks in the dual/shadow hot path. + +## Without an AD backend + +The two-argument form `prepare(problem, x)` is available without any AD package. +It returns the callable unchanged by default, so code that calls `prepare` +unconditionally works regardless of which backends are loaded: + +```@example ad +sumsimple(x) = sum(x) +p = prepare(sumsimple, zeros(3)) +p([1.0, 2.0, 3.0]) +``` + +## API reference + +```@docs +AbstractPPL.prepare +AbstractPPL.value_and_gradient!! +AbstractPPL.value_and_jacobian!! +``` diff --git a/docs/src/pplapi.md b/docs/src/pplapi.md index 492ab294..20a4fcf9 100644 --- a/docs/src/pplapi.md +++ b/docs/src/pplapi.md @@ -18,3 +18,7 @@ evaluate!! ```@docs AbstractModelTrace ``` + +## Evaluators interface + +See [Evaluator preparation and AD](@ref) for a full guide and API reference. diff --git a/ext/AbstractPPLLogDensityProblemsExt.jl b/ext/AbstractPPLLogDensityProblemsExt.jl new file mode 100644 index 00000000..c9b8568e --- /dev/null +++ b/ext/AbstractPPLLogDensityProblemsExt.jl @@ -0,0 +1,28 @@ +module AbstractPPLLogDensityProblemsExt + +using AbstractPPL: AbstractPPL +using AbstractPPL.Evaluators: Prepared, VectorEvaluator +using LogDensityProblems: LogDensityProblems + +LogDensityProblems.logdensity(p::Prepared, x) = p(x) +LogDensityProblems.logdensity(e::VectorEvaluator, x) = e(x) + +LogDensityProblems.dimension(p::Prepared) = LogDensityProblems.dimension(p.evaluator) +LogDensityProblems.dimension(e::VectorEvaluator) = e.dim + +# `Prepared` is the AD-aware shape, so it always advertises gradient capability. +LogDensityProblems.capabilities(::Type{<:Prepared}) = LogDensityProblems.LogDensityOrder{1}() +LogDensityProblems.capabilities(p::Prepared) = LogDensityProblems.capabilities(typeof(p)) + +# A bare `VectorEvaluator` is the no-AD shape; only `Prepared` advertises gradient. +function LogDensityProblems.capabilities(::Type{<:VectorEvaluator}) + return LogDensityProblems.LogDensityOrder{0}() +end +LogDensityProblems.capabilities(e::VectorEvaluator) = LogDensityProblems.capabilities(typeof(e)) + +function LogDensityProblems.logdensity_and_gradient(p::Prepared, x) + val, grad = AbstractPPL.value_and_gradient!!(p, x) + return (val, copy(grad)) +end + +end # module diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index e08ee8b4..eed7c226 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -10,6 +10,13 @@ export AbstractModelTrace include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") +include("evaluators/Evaluators.jl") +using .Evaluators: prepare, value_and_gradient!!, value_and_jacobian!!, Prepared +export Prepared +@static if VERSION >= v"1.11.0" + eval(Meta.parse("public prepare, value_and_gradient!!, value_and_jacobian!!")) +end + include("varname/optic.jl") include("varname/varname.jl") include("varname/subsumes.jl") diff --git a/src/evaluate.jl b/src/evaluate.jl index 57062062..19338565 100644 --- a/src/evaluate.jl +++ b/src/evaluate.jl @@ -5,7 +5,7 @@ Common base type for evaluation contexts. """ abstract type AbstractContext end -""" +""" evaluate!! General API for model operations, e.g. prior evaluation, log density, log joint etc. diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl new file mode 100644 index 00000000..d8fb7aa6 --- /dev/null +++ b/src/evaluators/Evaluators.jl @@ -0,0 +1,216 @@ +module Evaluators + +using ADTypes: AbstractADType +import ..evaluate!! + +include("utils.jl") + +""" + Prepared{AD<:AbstractADType,E,C}(adtype, evaluator, cache) + Prepared(adtype, evaluator) # cache defaults to `nothing` + +AD-prepared evaluator parameterised by backend type `AD`. + +- `adtype` — the backend, used for dispatch. +- `evaluator` — the user-facing callable (typically a `VectorEvaluator` or + `NamedTupleEvaluator`); forwarded on `p(x)`. +- `cache` — backend-specific pre-allocated state (ForwardDiff configs, Mooncake + caches, DifferentiationInterface preps, etc.). `Nothing` when the backend requires + no cached state. + +Extension packages implement `value_and_gradient!!` (and optionally +`value_and_jacobian!!`) by specialising on the `cache` type: + +```julia +function AbstractPPL.value_and_gradient!!( + p::Prepared{<:AbstractADType, <:VectorEvaluator, <:MyCache}, x::AbstractVector +) + ... +end +``` +""" +struct Prepared{AD<:AbstractADType,E,C} + adtype::AD + evaluator::E + cache::C +end + +Prepared(adtype::AbstractADType, evaluator) = Prepared(adtype, evaluator, nothing) + +(p::Prepared)(x) = p.evaluator(x) + +""" + prepare(problem, values::NamedTuple; check_dims::Bool=true) + prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true) + prepare(adtype, problem, x::AbstractVector{<:Real}; check_dims::Bool=true) + +Prepare a callable evaluator for `problem`. + +Use the two-argument form with a `NamedTuple` when the evaluator works with +named inputs, or with a vector when it works with vector inputs. The +three-argument form, contributed by AD-backend extensions, additionally +prepares gradient or jacobian machinery for vector inputs. + +`check_dims` (default `true`) is forwarded to the evaluator constructor by +AD-backend extensions (three-argument form). Pass `check_dims=false` to skip +per-call shape validation, e.g. when the AD backend already guarantees the +input shape. The two-argument stubs ignore this keyword. +""" +function prepare end + +# Downstream packages (e.g. DynamicPPL) pass already-callable objects, +# so the safe default is to return them unchanged. +prepare(problem, values::NamedTuple; check_dims::Bool=true) = problem +prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true) = problem + +""" + value_and_gradient!!(prepared, x::AbstractVector{<:Real}) + +Return `(value, gradient)` for a scalar-valued evaluator, potentially reusing +internal cache buffers of `prepared`. The returned gradient may alias +`prepared`'s internal storage; copy if you need to retain it past the next call. +""" +function value_and_gradient!! end + +""" + value_and_jacobian!!(prepared, x::AbstractVector{<:Real}) + +Return `(value::AbstractVector, jacobian::AbstractMatrix)` for a vector-valued +evaluator, potentially reusing internal cache buffers. The returned arrays may +alias `prepared`'s internal storage; copy if needed. +The Jacobian has shape `(length(value), length(x))`. +""" +function value_and_jacobian!! end + +""" + VectorEvaluator{CheckInput}(f, dim) + VectorEvaluator(f, dim) # equivalent to `VectorEvaluator{true}(f, dim)` + +Evaluator shape for scalar functions of a vector input. Part of the extension +author API; end users interact with the wrapping `Prepared` instead. + +`CheckInput` controls whether each call validates the input length. The default +(`true`) is the safe shape exposed via `prepared(x)`. Pass `CheckInput=false` +(via `check_dims=false` in `prepare`) for the callable handed to AD libraries, +where input shape is already guaranteed and the runtime check would persist in +the dual/shadow hot path. + +A bare `VectorEvaluator` is *not* differentiable; gradient capability is the +contract of the wrapping `Prepared` returned by `prepare(adtype, ...)`. +""" +struct VectorEvaluator{CheckInput,F} + f::F + dim::Int + function VectorEvaluator{CheckInput}(f::F, dim::Int) where {CheckInput,F} + CheckInput isa Bool || throw(ArgumentError("`CheckInput` must be a Bool.")) + dim >= 0 || throw(ArgumentError("`dim` must be non-negative, got $dim.")) + return new{CheckInput,F}(f, dim) + end +end + +VectorEvaluator(f, dim::Int) = VectorEvaluator{true}(f, dim) + +""" + NamedTupleEvaluator{CheckInput}(f, inputspec) + NamedTupleEvaluator(f, inputspec) # equivalent to `NamedTupleEvaluator{true}(f, inputspec)` + +Evaluator shape for functions of a `NamedTuple` input with a stable prototype. +Part of the extension author API; end users interact with the wrapping `Prepared`. + +`CheckInput` controls whether each call validates that the input `NamedTuple` +has the same type as the prototype captured during preparation. +""" +struct NamedTupleEvaluator{CheckInput,F,P<:NamedTuple} + f::F + inputspec::P + function NamedTupleEvaluator{CheckInput}( + f::F, inputspec::P + ) where {CheckInput,F,P<:NamedTuple} + CheckInput isa Bool || throw(ArgumentError("`CheckInput` must be a Bool.")) + return new{CheckInput,F,P}(f, inputspec) + end +end + +NamedTupleEvaluator(f, inputspec::NamedTuple) = NamedTupleEvaluator{true}(f, inputspec) + +function (e::VectorEvaluator{true})(x::AbstractVector) + length(x) == e.dim || throw( + DimensionMismatch( + "Expected a vector of length $(e.dim), but got length $(length(x))." + ), + ) + return e.f(x) +end + +(e::VectorEvaluator{false})(x::AbstractVector) = e.f(x) + +function (e::NamedTupleEvaluator{true})(values::NamedTuple) + _assert_namedtuple_shape(e, values) + return e.f(values) +end +(e::NamedTupleEvaluator{false})(values::NamedTuple) = e.f(values) + +# Reject integer vectors with a clear error rather than letting them flow into +# AD backends (which usually fail confusingly). Split per `CheckInput` to avoid +# an ambiguity with the `(::VectorEvaluator{true})(::AbstractVector)` method above. +function _reject_integer_input(::VectorEvaluator, x) + throw( + ArgumentError( + "VectorEvaluator requires a vector of floating-point values, but received an `$(typeof(x))`. Convert to a floating-point vector (e.g. `Float64.(x)`) before calling.", + ), + ) +end +(e::VectorEvaluator{true})(x::AbstractVector{<:Integer}) = _reject_integer_input(e, x) +(e::VectorEvaluator{false})(x::AbstractVector{<:Integer}) = _reject_integer_input(e, x) + +""" + _assert_namedtuple_shape(e::NamedTupleEvaluator, values) + +Throw `ArgumentError` unless `values` has the same type as the prototype captured +during preparation. No-op when `e` was constructed with `CheckInput=false`. +""" +function _assert_namedtuple_shape(e::NamedTupleEvaluator{true}, values) + typeof(values) === typeof(e.inputspec) || throw( + ArgumentError( + "Expected the same NamedTuple structure that was used to prepare this evaluator.", + ), + ) + return nothing +end +_assert_namedtuple_shape(::NamedTupleEvaluator{false}, _) = nothing + +function _assert_jacobian_output(y) + y isa AbstractVector || throw( + ArgumentError( + "`value_and_jacobian!!` requires the prepared function to return an AbstractVector; got $(typeof(y)).", + ), + ) + return nothing +end + +function _assert_supported_output(y) + (y isa Number || y isa AbstractVector) || throw( + ArgumentError( + "A prepared AD evaluator must return a scalar or AbstractVector; got $(typeof(y)).", + ), + ) + return nothing +end + +evaluate!!(p::Prepared, x) = p(x) +evaluate!!(e::VectorEvaluator, x) = e(x) +evaluate!!(e::NamedTupleEvaluator, x) = e(x) + +function __init__() + Base.Experimental.register_error_hint(MethodError) do io, exc, args, kwargs + # `args` are argument types, not values (see `Base.Experimental.show_error_hints`). + if exc.f === prepare && length(args) >= 1 && args[1] <: AbstractADType + print( + io, + "\nCalling `prepare` with an AD backend requires loading the corresponding extension (e.g., `using DifferentiationInterface`).", + ) + end + end +end + +end # module diff --git a/src/evaluators/utils.jl b/src/evaluators/utils.jl new file mode 100644 index 00000000..eab36471 --- /dev/null +++ b/src/evaluators/utils.jl @@ -0,0 +1,143 @@ +# Vectorisation utilities + +# This utility only supports a small structural subset so flattening stays +# predictable and reconstruction can use `x` as the template. + +flat_length(x::Union{Real,Complex}) = 1 +flat_length(x::AbstractArray{<:Union{Real,Complex}}) = length(x) +flat_length(x::Tuple) = mapreduce(flat_length, +, x; init=0) +flat_length(x::NamedTuple) = mapreduce(flat_length, +, values(x); init=0) +flat_length(x) = throw(ArgumentError("This value cannot be flattened into a vector.")) + +flat_eltype(x::Union{Real,Complex}) = typeof(x) +flat_eltype(x::AbstractArray{T}) where {T<:Union{Real,Complex}} = T +flat_eltype(::Tuple{}) = Float64 +flat_eltype(x::Tuple) = mapreduce(flat_eltype, promote_type, x) +flat_eltype(::NamedTuple{(),Tuple{}}) = Float64 +flat_eltype(x::NamedTuple) = mapreduce(flat_eltype, promote_type, values(x)) +flat_eltype(x) = throw(ArgumentError("This value cannot be flattened into a vector.")) + +""" + flatten_to!!(buf, x) + +Flatten `x` into the vector-like buffer `buf`. + +Supported `x` values are: +- `Real` +- `Complex` +- `AbstractArray{<:Union{Real,Complex}}` +- `Tuple` recursively containing supported values +- `NamedTuple` recursively containing supported values + +Pass `nothing` as `buf` to allocate a new vector. +""" +function flatten_to!!(::Nothing, x) + buf = Vector{flat_eltype(x)}(undef, flat_length(x)) + _flatten_to!(buf, x, 1) + return buf +end + +function flatten_to!!(buf::AbstractVector, x) + n = flat_length(x) + length(buf) == n || throw( + DimensionMismatch("Expected a vector of length $n, but got length $(length(buf))."), + ) + _flatten_to!(buf, x, 1) + return buf +end + +function _flatten_to!(buf::AbstractVector, x::Union{Real,Complex}, offset::Int) + buf[offset] = x + return offset + 1 +end + +function _flatten_to!( + buf::AbstractVector, x::AbstractArray{<:Union{Real,Complex}}, offset::Int +) + n = length(x) + copyto!(buf, offset, x, 1, n) + return offset + n +end + +function _flatten_to!(buf::AbstractVector, x::Tuple, offset::Int) + for value in x + offset = _flatten_to!(buf, value, offset) + end + return offset +end + +function _flatten_to!(buf::AbstractVector, x::NamedTuple, offset::Int) + for value in values(x) + offset = _flatten_to!(buf, value, offset) + end + return offset +end + +function _flatten_to!(buf::AbstractVector, x, ::Int) + throw(ArgumentError("This value cannot be flattened into a vector.")) +end + +function _unflatten(x::Union{Real,Complex}, buf::AbstractVector, offset::Int) + return buf[offset], offset + 1 +end + +function _unflatten( + x::AbstractArray{<:Union{Real,Complex}}, buf::AbstractVector, offset::Int +) + n = length(x) + value = similar(x, promote_type(eltype(x), eltype(buf))) + copyto!(value, 1, buf, offset, n) + return value, offset + n +end + +_unflatten(::Tuple{}, buf::AbstractVector, offset::Int) = (), offset + +function _unflatten(x::Tuple, buf::AbstractVector, offset::Int) + first_value, offset = _unflatten(first(x), buf, offset) + rest_value, offset = _unflatten(Base.tail(x), buf, offset) + return (first_value, rest_value...), offset +end + +# Generated to keep the result `NamedTuple` type inferable: a recursive `merge` +# over `Base.tail(Names)` erases parameters and breaks `@inferred` callers. +@generated function _unflatten( + x::NamedTuple{Names}, buf::AbstractVector, offset::Int +) where {Names} + if isempty(Names) + return :((NamedTuple(), offset)) + end + block = Expr(:block, :(off = offset)) + val_syms = Symbol[] + for name in Names + v = gensym(name) + push!(val_syms, v) + push!( + block.args, :(($v, off) = _unflatten(getfield(x, $(QuoteNode(name))), buf, off)) + ) + end + push!(block.args, :(return (NamedTuple{$Names}(($(val_syms...),)), off))) + return block +end + +""" + unflatten_to!!(x, buf) + +Reconstruct a value from the vector-like buffer `buf` using `x` as the structural template. + +Supported `x` values are: +- `Real` +- `Complex` +- `AbstractArray{<:Union{Real,Complex}}` +- `Tuple` recursively containing supported values +- `NamedTuple` recursively containing supported values +""" +# Always allocates: `_unflatten` calls `similar` for each array field. Gains from +# buffer reuse are negligible relative to gradient computation cost. +function unflatten_to!!(x, buf::AbstractVector) + n = flat_length(x) + length(buf) == n || throw( + DimensionMismatch("Expected a vector of length $n, but got length $(length(buf))."), + ) + value, _ = _unflatten(x, buf, 1) + return value +end diff --git a/test/Project.toml b/test/Project.toml index a43fb829..fb93b3bd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" @@ -13,6 +15,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +ADTypes = "1" Accessors = "0.1" Aqua = "0.8" DimensionalData = "0.29, 0.30" diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl new file mode 100644 index 00000000..8966e45e --- /dev/null +++ b/test/evaluators/Evaluators.jl @@ -0,0 +1,123 @@ +using AbstractPPL +using AbstractPPL: prepare, value_and_gradient!!, evaluate!! +using AbstractPPL.Evaluators: Prepared, VectorEvaluator +using ADTypes: ADTypes +using Test + +struct DummyProblem end + +struct DummyPrepared + prototype_keys::Tuple +end + +function AbstractPPL.prepare(problem::DummyProblem, values::NamedTuple) + return DummyPrepared(keys(values)) +end + +function (p::DummyPrepared)(values::NamedTuple) + keys(values) == p.prototype_keys || + error("expected fields $(p.prototype_keys), got $(keys(values))") + return sum(x -> x isa AbstractArray ? sum(x) : x, values) +end + +struct DummyADType <: ADTypes.AbstractADType end + +function AbstractPPL.prepare( + adtype::DummyADType, problem::DummyProblem, x::AbstractVector{<:Real}; + check_dims::Bool=true, +) + f = x -> sum(x) + return Prepared(adtype, VectorEvaluator{check_dims}(f, length(x))) +end + +function AbstractPPL.value_and_gradient!!( + p::Prepared{DummyADType}, x::AbstractVector{<:Real} +) + return (sum(x), ones(length(x))) +end + +@testset "ADProblem interface" begin + @testset "explicit evaluator shapes" begin + ve = AbstractPPL.Evaluators.VectorEvaluator(sum, 3) + @test ve([1.0, 2.0, 3.0]) == 6.0 + @test_throws DimensionMismatch ve([1.0, 2.0]) + @test_throws r"floating-point" ve([1, 2, 3]) + + ne = AbstractPPL.Evaluators.NamedTupleEvaluator( + x -> x.a + sum(x.b), (a=0.0, b=zeros(2)) + ) + @test ne((a=1.0, b=[2.0, 3.0])) == 6.0 + @test ne.inputspec == (a=0.0, b=zeros(2)) + @test_throws MethodError ne([1.0, 2.0, 3.0]) + + # `CheckInput=false` skips the per-call shape checks. + ve_unchecked = AbstractPPL.Evaluators.VectorEvaluator{false}(sum, 3) + @test ve_unchecked([1.0, 2.0]) == 3.0 + + ne_unchecked = AbstractPPL.Evaluators.NamedTupleEvaluator{false}( + x -> 0.0, (a=0.0, b=zeros(2)) + ) + @test AbstractPPL.Evaluators._assert_namedtuple_shape( + ne_unchecked, (totally=:wrong,) + ) === nothing + @test_throws r"same NamedTuple structure" AbstractPPL.Evaluators._assert_namedtuple_shape( + ne, (totally=:wrong,) + ) + end + + @testset "prepare (structural)" begin + problem = DummyProblem() + values = (x=0.0, y=[1.0, 2.0]) + prepared = prepare(problem, values) + @test prepared isa DummyPrepared + @test prepared.prototype_keys == (:x, :y) + + lp = prepared((x=0.5, y=[1.5, 2.5])) + @test lp ≈ 0.5 + 1.5 + 2.5 + + @test_throws ErrorException prepared((a=1.0, b=2.0)) + end + + @testset "prepare (AD-aware)" begin + problem = DummyProblem() + x0 = zeros(3) + adtype = DummyADType() + prepared = prepare(adtype, problem, x0) + @test prepared isa Prepared{DummyADType} + + x = [0.5, 1.5, 2.5] + @test prepared(x) ≈ 0.5 + 1.5 + 2.5 + + val, grad = value_and_gradient!!(prepared, x) + @test val ≈ 0.5 + 1.5 + 2.5 + @test grad ≈ [1.0, 1.0, 1.0] + + # check_dims=false skips the per-call dimension check. + prepared_unchecked = prepare(adtype, problem, x0; check_dims=false) + @test prepared_unchecked([1.0, 2.0]) ≈ 3.0 # wrong length, no error + end + + @testset "missing AD package extensions" begin + problem = DummyProblem() + x0 = zeros(3) + + @test_throws MethodError AbstractPPL.Evaluators.prepare( + ADTypes.AutoEnzyme(), problem, x0 + ) + end + + @testset "evaluate!!" begin + ve = AbstractPPL.Evaluators.VectorEvaluator(sum, 3) + @test evaluate!!(ve, [1.0, 2.0, 3.0]) == 6.0 + @test_throws DimensionMismatch evaluate!!(ve, [1.0, 2.0]) + + ne = AbstractPPL.Evaluators.NamedTupleEvaluator( + x -> x.a + sum(x.b), (a=0.0, b=zeros(2)) + ) + @test evaluate!!(ne, (a=1.0, b=[2.0, 3.0])) == 6.0 + + adtype = DummyADType() + prepared = prepare(adtype, DummyProblem(), zeros(3)) + @test evaluate!!(prepared, [0.5, 1.5, 2.5]) ≈ 4.5 + end +end diff --git a/test/evaluators/utils.jl b/test/evaluators/utils.jl new file mode 100644 index 00000000..c264361e --- /dev/null +++ b/test/evaluators/utils.jl @@ -0,0 +1,74 @@ +using AbstractPPL +using AbstractPPL.Evaluators: flatten_to!!, unflatten_to!! +using Test + +@testset "vectorisation utilities" begin + @testset "scalar round-trip" begin + x = 1.5 + v = flatten_to!!(nothing, x) + @test v == [1.5] + @test unflatten_to!!(x, v) == 1.5 + + z = 1.0 + 2.0im + vz = flatten_to!!(nothing, z) + @test vz == ComplexF64[1.0 + 2.0im] + @test unflatten_to!!(z, vz) == z + end + + @testset "array round-trip" begin + x = [1.0 2.0; 3.0 4.0] + v = flatten_to!!(nothing, x) + @test v == [1.0, 3.0, 2.0, 4.0] + @test unflatten_to!!(x, v) == x + + z = ComplexF64[1.0 + 1.0im, 2.0 + 0.0im] + vz = flatten_to!!(nothing, z) + @test vz == z + @test unflatten_to!!(z, vz) == z + end + + @testset "tuple round-trip" begin + x = (1.0, [2.0, 3.0], (4.0 + 1.0im,)) + v = flatten_to!!(nothing, x) + @test v == ComplexF64[1.0, 2.0, 3.0, 4.0 + 1.0im] + @test unflatten_to!!(x, v) == x + end + + @testset "named tuple round-trip" begin + x = (a=1.0, b=([2.0, 3.0], (c=4.0 + 1.0im,))) + v = flatten_to!!(nothing, x) + @test v == ComplexF64[1.0, 2.0, 3.0, 4.0 + 1.0im] + @test unflatten_to!!(x, v) == x + end + + @testset "buffer length mismatch" begin + @test_throws r"Expected a vector of length 4" flatten_to!!( + Vector{Float64}(undef, 3), zeros(2, 2) + ) + end + + @testset "vector length mismatch" begin + x = (a=1.0, b=[2.0, 3.0]) + @test_throws r"Expected a vector of length 3" unflatten_to!!(x, [1.0, 2.0]) + end + + @testset "edge cases" begin + empty = NamedTuple() + @test flatten_to!!(nothing, empty) == Float64[] + @test unflatten_to!!(empty, Float64[]) == empty + + view_values = (x=@view([1.0, 2.0, 3.0][2:3]),) + flat = flatten_to!!(nothing, view_values) + rebuilt = unflatten_to!!(view_values, flat) + @test collect(rebuilt.x) == [2.0, 3.0] + end + + @testset "unflatten_to!! type stability" begin + @inferred unflatten_to!!((a=1.0, b=2.0, c=3.0), zeros(3)) + @inferred unflatten_to!!((a=1.0, b=[2.0, 3.0], c=3.0), zeros(4)) + @inferred unflatten_to!!((a=(p=1.0, q=2.0), b=3.0), zeros(3)) + @inferred unflatten_to!!((a=1.0, b=(2.0, 3.0)), zeros(3)) + @inferred unflatten_to!!(NamedTuple(), Float64[]) + @inferred unflatten_to!!((1.0, [2.0, 3.0], 3.0), zeros(4)) + end +end diff --git a/test/ext/logdensityproblems/Project.toml b/test/ext/logdensityproblems/Project.toml new file mode 100644 index 00000000..a6d089ab --- /dev/null +++ b/test/ext/logdensityproblems/Project.toml @@ -0,0 +1,11 @@ +[deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +ADTypes = "1" +LogDensityProblems = "2" +julia = "1.10" diff --git a/test/ext/logdensityproblems/logdensityproblems.jl b/test/ext/logdensityproblems/logdensityproblems.jl new file mode 100644 index 00000000..bc31a7a6 --- /dev/null +++ b/test/ext/logdensityproblems/logdensityproblems.jl @@ -0,0 +1,51 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) +Pkg.instantiate() + +using AbstractPPL +using AbstractPPL.Evaluators: Prepared, VectorEvaluator, NamedTupleEvaluator +using ADTypes: AbstractADType, AutoForwardDiff +using LogDensityProblems: LogDensityProblems +using Test + +struct TestADType <: AbstractADType end + +function AbstractPPL.value_and_gradient!!(p::Prepared{TestADType}, x::AbstractVector{<:Real}) + return (p(x), ones(length(x))) +end + +@testset "AbstractPPLLogDensityProblemsExt" begin + @testset "VectorEvaluator" begin + ve = VectorEvaluator(sum, 3) + @test LogDensityProblems.dimension(ve) == 3 + @test LogDensityProblems.logdensity(ve, [1.0, 2.0, 3.0]) == 6.0 + # A bare VectorEvaluator never advertises gradient capability; + # only the wrapping `Prepared` does. + @test LogDensityProblems.capabilities(ve) == LogDensityProblems.LogDensityOrder{0}() + end + + @testset "Prepared advertises gradient" begin + p_vec = Prepared(AutoForwardDiff(), VectorEvaluator(sum, 3)) + @test LogDensityProblems.capabilities(p_vec) == + LogDensityProblems.LogDensityOrder{1}() + @test LogDensityProblems.capabilities(typeof(p_vec)) == + LogDensityProblems.LogDensityOrder{1}() + + p_nt = Prepared( + AutoForwardDiff(), + NamedTupleEvaluator(x -> x.a + sum(x.b), (a=0.0, b=zeros(2))), + ) + @test LogDensityProblems.capabilities(p_nt) == + LogDensityProblems.LogDensityOrder{1}() + end + + @testset "logdensity_and_gradient" begin + f = x -> -0.5 * sum(abs2, x) + p = Prepared(TestADType(), VectorEvaluator(f, 3)) + x = [1.0, 2.0, 3.0] + val, grad = LogDensityProblems.logdensity_and_gradient(p, x) + @test val ≈ f(x) + @test grad ≈ ones(3) + end +end diff --git a/test/run_extras.jl b/test/run_extras.jl new file mode 100644 index 00000000..1a5e8e5a --- /dev/null +++ b/test/run_extras.jl @@ -0,0 +1,15 @@ +# Run a named extension test in its own isolated Julia environment. +# +# Usage (from the repo root): +# LABEL=logdensityproblems julia test/run_extras.jl + +const TEST_SUBDIRS = (logdensityproblems="ext",) +const VALID_LABELS = string.(keys(TEST_SUBDIRS)) + +label = get(ENV, "LABEL", nothing) +label === nothing && error("Set LABEL to one of: $(join(VALID_LABELS, ", "))") +label in VALID_LABELS || + error("Unknown LABEL=$label. Valid options: $(join(VALID_LABELS, ", "))") + +subdir = TEST_SUBDIRS[Symbol(label)] +include(joinpath(@__DIR__, subdir, label, label * ".jl")) diff --git a/test/runtests.jl b/test/runtests.jl index fc8d3980..ad97e108 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,8 @@ const GROUP = get(ENV, "GROUP", "All") if GROUP == "All" || GROUP == "Tests" include("Aqua.jl") include("abstractprobprog.jl") + include("evaluators/Evaluators.jl") + include("evaluators/utils.jl") include("varname/optic.jl") include("varname/varname.jl") include("varname/subsumes.jl") From 0a3a9353b37e42bbb9f8fc0b554691b8f1fc435f Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 30 Apr 2026 13:47:58 +0100 Subject: [PATCH 02/17] Format with JuliaFormatter (blue style) Co-Authored-By: Claude Sonnet 4.6 --- docs/src/evaluators.md | 7 +++---- ext/AbstractPPLLogDensityProblemsExt.jl | 8 ++++++-- test/evaluators/Evaluators.jl | 4 +++- test/ext/logdensityproblems/logdensityproblems.jl | 7 ++++--- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index ac55f930..facc0961 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -100,16 +100,15 @@ function AbstractPPL.prepare( adtype::MyADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true ) f = # extract callable from problem - cache = MyCache(f, x) + cache = MyCache(f, x) return Prepared(adtype, VectorEvaluator{check_dims}(f, length(x)), cache) end function AbstractPPL.value_and_gradient!!( - p::Prepared{<:AbstractADType, <:VectorEvaluator, <:MyCache}, - x::AbstractVector{<:Real}, + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:MyCache}, x::AbstractVector{<:Real} ) # use p.cache to avoid allocations - ... + return ... end ``` diff --git a/ext/AbstractPPLLogDensityProblemsExt.jl b/ext/AbstractPPLLogDensityProblemsExt.jl index c9b8568e..29afbb2c 100644 --- a/ext/AbstractPPLLogDensityProblemsExt.jl +++ b/ext/AbstractPPLLogDensityProblemsExt.jl @@ -11,14 +11,18 @@ LogDensityProblems.dimension(p::Prepared) = LogDensityProblems.dimension(p.evalu LogDensityProblems.dimension(e::VectorEvaluator) = e.dim # `Prepared` is the AD-aware shape, so it always advertises gradient capability. -LogDensityProblems.capabilities(::Type{<:Prepared}) = LogDensityProblems.LogDensityOrder{1}() +function LogDensityProblems.capabilities(::Type{<:Prepared}) + return LogDensityProblems.LogDensityOrder{1}() +end LogDensityProblems.capabilities(p::Prepared) = LogDensityProblems.capabilities(typeof(p)) # A bare `VectorEvaluator` is the no-AD shape; only `Prepared` advertises gradient. function LogDensityProblems.capabilities(::Type{<:VectorEvaluator}) return LogDensityProblems.LogDensityOrder{0}() end -LogDensityProblems.capabilities(e::VectorEvaluator) = LogDensityProblems.capabilities(typeof(e)) +function LogDensityProblems.capabilities(e::VectorEvaluator) + return LogDensityProblems.capabilities(typeof(e)) +end function LogDensityProblems.logdensity_and_gradient(p::Prepared, x) val, grad = AbstractPPL.value_and_gradient!!(p, x) diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl index 8966e45e..35dce1a1 100644 --- a/test/evaluators/Evaluators.jl +++ b/test/evaluators/Evaluators.jl @@ -23,7 +23,9 @@ end struct DummyADType <: ADTypes.AbstractADType end function AbstractPPL.prepare( - adtype::DummyADType, problem::DummyProblem, x::AbstractVector{<:Real}; + adtype::DummyADType, + problem::DummyProblem, + x::AbstractVector{<:Real}; check_dims::Bool=true, ) f = x -> sum(x) diff --git a/test/ext/logdensityproblems/logdensityproblems.jl b/test/ext/logdensityproblems/logdensityproblems.jl index bc31a7a6..b43f2759 100644 --- a/test/ext/logdensityproblems/logdensityproblems.jl +++ b/test/ext/logdensityproblems/logdensityproblems.jl @@ -11,7 +11,9 @@ using Test struct TestADType <: AbstractADType end -function AbstractPPL.value_and_gradient!!(p::Prepared{TestADType}, x::AbstractVector{<:Real}) +function AbstractPPL.value_and_gradient!!( + p::Prepared{TestADType}, x::AbstractVector{<:Real} +) return (p(x), ones(length(x))) end @@ -33,8 +35,7 @@ end LogDensityProblems.LogDensityOrder{1}() p_nt = Prepared( - AutoForwardDiff(), - NamedTupleEvaluator(x -> x.a + sum(x.b), (a=0.0, b=zeros(2))), + AutoForwardDiff(), NamedTupleEvaluator(x -> x.a + sum(x.b), (a=0.0, b=zeros(2))) ) @test LogDensityProblems.capabilities(p_nt) == LogDensityProblems.LogDensityOrder{1}() From 6fba5b13131f4577d0bcd8aa6858a1c74ebd257f Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 30 Apr 2026 13:52:38 +0100 Subject: [PATCH 03/17] Fix docs example dispatch; rename test fixtures away from Problem MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - docs/src/evaluators.md: use `Prepared{<:AutoForwardDiff}` instead of `Prepared{AutoForwardDiff}` so dispatch matches the parametric instance type - test/evaluators/Evaluators.jl: rename DummyProblem → DummyModel, rename the testset and parameter names to drop ADProblems-era terminology Co-Authored-By: Claude Sonnet 4.6 --- docs/src/evaluators.md | 4 ++-- test/evaluators/Evaluators.jl | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index facc0961..85f90f78 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -24,7 +24,7 @@ function AbstractPPL.prepare(adtype::AutoForwardDiff, f, x::AbstractVector{<:Rea end function AbstractPPL.value_and_gradient!!( - p::Prepared{AutoForwardDiff}, x::AbstractVector{<:Real} + p::Prepared{<:AutoForwardDiff}, x::AbstractVector{<:Real} ) return (p(x), ForwardDiff.gradient(p.evaluator.f, x)) end @@ -54,7 +54,7 @@ using AbstractPPL: value_and_jacobian!! vecfun(x) = [x[1] * x[2], x[2] + x[3]] function AbstractPPL.value_and_jacobian!!( - p::Prepared{AutoForwardDiff}, x::AbstractVector{<:Real} + p::Prepared{<:AutoForwardDiff}, x::AbstractVector{<:Real} ) return (p(x), ForwardDiff.jacobian(p.evaluator.f, x)) end diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl index 35dce1a1..8785f60e 100644 --- a/test/evaluators/Evaluators.jl +++ b/test/evaluators/Evaluators.jl @@ -4,13 +4,13 @@ using AbstractPPL.Evaluators: Prepared, VectorEvaluator using ADTypes: ADTypes using Test -struct DummyProblem end +struct DummyModel end struct DummyPrepared prototype_keys::Tuple end -function AbstractPPL.prepare(problem::DummyProblem, values::NamedTuple) +function AbstractPPL.prepare(model::DummyModel, values::NamedTuple) return DummyPrepared(keys(values)) end @@ -24,7 +24,7 @@ struct DummyADType <: ADTypes.AbstractADType end function AbstractPPL.prepare( adtype::DummyADType, - problem::DummyProblem, + model::DummyModel, x::AbstractVector{<:Real}; check_dims::Bool=true, ) @@ -38,7 +38,7 @@ function AbstractPPL.value_and_gradient!!( return (sum(x), ones(length(x))) end -@testset "ADProblem interface" begin +@testset "Evaluators interface" begin @testset "explicit evaluator shapes" begin ve = AbstractPPL.Evaluators.VectorEvaluator(sum, 3) @test ve([1.0, 2.0, 3.0]) == 6.0 @@ -68,9 +68,9 @@ end end @testset "prepare (structural)" begin - problem = DummyProblem() + model = DummyModel() values = (x=0.0, y=[1.0, 2.0]) - prepared = prepare(problem, values) + prepared = prepare(model, values) @test prepared isa DummyPrepared @test prepared.prototype_keys == (:x, :y) @@ -81,10 +81,10 @@ end end @testset "prepare (AD-aware)" begin - problem = DummyProblem() + model = DummyModel() x0 = zeros(3) adtype = DummyADType() - prepared = prepare(adtype, problem, x0) + prepared = prepare(adtype, model, x0) @test prepared isa Prepared{DummyADType} x = [0.5, 1.5, 2.5] @@ -95,16 +95,16 @@ end @test grad ≈ [1.0, 1.0, 1.0] # check_dims=false skips the per-call dimension check. - prepared_unchecked = prepare(adtype, problem, x0; check_dims=false) + prepared_unchecked = prepare(adtype, model, x0; check_dims=false) @test prepared_unchecked([1.0, 2.0]) ≈ 3.0 # wrong length, no error end @testset "missing AD package extensions" begin - problem = DummyProblem() + model = DummyModel() x0 = zeros(3) @test_throws MethodError AbstractPPL.Evaluators.prepare( - ADTypes.AutoEnzyme(), problem, x0 + ADTypes.AutoEnzyme(), model, x0 ) end @@ -119,7 +119,7 @@ end @test evaluate!!(ne, (a=1.0, b=[2.0, 3.0])) == 6.0 adtype = DummyADType() - prepared = prepare(adtype, DummyProblem(), zeros(3)) + prepared = prepare(adtype, DummyModel(), zeros(3)) @test evaluate!!(prepared, [0.5, 1.5, 2.5]) ≈ 4.5 end end From d632db577820e7b57448b0ed7cb659c3357dfd27 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 30 Apr 2026 13:54:46 +0100 Subject: [PATCH 04/17] format --- test/evaluators/Evaluators.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl index 8785f60e..20425aba 100644 --- a/test/evaluators/Evaluators.jl +++ b/test/evaluators/Evaluators.jl @@ -23,10 +23,7 @@ end struct DummyADType <: ADTypes.AbstractADType end function AbstractPPL.prepare( - adtype::DummyADType, - model::DummyModel, - x::AbstractVector{<:Real}; - check_dims::Bool=true, + adtype::DummyADType, model::DummyModel, x::AbstractVector{<:Real}; check_dims::Bool=true ) f = x -> sum(x) return Prepared(adtype, VectorEvaluator{check_dims}(f, length(x))) From 154b0091d5605ebbf577e4e1688eff5eb1256b0e Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 30 Apr 2026 14:02:32 +0100 Subject: [PATCH 05/17] Make Prepared strictly internal to AbstractPPL.Evaluators - Remove `export Prepared` and the `using .Evaluators: Prepared` re-import - Update docs example to import `Prepared` from `AbstractPPL.Evaluators` - Add a brief comment on the `evaluate!!` overloads Co-Authored-By: Claude Sonnet 4.6 --- docs/src/evaluators.md | 4 ++-- src/AbstractPPL.jl | 3 +-- src/evaluators/Evaluators.jl | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index 85f90f78..bf40fe1d 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -14,8 +14,8 @@ the result past the next call. ```@example ad using AbstractPPL -using AbstractPPL: prepare, value_and_gradient!!, Prepared -using AbstractPPL.Evaluators: VectorEvaluator, NamedTupleEvaluator +using AbstractPPL: prepare, value_and_gradient!! +using AbstractPPL.Evaluators: Prepared, VectorEvaluator, NamedTupleEvaluator using ADTypes: AutoForwardDiff using ForwardDiff: ForwardDiff diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index eed7c226..48a2cb44 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -11,8 +11,7 @@ include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") include("evaluators/Evaluators.jl") -using .Evaluators: prepare, value_and_gradient!!, value_and_jacobian!!, Prepared -export Prepared +using .Evaluators: prepare, value_and_gradient!!, value_and_jacobian!! @static if VERSION >= v"1.11.0" eval(Meta.parse("public prepare, value_and_gradient!!, value_and_jacobian!!")) end diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index d8fb7aa6..5773b14f 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -197,6 +197,7 @@ function _assert_supported_output(y) return nothing end +# Make prepared evaluators usable through the same `evaluate!!` API as models. evaluate!!(p::Prepared, x) = p(x) evaluate!!(e::VectorEvaluator, x) = e(x) evaluate!!(e::NamedTupleEvaluator, x) = e(x) From 0db151a45b97438a469d50b0d13c3eab12552c4a Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 30 Apr 2026 16:01:01 +0100 Subject: [PATCH 06/17] Restrict LDP integration to vector-input evaluators; default capabilities to order 0 - Dispatch all LDP methods on `Prepared{<:Any,<:VectorEvaluator}` so `NamedTupleEvaluator`-backed `Prepared` does not match (it cannot satisfy LDP's vector-input contract). - Default `capabilities` returns `LogDensityOrder{0}`; AD-backend extensions opt into `LogDensityOrder{1}` by overloading on their cache type. Without the overload, `logdensity_and_gradient` would hit the `value_and_gradient!!` stub and fail. - Test: demonstrate the backend opt-in pattern (TestADType overloads `capabilities`), and verify NamedTupleEvaluator-backed `Prepared` has no LDP methods defined (`dimension` throws, `capabilities` falls through to LDP's `nothing` default). Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLLogDensityProblemsExt.jl | 26 +++++++++++----- .../logdensityproblems/logdensityproblems.jl | 30 +++++++++++++++---- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/ext/AbstractPPLLogDensityProblemsExt.jl b/ext/AbstractPPLLogDensityProblemsExt.jl index 29afbb2c..fd4a2585 100644 --- a/ext/AbstractPPLLogDensityProblemsExt.jl +++ b/ext/AbstractPPLLogDensityProblemsExt.jl @@ -4,19 +4,29 @@ using AbstractPPL: AbstractPPL using AbstractPPL.Evaluators: Prepared, VectorEvaluator using LogDensityProblems: LogDensityProblems -LogDensityProblems.logdensity(p::Prepared, x) = p(x) +# LDP integration is restricted to vector-input evaluators; `NamedTupleEvaluator` +# does not satisfy LDP's vector-input contract. Scalar output is a runtime +# contract the user must satisfy. + +LogDensityProblems.logdensity(p::Prepared{<:Any,<:VectorEvaluator}, x) = p(x) LogDensityProblems.logdensity(e::VectorEvaluator, x) = e(x) -LogDensityProblems.dimension(p::Prepared) = LogDensityProblems.dimension(p.evaluator) +function LogDensityProblems.dimension(p::Prepared{<:Any,<:VectorEvaluator}) + return LogDensityProblems.dimension(p.evaluator) +end LogDensityProblems.dimension(e::VectorEvaluator) = e.dim -# `Prepared` is the AD-aware shape, so it always advertises gradient capability. -function LogDensityProblems.capabilities(::Type{<:Prepared}) - return LogDensityProblems.LogDensityOrder{1}() +# Generic fallback: order 0. AD-backend extensions (DifferentiationInterface, +# ForwardDiff, Mooncake, etc.) must overload this for their cache type to +# advertise `LogDensityOrder{1}` — without that overload, +# `logdensity_and_gradient` will hit the `value_and_gradient!!` stub and fail. +function LogDensityProblems.capabilities(::Type{<:Prepared{<:Any,<:VectorEvaluator}}) + return LogDensityProblems.LogDensityOrder{0}() +end +function LogDensityProblems.capabilities(p::Prepared{<:Any,<:VectorEvaluator}) + return LogDensityProblems.capabilities(typeof(p)) end -LogDensityProblems.capabilities(p::Prepared) = LogDensityProblems.capabilities(typeof(p)) -# A bare `VectorEvaluator` is the no-AD shape; only `Prepared` advertises gradient. function LogDensityProblems.capabilities(::Type{<:VectorEvaluator}) return LogDensityProblems.LogDensityOrder{0}() end @@ -24,7 +34,7 @@ function LogDensityProblems.capabilities(e::VectorEvaluator) return LogDensityProblems.capabilities(typeof(e)) end -function LogDensityProblems.logdensity_and_gradient(p::Prepared, x) +function LogDensityProblems.logdensity_and_gradient(p::Prepared{<:Any,<:VectorEvaluator}, x) val, grad = AbstractPPL.value_and_gradient!!(p, x) return (val, copy(grad)) end diff --git a/test/ext/logdensityproblems/logdensityproblems.jl b/test/ext/logdensityproblems/logdensityproblems.jl index b43f2759..c0176035 100644 --- a/test/ext/logdensityproblems/logdensityproblems.jl +++ b/test/ext/logdensityproblems/logdensityproblems.jl @@ -9,6 +9,9 @@ using ADTypes: AbstractADType, AutoForwardDiff using LogDensityProblems: LogDensityProblems using Test +# A NamedTupleEvaluator does not satisfy LDP's vector-input contract, so the +# extension does not define LDP methods for it. + struct TestADType <: AbstractADType end function AbstractPPL.value_and_gradient!!( @@ -17,6 +20,13 @@ function AbstractPPL.value_and_gradient!!( return (p(x), ones(length(x))) end +# Backend extensions opt into gradient capability by overloading `capabilities` +# (typically on their cache type, e.g. `<:Prepared{<:Any,<:VectorEvaluator,<:MyCache}`). +# Here we dispatch on the AD type for simplicity. +function LogDensityProblems.capabilities(::Type{<:Prepared{TestADType,<:VectorEvaluator}}) + return LogDensityProblems.LogDensityOrder{1}() +end + @testset "AbstractPPLLogDensityProblemsExt" begin @testset "VectorEvaluator" begin ve = VectorEvaluator(sum, 3) @@ -27,18 +37,26 @@ end @test LogDensityProblems.capabilities(ve) == LogDensityProblems.LogDensityOrder{0}() end - @testset "Prepared advertises gradient" begin - p_vec = Prepared(AutoForwardDiff(), VectorEvaluator(sum, 3)) - @test LogDensityProblems.capabilities(p_vec) == + @testset "Prepared capabilities" begin + # Without a backend overload the fallback advertises order 0 only. + p_no_overload = Prepared(AutoForwardDiff(), VectorEvaluator(sum, 3)) + @test LogDensityProblems.capabilities(p_no_overload) == + LogDensityProblems.LogDensityOrder{0}() + + # A backend that overloads capabilities advertises order 1. + p_overloaded = Prepared(TestADType(), VectorEvaluator(sum, 3)) + @test LogDensityProblems.capabilities(p_overloaded) == LogDensityProblems.LogDensityOrder{1}() - @test LogDensityProblems.capabilities(typeof(p_vec)) == + @test LogDensityProblems.capabilities(typeof(p_overloaded)) == LogDensityProblems.LogDensityOrder{1}() + # NamedTupleEvaluator-backed Prepared has no LDP methods defined; the + # extension only integrates vector-input evaluators. p_nt = Prepared( AutoForwardDiff(), NamedTupleEvaluator(x -> x.a + sum(x.b), (a=0.0, b=zeros(2))) ) - @test LogDensityProblems.capabilities(p_nt) == - LogDensityProblems.LogDensityOrder{1}() + @test_throws MethodError LogDensityProblems.dimension(p_nt) + @test LogDensityProblems.capabilities(p_nt) === nothing end @testset "logdensity_and_gradient" begin From d0ebb6c63f606a298fc8e642ca315ab30fd32801 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 11:03:33 +0100 Subject: [PATCH 07/17] Preserve leaf types in unflatten_to!! round-trip Reconstruct each leaf using `x`'s types rather than `eltype(buf)`, so a heterogeneous round-trip preserves `typeof(x)` instead of widening to the flat buffer's eltype. Add an opt-in `check_eltype` keyword that warns when `eltype(buf)` differs from `flat_eltype(x)`. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/evaluators/utils.jl | 26 ++++++++++++++++++++++---- test/evaluators/utils.jl | 23 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/evaluators/utils.jl b/src/evaluators/utils.jl index eab36471..8d57a2fc 100644 --- a/src/evaluators/utils.jl +++ b/src/evaluators/utils.jl @@ -78,14 +78,14 @@ function _flatten_to!(buf::AbstractVector, x, ::Int) end function _unflatten(x::Union{Real,Complex}, buf::AbstractVector, offset::Int) - return buf[offset], offset + 1 + return convert(typeof(x), buf[offset]), offset + 1 end function _unflatten( x::AbstractArray{<:Union{Real,Complex}}, buf::AbstractVector, offset::Int ) n = length(x) - value = similar(x, promote_type(eltype(x), eltype(buf))) + value = similar(x) copyto!(value, 1, buf, offset, n) return value, offset + n end @@ -120,7 +120,7 @@ end end """ - unflatten_to!!(x, buf) + unflatten_to!!(x, buf; check_eltype::Bool=false) Reconstruct a value from the vector-like buffer `buf` using `x` as the structural template. @@ -130,14 +130,32 @@ Supported `x` values are: - `AbstractArray{<:Union{Real,Complex}}` - `Tuple` recursively containing supported values - `NamedTuple` recursively containing supported values + +Pass `check_eltype=true` to emit a warning when `eltype(buf)` differs from +`flat_eltype(x)` (off by default to keep hot paths quiet). """ # Always allocates: `_unflatten` calls `similar` for each array field. Gains from # buffer reuse are negligible relative to gradient computation cost. -function unflatten_to!!(x, buf::AbstractVector) +# +# Heterogeneous round-trip: the flat buffer widens, but leaves are rebuilt +# from `x`'s types, so `typeof(x2) == typeof(x)`. E.g. +# +# x = (1.0, [2.0, 3.0], (4.0 + 1.0im,)) # buffer widens to ComplexF64 +# x2 = unflatten_to!!(x, flatten_to!!(nothing, x)) +# # x2 == (1.0, [2.0, 3.0], (4.0 + 1.0im,)) +# # x2 == x → true +# # typeof(x2) == typeof(x) → true +function unflatten_to!!(x, buf::AbstractVector; check_eltype::Bool=false) n = flat_length(x) length(buf) == n || throw( DimensionMismatch("Expected a vector of length $n, but got length $(length(buf))."), ) + if check_eltype + expected = flat_eltype(x) + eltype(buf) === expected || @warn( + "Buffer eltype `$(eltype(buf))` differs from `flat_eltype(x) = $expected`; reconstructing using the leaf types from `x`." + ) + end value, _ = _unflatten(x, buf, 1) return value end diff --git a/test/evaluators/utils.jl b/test/evaluators/utils.jl index c264361e..41f1c8e0 100644 --- a/test/evaluators/utils.jl +++ b/test/evaluators/utils.jl @@ -41,6 +41,29 @@ using Test @test unflatten_to!!(x, v) == x end + @testset "heterogeneous container preserves leaf types" begin + # The flat buffer widens to ComplexF64, but `unflatten_to!!` rebuilds + # leaves using `x`'s types, so the round-trip preserves `typeof(x)`. + x = (1.0, [2.0, 3.0], (4.0 + 1.0im,)) + x2 = unflatten_to!!(x, flatten_to!!(nothing, x)) + @test x2 == x + @test typeof(x2) == typeof(x) + end + + @testset "check_eltype opt-in warning" begin + x = (a=1.0, b=[2.0, 3.0]) + # buf eltype is ComplexF64, but `flat_eltype(x) == Float64` — a + # mismatch that should warn only with `check_eltype=true`. Imag parts + # are zero, so the convert succeeds. + buf = ComplexF64[1.0, 2.0, 3.0] + x2 = @test_logs unflatten_to!!(x, buf) # default: silent + @test x2 == x + @test typeof(x2) == typeof(x) + x3 = @test_logs (:warn, r"differs from") unflatten_to!!(x, buf; check_eltype=true) + @test x3 == x + @test typeof(x3) == typeof(x) + end + @testset "buffer length mismatch" begin @test_throws r"Expected a vector of length 4" flatten_to!!( Vector{Float64}(undef, 3), zeros(2, 2) From 73c97e96f32d5f8f486b8ea297d6996f56f6478b Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 11:06:05 +0100 Subject: [PATCH 08/17] Reframe check_dims rationale around the outer entry point The previous wording attributed the input-shape guarantee to AD libraries, but ForwardDiff/Mooncake do not validate user input. The actual contract is that the outer entry point validates `length(x)` once, and dual arrays derived from that `x` then flow through the AD loop without re-checking. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/src/evaluators.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index bf40fe1d..c9bbbb52 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -113,9 +113,11 @@ end ``` Pass `check_dims=false` in your `prepare` implementation to construct a -`VectorEvaluator{false}`, which skips the per-call length check. AD libraries -that guarantee input shape (ForwardDiff, Mooncake, etc.) should do this to -avoid redundant checks in the dual/shadow hot path. +`VectorEvaluator{false}`, which skips the per-call length check. The outer +entry point (`prepared(x)` or `value_and_gradient!!(prepared, x)`) already +validates `length(x)` once, and the AD differentiation loop then invokes the +inner callable many times with same-length dual arrays derived from that +`x` — re-checking on each invocation is redundant work in the hot path. ## Without an AD backend From c2b84b28f8b3829a8208ae23feb6bbf8cf99aec8 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 11:08:30 +0100 Subject: [PATCH 09/17] Clarify check_dims trust model and no-AD interface rationale - check_dims=false is an opt-in trust mode where the caller takes responsibility for length(x); the AD inner-loop motivation is now framed as a typical use, not a guaranteed validation chain. - The no-AD `prepare(problem, x)` form exists so callers can write `prepare(...)` uniformly without knowing whether a backend is loaded, and downstream primal-only consumers can accept the result either way. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/src/evaluators.md | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index c9bbbb52..1b3e14ec 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -113,17 +113,20 @@ end ``` Pass `check_dims=false` in your `prepare` implementation to construct a -`VectorEvaluator{false}`, which skips the per-call length check. The outer -entry point (`prepared(x)` or `value_and_gradient!!(prepared, x)`) already -validates `length(x)` once, and the AD differentiation loop then invokes the -inner callable many times with same-length dual arrays derived from that -`x` — re-checking on each invocation is redundant work in the hot path. +`VectorEvaluator{false}`, which skips the per-call length check. This is an +opt-in trust mode — the caller takes responsibility for `length(x)`. The +typical use is inside a backend's `value_and_gradient!!`, where the AD +library invokes the inner callable many times with same-length dual arrays +derived from a single user-supplied `x`; re-validating on each invocation +would be redundant work in the hot path. ## Without an AD backend -The two-argument form `prepare(problem, x)` is available without any AD package. -It returns the callable unchanged by default, so code that calls `prepare` -unconditionally works regardless of which backends are loaded: +The two-argument form `prepare(problem, x)` is available without any AD +package. It returns the callable unchanged by default, so the caller doesn't +need to know whether an AD backend is loaded — the same `prepare(...)` call +works either way, and downstream code that only needs primal evaluation +(e.g. log-density only, no gradient) can accept the result uniformly: ```@example ad sumsimple(x) = sum(x) From 3ad1396e02347cc72b87012a5e7582a5d435cebb Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 11:10:35 +0100 Subject: [PATCH 10/17] Document why the AD output-shape assertions live here These helpers are unused inside this PR but are intended for AD-backend extensions to share, so each `value_and_gradient!!` / `value_and_jacobian!!` produces a uniform error rather than rolling its own. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/evaluators/Evaluators.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index 5773b14f..e5993132 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -179,6 +179,9 @@ function _assert_namedtuple_shape(e::NamedTupleEvaluator{true}, values) end _assert_namedtuple_shape(::NamedTupleEvaluator{false}, _) = nothing +# Output-shape assertions for AD-backend extensions to share. Centralised here +# so each backend's `value_and_gradient!!` / `value_and_jacobian!!` produces +# the same error message rather than rolling its own. function _assert_jacobian_output(y) y isa AbstractVector || throw( ArgumentError( From 6c9e4f26415c633eac8c31c2a5437d518fdd357e Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 11:17:02 +0100 Subject: [PATCH 11/17] Address remaining PR review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test/run_extras.jl: drop the label→subdir mapping, use the label as the path directly, rename the included test file to main.jl - src/evaluators/utils.jl: - simplify flat_length tuple/namedtuple cases to sum(flat_length, ...) - replace getfield with x[name] in the _unflatten generator - reject Symmetric/Hermitian/Diagonal/Triangular/Tri/Bi-diagonal arrays up front rather than emitting broken round-trip results - src/evaluators/Evaluators.jl: fold integer-input rejection into the VectorEvaluator method body (parametrise on T, branch on T<:Integer); drop the unused first arg of _reject_integer_input - docs/src/evaluators.md: fix the f=...\\ncache=... indentation typo; expand the !! aliasing note with an example showing safe usage Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/CI.yml | 2 +- docs/src/evaluators.md | 21 ++++++++--- src/evaluators/Evaluators.jl | 32 +++++++++-------- src/evaluators/utils.jl | 35 ++++++++++++++++--- test/evaluators/utils.jl | 12 +++++++ .../{logdensityproblems.jl => main.jl} | 0 test/run_extras.jl | 8 ++--- 7 files changed, 80 insertions(+), 30 deletions(-) rename test/ext/logdensityproblems/{logdensityproblems.jl => main.jl} (100%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ef4049b5..60609e16 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -71,4 +71,4 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - run: julia --project=. test/run_extras.jl env: - LABEL: logdensityproblems + LABEL: ext/logdensityproblems diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index 1b3e14ec..9522b9d4 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -7,8 +7,21 @@ sample input that establishes the expected input shape and type; derivative together. The `!!` suffix signals that the returned gradient or Jacobian **may alias -internal cache buffers** of the prepared evaluator. Copy if you need to retain -the result past the next call. +internal cache buffers** of the prepared evaluator. The next call to +`value_and_gradient!!` (or `value_and_jacobian!!`) may overwrite that buffer +in place, so a previously-returned reference will silently change. Copy +before holding on to a result: + +```julia +val, grad = value_and_gradient!!(prepared, x1) +saved = copy(grad) # safe to keep +val2, grad2 = value_and_gradient!!(prepared, x2) +# `grad` may now reflect `x2`; `saved` still reflects `x1` +``` + +Backends that always allocate fresh output (e.g. `ForwardDiff.gradient`) do +not actually alias, but consumers should not rely on that — write to the +contract, not the implementation. ## Quick start @@ -99,8 +112,8 @@ Mooncake tapes). Backend extensions dispatch on the cache type: function AbstractPPL.prepare( adtype::MyADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true ) - f = # extract callable from problem - cache = MyCache(f, x) + f = ... # extract callable from problem + cache = MyCache(f, x) return Prepared(adtype, VectorEvaluator{check_dims}(f, length(x)), cache) end diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index e5993132..be582b1f 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -133,7 +133,19 @@ end NamedTupleEvaluator(f, inputspec::NamedTuple) = NamedTupleEvaluator{true}(f, inputspec) -function (e::VectorEvaluator{true})(x::AbstractVector) +# Reject integer vectors with a clear error rather than letting them flow into +# AD backends (which usually fail confusingly). The `T <: Integer` branch is +# resolved at compile time, so non-integer inputs pay nothing. +function _reject_integer_input(x) + throw( + ArgumentError( + "VectorEvaluator requires a vector of floating-point values, but received an `$(typeof(x))`. Convert to a floating-point vector (e.g. `Float64.(x)`) before calling.", + ), + ) +end + +function (e::VectorEvaluator{true})(x::AbstractVector{T}) where {T} + T <: Integer && _reject_integer_input(x) length(x) == e.dim || throw( DimensionMismatch( "Expected a vector of length $(e.dim), but got length $(length(x))." @@ -142,7 +154,10 @@ function (e::VectorEvaluator{true})(x::AbstractVector) return e.f(x) end -(e::VectorEvaluator{false})(x::AbstractVector) = e.f(x) +function (e::VectorEvaluator{false})(x::AbstractVector{T}) where {T} + T <: Integer && _reject_integer_input(x) + return e.f(x) +end function (e::NamedTupleEvaluator{true})(values::NamedTuple) _assert_namedtuple_shape(e, values) @@ -150,19 +165,6 @@ function (e::NamedTupleEvaluator{true})(values::NamedTuple) end (e::NamedTupleEvaluator{false})(values::NamedTuple) = e.f(values) -# Reject integer vectors with a clear error rather than letting them flow into -# AD backends (which usually fail confusingly). Split per `CheckInput` to avoid -# an ambiguity with the `(::VectorEvaluator{true})(::AbstractVector)` method above. -function _reject_integer_input(::VectorEvaluator, x) - throw( - ArgumentError( - "VectorEvaluator requires a vector of floating-point values, but received an `$(typeof(x))`. Convert to a floating-point vector (e.g. `Float64.(x)`) before calling.", - ), - ) -end -(e::VectorEvaluator{true})(x::AbstractVector{<:Integer}) = _reject_integer_input(e, x) -(e::VectorEvaluator{false})(x::AbstractVector{<:Integer}) = _reject_integer_input(e, x) - """ _assert_namedtuple_shape(e::NamedTupleEvaluator, values) diff --git a/src/evaluators/utils.jl b/src/evaluators/utils.jl index 8d57a2fc..c05e79ee 100644 --- a/src/evaluators/utils.jl +++ b/src/evaluators/utils.jl @@ -3,13 +3,40 @@ # This utility only supports a small structural subset so flattening stays # predictable and reconstruction can use `x` as the template. +using LinearAlgebra: + AbstractTriangular, + Bidiagonal, + Diagonal, + Hermitian, + Symmetric, + SymTridiagonal, + Tridiagonal + +# Structured wrappers from LinearAlgebra have `length(x) > # of independent +# entries`, so a naive round-trip is lossy or fails inside `copyto!`. Reject up +# front with a clear error rather than emitting broken results. Cholesky/LU/QR +# are not <:AbstractArray and already fall through to the catch-all. +const _StructuredArray = Union{ + AbstractTriangular,Bidiagonal,Diagonal,Hermitian,Symmetric,SymTridiagonal,Tridiagonal +} + +function _reject_structured(x) + throw( + ArgumentError( + "Structured array `$(typeof(x))` is not supported by the flatten/unflatten utilities; convert to a plain `Array` first.", + ), + ) +end + flat_length(x::Union{Real,Complex}) = 1 +flat_length(x::_StructuredArray) = _reject_structured(x) flat_length(x::AbstractArray{<:Union{Real,Complex}}) = length(x) -flat_length(x::Tuple) = mapreduce(flat_length, +, x; init=0) -flat_length(x::NamedTuple) = mapreduce(flat_length, +, values(x); init=0) +flat_length(x::Tuple) = sum(flat_length, x; init=0) +flat_length(x::NamedTuple) = sum(flat_length, values(x); init=0) flat_length(x) = throw(ArgumentError("This value cannot be flattened into a vector.")) flat_eltype(x::Union{Real,Complex}) = typeof(x) +flat_eltype(x::_StructuredArray) = _reject_structured(x) flat_eltype(x::AbstractArray{T}) where {T<:Union{Real,Complex}} = T flat_eltype(::Tuple{}) = Float64 flat_eltype(x::Tuple) = mapreduce(flat_eltype, promote_type, x) @@ -111,9 +138,7 @@ end for name in Names v = gensym(name) push!(val_syms, v) - push!( - block.args, :(($v, off) = _unflatten(getfield(x, $(QuoteNode(name))), buf, off)) - ) + push!(block.args, :(($v, off) = _unflatten(x[$(QuoteNode(name))], buf, off))) end push!(block.args, :(return (NamedTuple{$Names}(($(val_syms...),)), off))) return block diff --git a/test/evaluators/utils.jl b/test/evaluators/utils.jl index 41f1c8e0..b6707841 100644 --- a/test/evaluators/utils.jl +++ b/test/evaluators/utils.jl @@ -1,5 +1,6 @@ using AbstractPPL using AbstractPPL.Evaluators: flatten_to!!, unflatten_to!! +using LinearAlgebra: Diagonal, Symmetric, UpperTriangular using Test @testset "vectorisation utilities" begin @@ -64,6 +65,17 @@ using Test @test typeof(x3) == typeof(x) end + @testset "structured arrays rejected" begin + for x in ( + Symmetric([1.0 2.0; 2.0 3.0]), + Diagonal([1.0, 2.0]), + UpperTriangular([1.0 2.0; 0.0 3.0]), + ) + @test_throws r"Structured array" flatten_to!!(nothing, x) + @test_throws r"Structured array" unflatten_to!!(x, [1.0, 2.0, 3.0, 4.0]) + end + end + @testset "buffer length mismatch" begin @test_throws r"Expected a vector of length 4" flatten_to!!( Vector{Float64}(undef, 3), zeros(2, 2) diff --git a/test/ext/logdensityproblems/logdensityproblems.jl b/test/ext/logdensityproblems/main.jl similarity index 100% rename from test/ext/logdensityproblems/logdensityproblems.jl rename to test/ext/logdensityproblems/main.jl diff --git a/test/run_extras.jl b/test/run_extras.jl index 1a5e8e5a..4b40e320 100644 --- a/test/run_extras.jl +++ b/test/run_extras.jl @@ -1,15 +1,13 @@ # Run a named extension test in its own isolated Julia environment. # # Usage (from the repo root): -# LABEL=logdensityproblems julia test/run_extras.jl +# LABEL=ext/logdensityproblems julia test/run_extras.jl -const TEST_SUBDIRS = (logdensityproblems="ext",) -const VALID_LABELS = string.(keys(TEST_SUBDIRS)) +const VALID_LABELS = ("ext/logdensityproblems",) label = get(ENV, "LABEL", nothing) label === nothing && error("Set LABEL to one of: $(join(VALID_LABELS, ", "))") label in VALID_LABELS || error("Unknown LABEL=$label. Valid options: $(join(VALID_LABELS, ", "))") -subdir = TEST_SUBDIRS[Symbol(label)] -include(joinpath(@__DIR__, subdir, label, label * ".jl")) +include(joinpath(@__DIR__, label, "main.jl")) From 88f1661118464d4f0b35df1fb2adf07d86e581f5 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 11:17:34 +0100 Subject: [PATCH 12/17] Note TODO for proper structured-array / Cholesky support Co-Authored-By: Claude Opus 4.7 (1M context) --- src/evaluators/utils.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/evaluators/utils.jl b/src/evaluators/utils.jl index c05e79ee..4bb677a5 100644 --- a/src/evaluators/utils.jl +++ b/src/evaluators/utils.jl @@ -16,6 +16,10 @@ using LinearAlgebra: # entries`, so a naive round-trip is lossy or fails inside `copyto!`. Reject up # front with a clear error rather than emitting broken results. Cholesky/LU/QR # are not <:AbstractArray and already fall through to the catch-all. +# +# TODO: extend `flatten_to!!` / `unflatten_to!!` with proper support for +# structured arrays (independent-entry packing) and factorisation types +# (Cholesky in particular is needed for PPL covariance parameters). const _StructuredArray = Union{ AbstractTriangular,Bidiagonal,Diagonal,Hermitian,Symmetric,SymTridiagonal,Tridiagonal } From a8ff7e1e8312f9ee86d8a2cd587b46b6d5c6e029 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 14:05:42 +0100 Subject: [PATCH 13/17] Address late-round review feedback on utils.jl - Drop `init=0` from `flat_length` tuple/namedtuple sums and add explicit empty-case methods to mirror `flat_eltype`'s pattern. - Reject non-one-based arrays via `Base.require_one_based_indexing` (previously `copyto!(buf, offset, x, 1, n)` BoundsErrored on OffsetArrays silently). Checked at user-facing buf entry points and at the array dispatches in `_flatten_to!` / `_unflatten`. - Regression tests for both rejection paths. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/evaluators/utils.jl | 10 ++++++++-- test/evaluators/utils.jl | 13 +++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/evaluators/utils.jl b/src/evaluators/utils.jl index 4bb677a5..a287e167 100644 --- a/src/evaluators/utils.jl +++ b/src/evaluators/utils.jl @@ -35,8 +35,10 @@ end flat_length(x::Union{Real,Complex}) = 1 flat_length(x::_StructuredArray) = _reject_structured(x) flat_length(x::AbstractArray{<:Union{Real,Complex}}) = length(x) -flat_length(x::Tuple) = sum(flat_length, x; init=0) -flat_length(x::NamedTuple) = sum(flat_length, values(x); init=0) +flat_length(::Tuple{}) = 0 +flat_length(x::Tuple) = sum(flat_length, x) +flat_length(::NamedTuple{(),Tuple{}}) = 0 +flat_length(x::NamedTuple) = sum(flat_length, values(x)) flat_length(x) = throw(ArgumentError("This value cannot be flattened into a vector.")) flat_eltype(x::Union{Real,Complex}) = typeof(x) @@ -69,6 +71,7 @@ function flatten_to!!(::Nothing, x) end function flatten_to!!(buf::AbstractVector, x) + Base.require_one_based_indexing(buf) n = flat_length(x) length(buf) == n || throw( DimensionMismatch("Expected a vector of length $n, but got length $(length(buf))."), @@ -85,6 +88,7 @@ end function _flatten_to!( buf::AbstractVector, x::AbstractArray{<:Union{Real,Complex}}, offset::Int ) + Base.require_one_based_indexing(x) n = length(x) copyto!(buf, offset, x, 1, n) return offset + n @@ -115,6 +119,7 @@ end function _unflatten( x::AbstractArray{<:Union{Real,Complex}}, buf::AbstractVector, offset::Int ) + Base.require_one_based_indexing(x) n = length(x) value = similar(x) copyto!(value, 1, buf, offset, n) @@ -175,6 +180,7 @@ Pass `check_eltype=true` to emit a warning when `eltype(buf)` differs from # # x2 == x → true # # typeof(x2) == typeof(x) → true function unflatten_to!!(x, buf::AbstractVector; check_eltype::Bool=false) + Base.require_one_based_indexing(buf) n = flat_length(x) length(buf) == n || throw( DimensionMismatch("Expected a vector of length $n, but got length $(length(buf))."), diff --git a/test/evaluators/utils.jl b/test/evaluators/utils.jl index b6707841..9021505d 100644 --- a/test/evaluators/utils.jl +++ b/test/evaluators/utils.jl @@ -1,6 +1,7 @@ using AbstractPPL using AbstractPPL.Evaluators: flatten_to!!, unflatten_to!! using LinearAlgebra: Diagonal, Symmetric, UpperTriangular +using OffsetArrays: OffsetArray using Test @testset "vectorisation utilities" begin @@ -65,6 +66,18 @@ using Test @test typeof(x3) == typeof(x) end + @testset "non-one-based arrays rejected" begin + oa = OffsetArray([1.0, 2.0, 3.0], 0:2) + @test_throws ArgumentError flatten_to!!(nothing, oa) + @test_throws ArgumentError flatten_to!!(zeros(3), oa) + @test_throws ArgumentError unflatten_to!!(oa, [1.0, 2.0, 3.0]) + # Non-one-based buf is rejected even when `x` is fine. + @test_throws ArgumentError flatten_to!!(OffsetArray(zeros(3), 0:2), [1.0, 2.0, 3.0]) + @test_throws ArgumentError unflatten_to!!( + [1.0, 2.0, 3.0], OffsetArray(zeros(3), 0:2) + ) + end + @testset "structured arrays rejected" begin for x in ( Symmetric([1.0 2.0; 2.0 3.0]), From 0ad98f5d7b1892c5aeca6b0801d75ed12c283f88 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 14:47:45 +0100 Subject: [PATCH 14/17] Validate nested-array shapes and supported leaves in NamedTupleEvaluator The previous shape check relied on `typeof` equality, which treats `(b=zeros(2),)` and `(b=[1.0],)` as the same shape because both are `Vector{Float64}`. Add a structural `_shapes_match` recursion that compares `size` for array leaves (recursing into non-numeric eltypes) and rejects unsupported leaf types up front, matching the `Real`/`Complex`/`AbstractArray`/`Tuple`/`NamedTuple` contract used by the flatten/unflatten utilities. Document the supported leaves in the `NamedTupleEvaluator` docstring and the user-facing docs. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/src/evaluators.md | 7 +++-- src/evaluators/Evaluators.jl | 52 +++++++++++++++++++++++++++++++++-- test/evaluators/Evaluators.jl | 14 ++++++++++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index 9522b9d4..37444d9a 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -79,8 +79,11 @@ value_and_jacobian!!(prepared_vec, [2.0, 3.0, 4.0]) ### NamedTuple inputs When the callable accepts a `NamedTuple`, pass a sample `NamedTuple` whose -field names and value types match the expected input. An extension can define a -`prepare` overload that wraps the function in a `NamedTupleEvaluator`: +field names and value types match the expected input. The prototype's leaves +must be `Real`, `Complex`, `AbstractArray` (recursively), `Tuple`, or +`NamedTuple` — the same structural model used by `flatten_to!!` / +`unflatten_to!!`. An extension can define a `prepare` overload that wraps the +function in a `NamedTupleEvaluator`: ```@example ad function AbstractPPL.prepare(adtype::AutoForwardDiff, f, values::NamedTuple) diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index be582b1f..b24e1e6f 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -117,8 +117,18 @@ VectorEvaluator(f, dim::Int) = VectorEvaluator{true}(f, dim) Evaluator shape for functions of a `NamedTuple` input with a stable prototype. Part of the extension author API; end users interact with the wrapping `Prepared`. +The `inputspec` prototype's leaves must be one of: + +- `Real` or `Complex` (scalar) +- `AbstractArray` whose elements are themselves supported leaves +- `Tuple` or `NamedTuple` recursively containing supported leaves + +This matches the structural model used by [`flatten_to!!`](@ref) / +[`unflatten_to!!`](@ref). Other leaf types (e.g. `String`, `Symbol`, custom +structs) trigger an `ArgumentError` from the per-call shape check. + `CheckInput` controls whether each call validates that the input `NamedTuple` -has the same type as the prototype captured during preparation. +matches the prototype's `typeof` and per-leaf array `size`. """ struct NamedTupleEvaluator{CheckInput,F,P<:NamedTuple} f::F @@ -169,7 +179,10 @@ end _assert_namedtuple_shape(e::NamedTupleEvaluator, values) Throw `ArgumentError` unless `values` has the same type as the prototype captured -during preparation. No-op when `e` was constructed with `CheckInput=false`. +during preparation, including matching `size` for any nested `AbstractArray` +leaves. Also throws if the prototype contains a leaf type outside the supported +set (`Real`, `Complex`, `AbstractArray`, `Tuple`, `NamedTuple`). No-op when `e` +was constructed with `CheckInput=false`. """ function _assert_namedtuple_shape(e::NamedTupleEvaluator{true}, values) typeof(values) === typeof(e.inputspec) || throw( @@ -177,10 +190,45 @@ function _assert_namedtuple_shape(e::NamedTupleEvaluator{true}, values) "Expected the same NamedTuple structure that was used to prepare this evaluator.", ), ) + _shapes_match(values, e.inputspec) || throw( + ArgumentError( + "Nested array shape differs from the prototype captured during preparation." + ), + ) return nothing end _assert_namedtuple_shape(::NamedTupleEvaluator{false}, _) = nothing +# Complements the `typeof` check above: same-typed arrays can differ in `size`. +# Arrays with non-`Real`/`Complex` eltype are walked element-wise to catch +# inner mismatches. Unknown leaves throw, mirroring the supported-leaves +# contract of the flatten/unflatten utilities in `utils.jl`. +# +# `Tuple` recursion uses `first`/`Base.tail` rather than a `zip` loop so each +# leaf call sees concrete element types — same idiom as `_unflatten`. +_shapes_match(::Union{Real,Complex}, ::Union{Real,Complex}) = true +function _shapes_match(a::AbstractArray, b::AbstractArray) + size(a) == size(b) || return false + eltype(a) <: Union{Real,Complex} && return true + for (ai, bi) in zip(a, b) + _shapes_match(ai, bi) || return false + end + return true +end +_shapes_match(::Tuple{}, ::Tuple{}) = true +function _shapes_match(a::Tuple, b::Tuple) + _shapes_match(first(a), first(b)) || return false + return _shapes_match(Base.tail(a), Base.tail(b)) +end +_shapes_match(a::NamedTuple, b::NamedTuple) = _shapes_match(values(a), values(b)) +function _shapes_match(a, _) + throw( + ArgumentError( + "Cannot validate shape for prototype leaf of type `$(typeof(a))`. Supported leaves are `Real`, `Complex`, `AbstractArray`, `Tuple`, and `NamedTuple`.", + ), + ) +end + # Output-shape assertions for AD-backend extensions to share. Centralised here # so each backend's `value_and_gradient!!` / `value_and_jacobian!!` produces # the same error message rather than rolling its own. diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl index 20425aba..6df35b1c 100644 --- a/test/evaluators/Evaluators.jl +++ b/test/evaluators/Evaluators.jl @@ -62,6 +62,20 @@ end @test_throws r"same NamedTuple structure" AbstractPPL.Evaluators._assert_namedtuple_shape( ne, (totally=:wrong,) ) + + # Nested array shape: same `typeof` (Vector{Float64}), different size. + @test_throws r"Nested array" ne((a=1.0, b=[2.0])) + + # Array-of-arrays: same `typeof` and outer size, mismatched inner size. + ne_nested = AbstractPPL.Evaluators.NamedTupleEvaluator( + x -> sum(sum, x.b), (b=[zeros(2), zeros(2)],) + ) + @test ne_nested((b=[[1.0, 2.0], [3.0, 4.0]],)) == 10.0 + @test_throws r"Nested array" ne_nested((b=[[1.0], [2.0]],)) + + # Unsupported leaf types are rejected rather than silently passing. + ne_string = AbstractPPL.Evaluators.NamedTupleEvaluator(x -> length(x.s), (s="abc",)) + @test_throws r"Supported leaves" ne_string((s="abcde",)) end @testset "prepare (structural)" begin From 2eecfa96eee69c8d661c7d1d63b8a1eab43a8bbe Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 1 May 2026 23:41:05 +0100 Subject: [PATCH 15/17] Switch flatten/unflatten to opt-in array allow-list Replace the `_StructuredArray` reject-list with a `FlattableArray = Union{Array,SubArray}` opt-in. This rejects `Adjoint`/`Transpose` (whose `similar` strips the wrapper and silently breaks the type round-trip) and any other custom `AbstractArray` whose round-trip we cannot guarantee. Also clarify the `check_eltype` warning to mention the `InexactError` that follows when narrowing conversions fail. Co-Authored-By: Claude Opus 4.7 (1M context) --- Project.toml | 2 +- docs/src/evaluators.md | 24 +++++++++------ src/evaluators/Evaluators.jl | 21 ++++++++----- src/evaluators/utils.jl | 57 +++++++++++------------------------ test/evaluators/Evaluators.jl | 19 +++++++++++- test/evaluators/utils.jl | 23 ++++++++------ 6 files changed, 78 insertions(+), 68 deletions(-) diff --git a/Project.toml b/Project.toml index bf497189..3678edc6 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.14.2" +version = "0.14.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index 37444d9a..4f1e1512 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -59,19 +59,22 @@ prepared([1.0, 2.0, 3.0]) ``` For vector-valued callables, use `value_and_jacobian!!`. The returned Jacobian -has shape `(length(value), length(x))`: +has shape `(length(value), length(x))`. The same backend extension that +defines `value_and_gradient!!` typically also defines `value_and_jacobian!!` +on the same `Prepared` type — they are separate generic functions, so the +two methods coexist without conflict and the caller picks whichever applies +to their function: ```@example ad using AbstractPPL: value_and_jacobian!! -vecfun(x) = [x[1] * x[2], x[2] + x[3]] - function AbstractPPL.value_and_jacobian!!( p::Prepared{<:AutoForwardDiff}, x::AbstractVector{<:Real} ) return (p(x), ForwardDiff.jacobian(p.evaluator.f, x)) end +vecfun(x) = [x[1] * x[2], x[2] + x[3]] prepared_vec = prepare(AutoForwardDiff(), vecfun, zeros(3)) value_and_jacobian!!(prepared_vec, [2.0, 3.0, 4.0]) ``` @@ -81,8 +84,7 @@ value_and_jacobian!!(prepared_vec, [2.0, 3.0, 4.0]) When the callable accepts a `NamedTuple`, pass a sample `NamedTuple` whose field names and value types match the expected input. The prototype's leaves must be `Real`, `Complex`, `AbstractArray` (recursively), `Tuple`, or -`NamedTuple` — the same structural model used by `flatten_to!!` / -`unflatten_to!!`. An extension can define a `prepare` overload that wraps the +`NamedTuple`. An extension can define a `prepare` overload that wraps the function in a `NamedTupleEvaluator`: ```@example ad @@ -139,14 +141,16 @@ would be redundant work in the hot path. ## Without an AD backend The two-argument form `prepare(problem, x)` is available without any AD -package. It returns the callable unchanged by default, so the caller doesn't -need to know whether an AD backend is loaded — the same `prepare(...)` call -works either way, and downstream code that only needs primal evaluation -(e.g. log-density only, no gradient) can accept the result uniformly: +package. By default it wraps `problem` in a `VectorEvaluator{check_dims}` +(or `NamedTupleEvaluator{check_dims}` for the `NamedTuple` form), giving you +a callable that runs the per-call shape check before forwarding to +`problem`. Downstream code that only needs primal evaluation (e.g. +log-density only, no gradient) can call `prepare(...)` uniformly without +knowing whether an AD backend is loaded: ```@example ad sumsimple(x) = sum(x) -p = prepare(sumsimple, zeros(3)) +p = prepare(sumsimple, zeros(3)) # `VectorEvaluator{true}(sumsimple, 3)` p([1.0, 2.0, 3.0]) ``` diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index b24e1e6f..8942e183 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -51,17 +51,22 @@ named inputs, or with a vector when it works with vector inputs. The three-argument form, contributed by AD-backend extensions, additionally prepares gradient or jacobian machinery for vector inputs. -`check_dims` (default `true`) is forwarded to the evaluator constructor by -AD-backend extensions (three-argument form). Pass `check_dims=false` to skip -per-call shape validation, e.g. when the AD backend already guarantees the -input shape. The two-argument stubs ignore this keyword. +`check_dims` (default `true`) controls whether the returned evaluator validates +the input shape on each call. Pass `check_dims=false` to skip the per-call +check, e.g. inside an AD backend's hot path where the input shape is already +guaranteed. """ function prepare end -# Downstream packages (e.g. DynamicPPL) pass already-callable objects, -# so the safe default is to return them unchanged. -prepare(problem, values::NamedTuple; check_dims::Bool=true) = problem -prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true) = problem +# Default: wrap the callable in the appropriate evaluator so per-call shape +# checks fire even without a backend-specific `prepare` method. Downstream +# packages (e.g. DynamicPPL) override these for their problem types. +function prepare(problem, values::NamedTuple; check_dims::Bool=true) + return NamedTupleEvaluator{check_dims}(problem, values) +end +function prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true) + return VectorEvaluator{check_dims}(problem, length(x)) +end """ value_and_gradient!!(prepared, x::AbstractVector{<:Real}) diff --git a/src/evaluators/utils.jl b/src/evaluators/utils.jl index a287e167..201bf300 100644 --- a/src/evaluators/utils.jl +++ b/src/evaluators/utils.jl @@ -1,40 +1,18 @@ # Vectorisation utilities -# This utility only supports a small structural subset so flattening stays -# predictable and reconstruction can use `x` as the template. - -using LinearAlgebra: - AbstractTriangular, - Bidiagonal, - Diagonal, - Hermitian, - Symmetric, - SymTridiagonal, - Tridiagonal - -# Structured wrappers from LinearAlgebra have `length(x) > # of independent -# entries`, so a naive round-trip is lossy or fails inside `copyto!`. Reject up -# front with a clear error rather than emitting broken results. Cholesky/LU/QR -# are not <:AbstractArray and already fall through to the catch-all. +# Opt-in: `Array` and `SubArray` are the only `AbstractArray` subtypes that +# round-trip cleanly through `similar` + `copyto!` while preserving structure. +# Structured wrappers (`Symmetric`, `Diagonal`, …), lazy wrappers +# (`Adjoint`/`Transpose`), `OffsetArray`, and custom array types fall through +# to the catch-all rejection — callers must `collect` first. # -# TODO: extend `flatten_to!!` / `unflatten_to!!` with proper support for -# structured arrays (independent-entry packing) and factorisation types -# (Cholesky in particular is needed for PPL covariance parameters). -const _StructuredArray = Union{ - AbstractTriangular,Bidiagonal,Diagonal,Hermitian,Symmetric,SymTridiagonal,Tridiagonal -} - -function _reject_structured(x) - throw( - ArgumentError( - "Structured array `$(typeof(x))` is not supported by the flatten/unflatten utilities; convert to a plain `Array` first.", - ), - ) -end +# TODO: extend with proper support for structured arrays (independent-entry +# packing) and factorisation types (Cholesky in particular is needed for PPL +# covariance parameters). +const FlattableArray{T} = Union{Array{T},SubArray{T}} flat_length(x::Union{Real,Complex}) = 1 -flat_length(x::_StructuredArray) = _reject_structured(x) -flat_length(x::AbstractArray{<:Union{Real,Complex}}) = length(x) +flat_length(x::FlattableArray{<:Union{Real,Complex}}) = length(x) flat_length(::Tuple{}) = 0 flat_length(x::Tuple) = sum(flat_length, x) flat_length(::NamedTuple{(),Tuple{}}) = 0 @@ -42,8 +20,7 @@ flat_length(x::NamedTuple) = sum(flat_length, values(x)) flat_length(x) = throw(ArgumentError("This value cannot be flattened into a vector.")) flat_eltype(x::Union{Real,Complex}) = typeof(x) -flat_eltype(x::_StructuredArray) = _reject_structured(x) -flat_eltype(x::AbstractArray{T}) where {T<:Union{Real,Complex}} = T +flat_eltype(x::FlattableArray{T}) where {T<:Union{Real,Complex}} = T flat_eltype(::Tuple{}) = Float64 flat_eltype(x::Tuple) = mapreduce(flat_eltype, promote_type, x) flat_eltype(::NamedTuple{(),Tuple{}}) = Float64 @@ -58,7 +35,8 @@ Flatten `x` into the vector-like buffer `buf`. Supported `x` values are: - `Real` - `Complex` -- `AbstractArray{<:Union{Real,Complex}}` +- `Array{<:Union{Real,Complex}}` or one-based `SubArray` thereof (other + `AbstractArray` subtypes must be `collect`ed first) - `Tuple` recursively containing supported values - `NamedTuple` recursively containing supported values @@ -86,7 +64,7 @@ function _flatten_to!(buf::AbstractVector, x::Union{Real,Complex}, offset::Int) end function _flatten_to!( - buf::AbstractVector, x::AbstractArray{<:Union{Real,Complex}}, offset::Int + buf::AbstractVector, x::FlattableArray{<:Union{Real,Complex}}, offset::Int ) Base.require_one_based_indexing(x) n = length(x) @@ -117,7 +95,7 @@ function _unflatten(x::Union{Real,Complex}, buf::AbstractVector, offset::Int) end function _unflatten( - x::AbstractArray{<:Union{Real,Complex}}, buf::AbstractVector, offset::Int + x::FlattableArray{<:Union{Real,Complex}}, buf::AbstractVector, offset::Int ) Base.require_one_based_indexing(x) n = length(x) @@ -161,7 +139,8 @@ Reconstruct a value from the vector-like buffer `buf` using `x` as the structura Supported `x` values are: - `Real` - `Complex` -- `AbstractArray{<:Union{Real,Complex}}` +- `Array{<:Union{Real,Complex}}` or one-based `SubArray` thereof (other + `AbstractArray` subtypes must be `collect`ed first) - `Tuple` recursively containing supported values - `NamedTuple` recursively containing supported values @@ -188,7 +167,7 @@ function unflatten_to!!(x, buf::AbstractVector; check_eltype::Bool=false) if check_eltype expected = flat_eltype(x) eltype(buf) === expected || @warn( - "Buffer eltype `$(eltype(buf))` differs from `flat_eltype(x) = $expected`; reconstructing using the leaf types from `x`." + "Buffer eltype `$(eltype(buf))` differs from `flat_eltype(x) = $expected`; reconstructing using the leaf types from `x`. An `InexactError` will be thrown if any value in `buf` cannot be converted back to the corresponding leaf type." ) end value, _ = _unflatten(x, buf, 1) diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl index 6df35b1c..68828a6a 100644 --- a/test/evaluators/Evaluators.jl +++ b/test/evaluators/Evaluators.jl @@ -1,6 +1,6 @@ using AbstractPPL using AbstractPPL: prepare, value_and_gradient!!, evaluate!! -using AbstractPPL.Evaluators: Prepared, VectorEvaluator +using AbstractPPL.Evaluators: Prepared, VectorEvaluator, NamedTupleEvaluator using ADTypes: ADTypes using Test @@ -89,6 +89,23 @@ end @test lp ≈ 0.5 + 1.5 + 2.5 @test_throws ErrorException prepared((a=1.0, b=2.0)) + + # Generic fallback wraps a plain callable in the appropriate evaluator + # so per-call shape checks fire even without a backend-specific override. + pv = prepare(sum, zeros(3)) + @test pv isa VectorEvaluator{true} + @test pv([1.0, 2.0, 3.0]) == 6.0 + @test_throws DimensionMismatch pv([1.0, 2.0]) + + ntfun = v -> v.a + sum(v.b) + pn = prepare(ntfun, (a=0.0, b=zeros(2))) + @test pn isa NamedTupleEvaluator{true} + @test pn((a=1.0, b=[2.0, 3.0])) == 6.0 + + # check_dims=false propagates to the wrapper. + pv_unchecked = prepare(sum, zeros(3); check_dims=false) + @test pv_unchecked isa VectorEvaluator{false} + @test pv_unchecked([1.0, 2.0]) == 3.0 # wrong length, no error end @testset "prepare (AD-aware)" begin diff --git a/test/evaluators/utils.jl b/test/evaluators/utils.jl index 9021505d..ab572041 100644 --- a/test/evaluators/utils.jl +++ b/test/evaluators/utils.jl @@ -66,26 +66,31 @@ using Test @test typeof(x3) == typeof(x) end - @testset "non-one-based arrays rejected" begin - oa = OffsetArray([1.0, 2.0, 3.0], 0:2) - @test_throws ArgumentError flatten_to!!(nothing, oa) - @test_throws ArgumentError flatten_to!!(zeros(3), oa) - @test_throws ArgumentError unflatten_to!!(oa, [1.0, 2.0, 3.0]) - # Non-one-based buf is rejected even when `x` is fine. + @testset "non-one-based buf rejected" begin + # Non-one-based `buf` fails `require_one_based_indexing` even when `x` + # is a plain `Array`. (Non-one-based `x` is covered by the catch-all + # in the next testset.) @test_throws ArgumentError flatten_to!!(OffsetArray(zeros(3), 0:2), [1.0, 2.0, 3.0]) @test_throws ArgumentError unflatten_to!!( [1.0, 2.0, 3.0], OffsetArray(zeros(3), 0:2) ) end - @testset "structured arrays rejected" begin + @testset "non-Array/SubArray AbstractArrays rejected" begin + # Only `Array` and `SubArray` are opted in. Structured wrappers, + # `Adjoint`/`Transpose`, and any other `AbstractArray` fall through to + # the catch-all rejection. + M = [1.0 2.0; 3.0 4.0] for x in ( Symmetric([1.0 2.0; 2.0 3.0]), Diagonal([1.0, 2.0]), UpperTriangular([1.0 2.0; 0.0 3.0]), + adjoint(M), + transpose(M), + OffsetArray([1.0, 2.0, 3.0], 0:2), ) - @test_throws r"Structured array" flatten_to!!(nothing, x) - @test_throws r"Structured array" unflatten_to!!(x, [1.0, 2.0, 3.0, 4.0]) + @test_throws r"cannot be flattened" flatten_to!!(nothing, x) + @test_throws r"cannot be flattened" unflatten_to!!(x, [1.0, 2.0, 3.0, 4.0]) end end From 49b20aa44b9f34ade0b7a49659a8327b6187fdb4 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sat, 2 May 2026 10:13:39 +0100 Subject: [PATCH 16/17] Address final round of PR review feedback - src/evaluators/utils.jl: drop SubArray from FlattableArray allowlist; views silently lost their wrapper through similar(::SubArray) so the typeof round-trip invariant did not hold. - src/evaluators/Evaluators.jl: extend NamedTupleEvaluator docstring to mirror VectorEvaluator's CheckInput=false guidance, calling out that the prototype typeof check rejects dual/shadow leaves; narrow the AD-extension error hint to fire only when no extension has registered an AD-aware prepare method; add a brief comment on VectorEvaluator{false} noting the `T <: Integer` branch resolves at compile time. - ext/AbstractPPLLogDensityProblemsExt.jl: drop redundant instance-method capabilities delegations now that we rely on LogDensityProblems' built-in capabilities(x) = capabilities(typeof(x)). - test/evaluators/utils.jl: extend rejection test with a SubArray case; drop obsolete view round-trip edge case. - test/evaluators/Evaluators.jl: replace the internal _assert_namedtuple_shape direct-call asserts with public-call coverage. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLLogDensityProblemsExt.jl | 11 +++-------- src/evaluators/Evaluators.jl | 26 ++++++++++++++++++------- src/evaluators/utils.jl | 22 +++++++++++---------- test/evaluators/Evaluators.jl | 8 ++------ test/evaluators/utils.jl | 14 +++++-------- 5 files changed, 41 insertions(+), 40 deletions(-) diff --git a/ext/AbstractPPLLogDensityProblemsExt.jl b/ext/AbstractPPLLogDensityProblemsExt.jl index fd4a2585..d1ba52ef 100644 --- a/ext/AbstractPPLLogDensityProblemsExt.jl +++ b/ext/AbstractPPLLogDensityProblemsExt.jl @@ -16,23 +16,18 @@ function LogDensityProblems.dimension(p::Prepared{<:Any,<:VectorEvaluator}) end LogDensityProblems.dimension(e::VectorEvaluator) = e.dim -# Generic fallback: order 0. AD-backend extensions (DifferentiationInterface, +# Order 0 by default. AD-backend extensions (DifferentiationInterface, # ForwardDiff, Mooncake, etc.) must overload this for their cache type to # advertise `LogDensityOrder{1}` — without that overload, # `logdensity_and_gradient` will hit the `value_and_gradient!!` stub and fail. +# LDP defines `capabilities(x) = capabilities(typeof(x))`, so the type method +# alone covers both call shapes. function LogDensityProblems.capabilities(::Type{<:Prepared{<:Any,<:VectorEvaluator}}) return LogDensityProblems.LogDensityOrder{0}() end -function LogDensityProblems.capabilities(p::Prepared{<:Any,<:VectorEvaluator}) - return LogDensityProblems.capabilities(typeof(p)) -end - function LogDensityProblems.capabilities(::Type{<:VectorEvaluator}) return LogDensityProblems.LogDensityOrder{0}() end -function LogDensityProblems.capabilities(e::VectorEvaluator) - return LogDensityProblems.capabilities(typeof(e)) -end function LogDensityProblems.logdensity_and_gradient(p::Prepared{<:Any,<:VectorEvaluator}, x) val, grad = AbstractPPL.value_and_gradient!!(p, x) diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index 8942e183..fb518ba2 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -133,7 +133,13 @@ This matches the structural model used by [`flatten_to!!`](@ref) / structs) trigger an `ArgumentError` from the per-call shape check. `CheckInput` controls whether each call validates that the input `NamedTuple` -matches the prototype's `typeof` and per-leaf array `size`. +matches the prototype's `typeof` and per-leaf array `size`. The default (`true`) +is the safe shape exposed via `prepared(x)`. Pass `CheckInput=false` (via +`check_dims=false` in `prepare`) for the callable handed to AD libraries: the +prototype's `typeof` is captured at preparation time using the original element +types, so a `CheckInput=true` evaluator will reject inputs whose leaves are +dual/shadow numbers (or any other widened element type) even when the structure +is otherwise correct. """ struct NamedTupleEvaluator{CheckInput,F,P<:NamedTuple} f::F @@ -169,6 +175,8 @@ function (e::VectorEvaluator{true})(x::AbstractVector{T}) where {T} return e.f(x) end +# `T <: Integer` resolves at compile time; the AD hot path (Float/dual `T`) +# elides the branch entirely. function (e::VectorEvaluator{false})(x::AbstractVector{T}) where {T} T <: Integer && _reject_integer_input(x) return e.f(x) @@ -263,12 +271,16 @@ evaluate!!(e::NamedTupleEvaluator, x) = e(x) function __init__() Base.Experimental.register_error_hint(MethodError) do io, exc, args, kwargs # `args` are argument types, not values (see `Base.Experimental.show_error_hints`). - if exc.f === prepare && length(args) >= 1 && args[1] <: AbstractADType - print( - io, - "\nCalling `prepare` with an AD backend requires loading the corresponding extension (e.g., `using DifferentiationInterface`).", - ) - end + # Only fire when no extension has registered any AD-aware `prepare` method yet — + # once a backend is loaded, the candidate list in the `MethodError` is more + # informative than a generic "load an extension" hint. + exc.f === prepare || return nothing + length(args) >= 1 && args[1] <: AbstractADType || return nothing + any(m -> m.nargs >= 4, methods(prepare)) && return nothing + print( + io, + "\nCalling `prepare` with an AD backend requires loading the corresponding extension (e.g., `using DifferentiationInterface`).", + ) end end diff --git a/src/evaluators/utils.jl b/src/evaluators/utils.jl index 201bf300..aa83d62f 100644 --- a/src/evaluators/utils.jl +++ b/src/evaluators/utils.jl @@ -1,15 +1,17 @@ # Vectorisation utilities -# Opt-in: `Array` and `SubArray` are the only `AbstractArray` subtypes that -# round-trip cleanly through `similar` + `copyto!` while preserving structure. -# Structured wrappers (`Symmetric`, `Diagonal`, …), lazy wrappers -# (`Adjoint`/`Transpose`), `OffsetArray`, and custom array types fall through -# to the catch-all rejection — callers must `collect` first. +# Opt-in: `Array` is the only `AbstractArray` subtype that round-trips cleanly +# through `similar` + `copyto!` while preserving structure. `SubArray` is +# excluded because `similar(::SubArray)` returns a plain `Array`, so the view +# wrapper is silently dropped on `unflatten_to!!` and `typeof(x2) == typeof(x)` +# would not hold. Structured wrappers (`Symmetric`, `Diagonal`, …), lazy +# wrappers (`Adjoint`/`Transpose`), `OffsetArray`, and custom array types fall +# through to the catch-all rejection — callers must `collect` first. # # TODO: extend with proper support for structured arrays (independent-entry # packing) and factorisation types (Cholesky in particular is needed for PPL # covariance parameters). -const FlattableArray{T} = Union{Array{T},SubArray{T}} +const FlattableArray{T} = Array{T} flat_length(x::Union{Real,Complex}) = 1 flat_length(x::FlattableArray{<:Union{Real,Complex}}) = length(x) @@ -35,8 +37,8 @@ Flatten `x` into the vector-like buffer `buf`. Supported `x` values are: - `Real` - `Complex` -- `Array{<:Union{Real,Complex}}` or one-based `SubArray` thereof (other - `AbstractArray` subtypes must be `collect`ed first) +- `Array{<:Union{Real,Complex}}` (other `AbstractArray` subtypes, including + views, must be `collect`ed first) - `Tuple` recursively containing supported values - `NamedTuple` recursively containing supported values @@ -139,8 +141,8 @@ Reconstruct a value from the vector-like buffer `buf` using `x` as the structura Supported `x` values are: - `Real` - `Complex` -- `Array{<:Union{Real,Complex}}` or one-based `SubArray` thereof (other - `AbstractArray` subtypes must be `collect`ed first) +- `Array{<:Union{Real,Complex}}` (other `AbstractArray` subtypes, including + views, must be `collect`ed first) - `Tuple` recursively containing supported values - `NamedTuple` recursively containing supported values diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl index 68828a6a..c8454438 100644 --- a/test/evaluators/Evaluators.jl +++ b/test/evaluators/Evaluators.jl @@ -56,12 +56,8 @@ end ne_unchecked = AbstractPPL.Evaluators.NamedTupleEvaluator{false}( x -> 0.0, (a=0.0, b=zeros(2)) ) - @test AbstractPPL.Evaluators._assert_namedtuple_shape( - ne_unchecked, (totally=:wrong,) - ) === nothing - @test_throws r"same NamedTuple structure" AbstractPPL.Evaluators._assert_namedtuple_shape( - ne, (totally=:wrong,) - ) + @test ne_unchecked((totally=:wrong,)) == 0.0 + @test_throws r"same NamedTuple structure" ne((totally=:wrong,)) # Nested array shape: same `typeof` (Vector{Float64}), different size. @test_throws r"Nested array" ne((a=1.0, b=[2.0])) diff --git a/test/evaluators/utils.jl b/test/evaluators/utils.jl index ab572041..c472d66e 100644 --- a/test/evaluators/utils.jl +++ b/test/evaluators/utils.jl @@ -76,10 +76,10 @@ using Test ) end - @testset "non-Array/SubArray AbstractArrays rejected" begin - # Only `Array` and `SubArray` are opted in. Structured wrappers, - # `Adjoint`/`Transpose`, and any other `AbstractArray` fall through to - # the catch-all rejection. + @testset "non-Array AbstractArrays rejected" begin + # Only `Array` is opted in. Structured wrappers, `Adjoint`/`Transpose`, + # `SubArray` (views), and any other `AbstractArray` fall through to the + # catch-all rejection. M = [1.0 2.0; 3.0 4.0] for x in ( Symmetric([1.0 2.0; 2.0 3.0]), @@ -88,6 +88,7 @@ using Test adjoint(M), transpose(M), OffsetArray([1.0, 2.0, 3.0], 0:2), + @view([1.0, 2.0, 3.0, 4.0][1:4]), ) @test_throws r"cannot be flattened" flatten_to!!(nothing, x) @test_throws r"cannot be flattened" unflatten_to!!(x, [1.0, 2.0, 3.0, 4.0]) @@ -109,11 +110,6 @@ using Test empty = NamedTuple() @test flatten_to!!(nothing, empty) == Float64[] @test unflatten_to!!(empty, Float64[]) == empty - - view_values = (x=@view([1.0, 2.0, 3.0][2:3]),) - flat = flatten_to!!(nothing, view_values) - rebuilt = unflatten_to!!(view_values, flat) - @test collect(rebuilt.x) == [2.0, 3.0] end @testset "unflatten_to!! type stability" begin From 7ac3e3a8c1d2c1a5ce538e5904cde46901ae5991 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 4 May 2026 18:04:53 +0100 Subject: [PATCH 17/17] Simplify Evaluators module after review - src/evaluators/Evaluators.jl: collapse three trivial evaluate!! overloads into one Union method; remove unused _assert_jacobian_output and _assert_supported_output helpers (will be reintroduced by AD-extension PRs that actually call them); consolidate the integer-rejection compile-time rationale onto the helper; document the m.nargs >= 4 gate in the error hint. - src/evaluators/utils.jl: drop the single-use FlattableArray{T} = Array{T} alias and inline Array{...} at its 4 use sites; tighten the allow-list rationale onto the typeof round-trip contract; fold the heterogeneous round-trip note into the unflatten_to!! docstring and drop the verbose worked-example comment block (covered by the dedicated test). - ext/AbstractPPLLogDensityProblemsExt.jl: add a one-line WHY for the copy(grad) in logdensity_and_gradient (LDP requires a non-aliased gradient). - test/evaluators/utils.jl: trim the seven-element non-Array rejection loop to three representatives (structured wrapper, view, non-one-based) and drop the unused Diagonal/UpperTriangular imports. - test/run_extras.jl: collapse the two-step LABEL validation into a single membership check. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLLogDensityProblemsExt.jl | 1 + src/evaluators/Evaluators.jl | 32 +++----------------- src/evaluators/utils.jl | 40 ++++++++----------------- test/evaluators/utils.jl | 12 ++------ test/run_extras.jl | 3 +- 5 files changed, 21 insertions(+), 67 deletions(-) diff --git a/ext/AbstractPPLLogDensityProblemsExt.jl b/ext/AbstractPPLLogDensityProblemsExt.jl index d1ba52ef..8f00ebb7 100644 --- a/ext/AbstractPPLLogDensityProblemsExt.jl +++ b/ext/AbstractPPLLogDensityProblemsExt.jl @@ -31,6 +31,7 @@ end function LogDensityProblems.logdensity_and_gradient(p::Prepared{<:Any,<:VectorEvaluator}, x) val, grad = AbstractPPL.value_and_gradient!!(p, x) + # `value_and_gradient!!` may alias internal storage; LDP requires a stable result. return (val, copy(grad)) end diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index fb518ba2..ff2d5407 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -155,8 +155,8 @@ end NamedTupleEvaluator(f, inputspec::NamedTuple) = NamedTupleEvaluator{true}(f, inputspec) # Reject integer vectors with a clear error rather than letting them flow into -# AD backends (which usually fail confusingly). The `T <: Integer` branch is -# resolved at compile time, so non-integer inputs pay nothing. +# AD backends (which usually fail confusingly). `T <: Integer` resolves at +# compile time, so the AD hot path (Float/dual `T`) elides the branch entirely. function _reject_integer_input(x) throw( ArgumentError( @@ -175,8 +175,6 @@ function (e::VectorEvaluator{true})(x::AbstractVector{T}) where {T} return e.f(x) end -# `T <: Integer` resolves at compile time; the AD hot path (Float/dual `T`) -# elides the branch entirely. function (e::VectorEvaluator{false})(x::AbstractVector{T}) where {T} T <: Integer && _reject_integer_input(x) return e.f(x) @@ -242,31 +240,8 @@ function _shapes_match(a, _) ) end -# Output-shape assertions for AD-backend extensions to share. Centralised here -# so each backend's `value_and_gradient!!` / `value_and_jacobian!!` produces -# the same error message rather than rolling its own. -function _assert_jacobian_output(y) - y isa AbstractVector || throw( - ArgumentError( - "`value_and_jacobian!!` requires the prepared function to return an AbstractVector; got $(typeof(y)).", - ), - ) - return nothing -end - -function _assert_supported_output(y) - (y isa Number || y isa AbstractVector) || throw( - ArgumentError( - "A prepared AD evaluator must return a scalar or AbstractVector; got $(typeof(y)).", - ), - ) - return nothing -end - # Make prepared evaluators usable through the same `evaluate!!` API as models. -evaluate!!(p::Prepared, x) = p(x) -evaluate!!(e::VectorEvaluator, x) = e(x) -evaluate!!(e::NamedTupleEvaluator, x) = e(x) +evaluate!!(e::Union{Prepared,VectorEvaluator,NamedTupleEvaluator}, x) = e(x) function __init__() Base.Experimental.register_error_hint(MethodError) do io, exc, args, kwargs @@ -276,6 +251,7 @@ function __init__() # informative than a generic "load an extension" hint. exc.f === prepare || return nothing length(args) >= 1 && args[1] <: AbstractADType || return nothing + # `nargs` counts `self`, so `>= 4` matches the AD-aware 3-positional form. any(m -> m.nargs >= 4, methods(prepare)) && return nothing print( io, diff --git a/src/evaluators/utils.jl b/src/evaluators/utils.jl index aa83d62f..5e97d16c 100644 --- a/src/evaluators/utils.jl +++ b/src/evaluators/utils.jl @@ -1,20 +1,17 @@ # Vectorisation utilities -# Opt-in: `Array` is the only `AbstractArray` subtype that round-trips cleanly -# through `similar` + `copyto!` while preserving structure. `SubArray` is -# excluded because `similar(::SubArray)` returns a plain `Array`, so the view -# wrapper is silently dropped on `unflatten_to!!` and `typeof(x2) == typeof(x)` -# would not hold. Structured wrappers (`Symmetric`, `Diagonal`, …), lazy -# wrappers (`Adjoint`/`Transpose`), `OffsetArray`, and custom array types fall -# through to the catch-all rejection — callers must `collect` first. +# Opt-in: only `Array` round-trips cleanly through `similar` + `copyto!` +# preserving `typeof`. `SubArray` is excluded because `similar(::SubArray)` +# returns a plain `Array`, silently breaking the typeof round-trip contract +# advertised by `unflatten_to!!`. Structured/lazy wrappers and `OffsetArray` +# fall through to the catch-all — callers must `collect` first. # # TODO: extend with proper support for structured arrays (independent-entry # packing) and factorisation types (Cholesky in particular is needed for PPL # covariance parameters). -const FlattableArray{T} = Array{T} flat_length(x::Union{Real,Complex}) = 1 -flat_length(x::FlattableArray{<:Union{Real,Complex}}) = length(x) +flat_length(x::Array{<:Union{Real,Complex}}) = length(x) flat_length(::Tuple{}) = 0 flat_length(x::Tuple) = sum(flat_length, x) flat_length(::NamedTuple{(),Tuple{}}) = 0 @@ -22,7 +19,7 @@ flat_length(x::NamedTuple) = sum(flat_length, values(x)) flat_length(x) = throw(ArgumentError("This value cannot be flattened into a vector.")) flat_eltype(x::Union{Real,Complex}) = typeof(x) -flat_eltype(x::FlattableArray{T}) where {T<:Union{Real,Complex}} = T +flat_eltype(x::Array{T}) where {T<:Union{Real,Complex}} = T flat_eltype(::Tuple{}) = Float64 flat_eltype(x::Tuple) = mapreduce(flat_eltype, promote_type, x) flat_eltype(::NamedTuple{(),Tuple{}}) = Float64 @@ -65,9 +62,7 @@ function _flatten_to!(buf::AbstractVector, x::Union{Real,Complex}, offset::Int) return offset + 1 end -function _flatten_to!( - buf::AbstractVector, x::FlattableArray{<:Union{Real,Complex}}, offset::Int -) +function _flatten_to!(buf::AbstractVector, x::Array{<:Union{Real,Complex}}, offset::Int) Base.require_one_based_indexing(x) n = length(x) copyto!(buf, offset, x, 1, n) @@ -96,9 +91,7 @@ function _unflatten(x::Union{Real,Complex}, buf::AbstractVector, offset::Int) return convert(typeof(x), buf[offset]), offset + 1 end -function _unflatten( - x::FlattableArray{<:Union{Real,Complex}}, buf::AbstractVector, offset::Int -) +function _unflatten(x::Array{<:Union{Real,Complex}}, buf::AbstractVector, offset::Int) Base.require_one_based_indexing(x) n = length(x) value = similar(x) @@ -148,18 +141,11 @@ Supported `x` values are: Pass `check_eltype=true` to emit a warning when `eltype(buf)` differs from `flat_eltype(x)` (off by default to keep hot paths quiet). + +Leaves are rebuilt using `x`'s types, so `typeof(unflatten_to!!(x, buf)) == typeof(x)` +even when `buf`'s element type is widened (e.g. real `x` flattened into a `ComplexF64` +buffer). Always allocates: each array leaf goes through `similar`. """ -# Always allocates: `_unflatten` calls `similar` for each array field. Gains from -# buffer reuse are negligible relative to gradient computation cost. -# -# Heterogeneous round-trip: the flat buffer widens, but leaves are rebuilt -# from `x`'s types, so `typeof(x2) == typeof(x)`. E.g. -# -# x = (1.0, [2.0, 3.0], (4.0 + 1.0im,)) # buffer widens to ComplexF64 -# x2 = unflatten_to!!(x, flatten_to!!(nothing, x)) -# # x2 == (1.0, [2.0, 3.0], (4.0 + 1.0im,)) -# # x2 == x → true -# # typeof(x2) == typeof(x) → true function unflatten_to!!(x, buf::AbstractVector; check_eltype::Bool=false) Base.require_one_based_indexing(buf) n = flat_length(x) diff --git a/test/evaluators/utils.jl b/test/evaluators/utils.jl index c472d66e..fb7f8571 100644 --- a/test/evaluators/utils.jl +++ b/test/evaluators/utils.jl @@ -1,6 +1,6 @@ using AbstractPPL using AbstractPPL.Evaluators: flatten_to!!, unflatten_to!! -using LinearAlgebra: Diagonal, Symmetric, UpperTriangular +using LinearAlgebra: Symmetric using OffsetArrays: OffsetArray using Test @@ -77,18 +77,10 @@ using Test end @testset "non-Array AbstractArrays rejected" begin - # Only `Array` is opted in. Structured wrappers, `Adjoint`/`Transpose`, - # `SubArray` (views), and any other `AbstractArray` fall through to the - # catch-all rejection. - M = [1.0 2.0; 3.0 4.0] for x in ( Symmetric([1.0 2.0; 2.0 3.0]), - Diagonal([1.0, 2.0]), - UpperTriangular([1.0 2.0; 0.0 3.0]), - adjoint(M), - transpose(M), - OffsetArray([1.0, 2.0, 3.0], 0:2), @view([1.0, 2.0, 3.0, 4.0][1:4]), + OffsetArray([1.0, 2.0, 3.0], 0:2), ) @test_throws r"cannot be flattened" flatten_to!!(nothing, x) @test_throws r"cannot be flattened" unflatten_to!!(x, [1.0, 2.0, 3.0, 4.0]) diff --git a/test/run_extras.jl b/test/run_extras.jl index 4b40e320..0e71bbcd 100644 --- a/test/run_extras.jl +++ b/test/run_extras.jl @@ -6,8 +6,7 @@ const VALID_LABELS = ("ext/logdensityproblems",) label = get(ENV, "LABEL", nothing) -label === nothing && error("Set LABEL to one of: $(join(VALID_LABELS, ", "))") label in VALID_LABELS || - error("Unknown LABEL=$label. Valid options: $(join(VALID_LABELS, ", "))") + error("Set LABEL to one of: $(join(VALID_LABELS, ", ")) (got `$label`).") include(joinpath(@__DIR__, label, "main.jl"))