diff --git a/HISTORY.md b/HISTORY.md index 12112c61..b42d692c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,17 @@ +## 0.15.3 + +Added the `of` type system: a self-contained, declarative way to specify the shape, element type, and support of model variables. Construct specifications with the exported `of` function or the exported `@of` macro: + +```julia +using AbstractPPL +s = of(Real, 0, 1) # a real in [0, 1] +RegressionParams = @of( + y = of(Array, Float64, 100), beta = of(Array, Float64, 3), sigma = of(Real, 0, nothing), +) +``` + +`of`-types support `rand`/`zero` (drawing or zeroing a value of the declared shape) and `size`/`length` (querying the declared shape), and can be `flatten`ed to / `unflatten`ed from a flat numeric vector. `rand` accepts an optional `AbstractRNG` for reproducible draws. `flatten` returns a vector whose element type is the promotion of the declared leaf types, and `unflatten` is automatic-differentiation transparent (float leaves take `promote_type(declared, eltype(flat))`, so `ForwardDiff.Dual` numbers flow through); both are type-stable. These are narrow methods dispatched on AbstractPPL-owned `OfType` subtypes. The supporting types (`OfType` and its subtypes) and the inspection/`flatten`/`unflatten` helpers are marked `public` for downstream use without being exported. + ## 0.15.2 Added `AbstractPPLForwardDiffExt`, a direct ForwardDiff path for `AutoForwardDiff` (gradient, Jacobian, Hessian, `context`, chunk size, custom `tag`). diff --git a/Project.toml b/Project.toml index 57c13443..f35fb986 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probabilistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.15.2" +version = "0.15.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/make.jl b/docs/make.jl index ead3820f..492c4d09 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,6 @@ using Documenter using AbstractPPL +using Random # for the `Base.rand(::Random.AbstractRNG, ...)` signature in of.md's @docs block # trigger DistributionsExt loading using Distributions, LinearAlgebra @@ -9,7 +10,7 @@ DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive= makedocs(; sitename="AbstractPPL", modules=[AbstractPPL, Base.get_extension(AbstractPPL, :AbstractPPLDistributionsExt)], - pages=["index.md", "varname.md", "pplapi.md", "evaluators.md", "interface.md"], + pages=["index.md", "varname.md", "of.md", "pplapi.md", "evaluators.md", "interface.md"], checkdocs=:exports, doctest=false, ) diff --git a/docs/src/of.md b/docs/src/of.md new file mode 100644 index 00000000..3cb79712 --- /dev/null +++ b/docs/src/of.md @@ -0,0 +1,315 @@ +# The `of` Type System + +## Overview + +The `of` type system provides a declarative way to specify parameter **types** for +probabilistic programming. It is a lightweight, framework-agnostic type-annotation +system that: + + - Returns schema types (not instances) for downstream annotation systems + - Encodes specifications (dimensions, bounds) in type parameters + - Provides utilities for parameter manipulation (`rand`, `zero`, `flatten`, `unflatten`) + +It lives in AbstractPPL so that downstream packages can share a common vocabulary for +describing the shape, element type, and support of model variables. JuliaBUGS, for +example, uses it for `@model` parameter annotations. + +The examples on this page are executed when the documentation is built. The imports are +brought into scope here; later examples reuse them. + +```@setup of +using AbstractPPL +using AbstractPPL: flatten, unflatten +using Random +``` + +```@example of +using AbstractPPL +using AbstractPPL: flatten, unflatten +using Random +nothing # hide +``` + +## Core Concepts + +### 1. Type-Based Design + +The `of` function returns types with specifications encoded in type parameters: + + - `of(Array, dims...)` → `OfArray{Float64, N, (dim1, dim2, ...)}` - Arrays with specified dimensions + - `of(Array, T, dims...)` → `OfArray{T, N, (dim1, dim2, ...)}` - Typed numeric arrays (`T <: Number`) + - `of(Float64)` → `OfReal{Float64, Nothing, Nothing}` - Unbounded 64-bit floating point numbers + - `of(Float32)` → `OfReal{Float32, Nothing, Nothing}` - Unbounded 32-bit floating point numbers + - `of(Float64, lower, upper)` → `OfReal{Float64, lower, upper}` - Bounded 64-bit floats + - `of(Float32, lower, upper)` → `OfReal{Float32, lower, upper}` - Bounded 32-bit floats + - `of(Real)` → `OfReal{Float64, Nothing, Nothing}` - Unbounded real numbers (defaults to Float64) + - `of(Real, lower, upper)` → `OfReal{Float64, lower, upper}` - Bounded real numbers (defaults to Float64) + - `of(Int)` → `OfInt{Nothing, Nothing}` - Unbounded integers + - `of(Int, lower, upper)` → `OfInt{lower, upper}` - Bounded integers + - `@of(field1=..., field2=...)` → `OfNamedTuple{(:field1, :field2), Tuple{Type1, Type2}}` - Named tuples (use `@of` macro) + - `of(...; constant=true)` → `OfConstantWrapper{T}` - Marks a type as constant/hyperparameter (supported for float types and `Int`) + +A few `of(...)` calls and the concrete types they return: + +```@example of +of(Float64, 0, 1) +``` + +```@example of +of(Array, 3, 4) +``` + +```@example of +of(Int; constant=true) +``` + +### 2. Type Parameter Encoding + +The system encodes extra useful information into type parameters: + + - **Dimensions**: Stored as tuple type parameters (e.g., `(3, 4)` for a 3×4 matrix) + - **Bounds**: Numeric literals stored directly as type parameters (e.g., `0.0`, `1.0`), or `Nothing` for unbounded + - **Symbolic references**: Encoded using `SymbolicRef{:symbol}` for referencing earlier constant fields + - **Arithmetic expressions**: Encoded using `SymbolicExpr{expr}` for expressions like `n+1`, `2*n`, etc. Division operations must result in integers for array dimensions. + - **Field names**: Stored as a tuple of symbols in `OfNamedTuple` + - **Element types**: Preserved as type parameters for numeric arrays and nested structures + +### 3. Operations on Types + + - `T(; kwargs...)` where `T<:OfNamedTuple` — Create instances with specified constants (returns values, not types). Uses `zero()` as the default for missing values. + + - `T(default_value; kwargs...)` where `T<:OfNamedTuple` — Create instances with specified constants and initialise all element values to `default_value`, e.g. `T(missing; kwargs...)` initialises all element values to `missing`. `T(...)` returns instances, not types. + - `of(T; kwargs...)` where `T<:OfType` — Create concrete types by resolving constants + - `rand([rng], T::Type{<:OfType})` — Generate random values matching the type specification (pass an `AbstractRNG` for reproducible draws) + - `zero(T::Type{<:OfType})` — Generate zero/default values + - `size(T::Type{<:OfType})` — Get the dimensions/shape of the type + - `length(T::Type{<:OfType})` — Get the total number of elements when flattened + - `flatten(T::Type{<:OfType}, values)` — Convert structured values to a flat vector (element type is the promotion of the declared leaf types) + - `unflatten(T::Type{<:OfType}, vec)` — Reconstruct structured values from a flat vector (float leaves take `promote_type(declared, eltype(vec))`, so AD numbers flow through) + - `unflatten(T::Type{<:OfType}, missing)` — Create instances where element values are initialised to `missing` + +Only `of` and `@of` are exported. `flatten`, `unflatten`, the `OfType` subtypes, and the +inspection helpers are `public` but not exported, so qualify them (`AbstractPPL.flatten`) or +bring them into scope with `using AbstractPPL: flatten, unflatten`. + +### 4. The `@of` Macro + +The `@of` macro provides cleaner syntax by automatically converting references to earlier +constant fields to symbols. Here `n` in the array dimension is automatically converted to +the symbol `:n`: + +```@example of +T = @of( + n = of(Int; constant=true), + data = of(Array, n, 2) # 'n' is automatically converted to :n +) +``` + +### 5. Symbolic Dimensions and Bounds + +For cases where dimensions need to be specified at runtime, declare the dimensions as +constants and reference them in the array specifications: + +```@example of +MatrixType = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + data = of(Array, rows, cols), +) +``` + +Resolving the constants with `of(MatrixType; ...)` produces a concrete type with the +symbolic dimensions filled in: + +```@example of +ConcreteType = of(MatrixType; rows=3, cols=4) +``` + +The concrete type works with [`rand`](@ref) and [`zero`](@ref). The draw uses a seeded RNG +so the rendered output is reproducible: + +```@example of +rand(MersenneTwister(0), ConcreteType) # random 3×4 matrix wrapped in a NamedTuple +``` + +```@example of +zero(ConcreteType) # zero 3×4 matrix wrapped in a NamedTuple +``` + +Concretization can be partial. Resolving only `rows` leaves `cols` symbolic +(semiconcretized): + +```@example of +SemiConcreteType = of(MatrixType; rows=3) +``` + +Calling the type as a constructor builds an instance. With all constants provided, the +non-constant `data` field defaults to zeros: + +```@example of +MatrixType(; rows=3, cols=4) +``` + +Passing `missing` initialises element values to `missing`: + +```@example of +MatrixType(missing; rows=3, cols=4) +``` + +Specific data can be provided directly for non-constant fields: + +```@example of +MatrixType(; rows=3, cols=4, data=ones(3, 4)) +``` + +A concrete type can be flattened and reconstructed. Here we flatten a `3×4` instance and +recover it (`flatten`/`unflatten` are public, not exported): + +```@example of +instance = MatrixType(; rows=3, cols=4) +flat = flatten(ConcreteType, instance) +``` + +```@example of +reconstructed = unflatten(ConcreteType, flat) +``` + +`rand` and `zero` also work directly on a concretized type: + +```@example of +rand(MersenneTwister(0), of(MatrixType; rows=3, cols=4)) # random instance +``` + +```@example of +zero(of(MatrixType; rows=10, cols=5)) # zero instance +``` + +Operations that still need unresolved information error. Constructing with a missing +constant throws, so we catch and display the message: + +```@example of +try + MatrixType(; rows=3) # `cols` is required but not provided +catch err + showerror(stdout, err) +end +``` + +Likewise, drawing from a type with unresolved symbolic dimensions throws: + +```@example of +try + rand(MatrixType) # symbolic dimensions are unresolved +catch err + showerror(stdout, err) +end +``` + +#### Arithmetic expressions in dimensions + +Dimensions may be arithmetic expressions of constant fields. Division operations must +result in integers for array dimensions: + +```@example of +ExpandedMatrixType = @of( + n = of(Int; constant=true), + original = of(Array, n, n), + padded = of(Array, n + 1, n + 1), + doubled = of(Array, 2 * n, n), + halved = of(Array, n / 2, n), +) +``` + +Creating an instance with `n=10` evaluates each expression: `original` is `10×10`, +`padded` is `11×11`, `doubled` is `20×10`, and `halved` is `5×10`. Non-constant fields +default to zero. We display each field's shape: + +```@example of +instance = ExpandedMatrixType(; n=10) +map(size, instance) +``` + +A custom default value fills every matrix instead of using zeros: + +```@example of +instance = ExpandedMatrixType(1.0; n=10) +instance.original +``` + +If a division does not yield an integer dimension, instantiation throws. With `n=9`, +`n / 2 = 4.5` is not an integer: + +```@example of +try + ExpandedMatrixType(; n=9) # n / 2 = 4.5 is not an integer +catch err + showerror(stdout, err) +end +``` + +## Flattening parameters + +`flatten`/`unflatten` are useful for code that needs a flat parameter vector (for +example, an optimiser or a sampler) while keeping a structured view of the parameters. +We define a small parameter specification: + +```@example of +Params = @of(mu = of(Real), sigma = of(Real, 0, nothing), beta = of(Array, Float64, 3)) +``` + +The total flattened length is `length(Params)`: + +```@example of +length(Params) +``` + +Flattening a structured value produces a flat vector: + +```@example of +values = (mu=0.5, sigma=1.2, beta=[0.1, 0.2, 0.3]) +flat = flatten(Params, values) +``` + +`unflatten` reconstructs the original `(mu, sigma, beta)` NamedTuple: + +```@example of +reconstructed = unflatten(Params, flat) +``` + +`flatten` returns a vector whose element type is the promotion of the declared leaf types, +and `unflatten` is automatic-differentiation transparent: floating-point leaves take +`promote_type(declared, eltype(flat))`, so `ForwardDiff.Dual` (or `BigFloat`, …) numbers in +the flat vector flow through to the reconstructed structure. This makes the pair suitable for +gradient-based samplers and optimisers that differentiate through `unflatten`. + +Constants (fields wrapped with `constant=true`) are excluded from the flattened +representation and must be resolved with `of(T; kwargs...)` before flattening. + +## Use in models + +Because `of` returns schema types, downstream packages can use those types in their own +annotation systems. JuliaBUGS, for instance, accepts an `of` type as the parameter +annotation of a `@model`'s argument destructuring, e.g. `(; mu, beta, sigma)::ParamsType`. +These schema types are not supertypes of raw values, so `1.0 isa of(Float64)` is false; +see the downstream package documentation for the modelling integration. + +## API Reference + +```@docs +of +@of +AbstractPPL.flatten +AbstractPPL.unflatten +Base.rand(::Random.AbstractRNG, ::Type{<:AbstractPPL.OfType}) +Base.zero +Base.size +Base.length +AbstractPPL.OfType +AbstractPPL.OfReal +AbstractPPL.OfInt +AbstractPPL.OfArray +AbstractPPL.OfNamedTuple +AbstractPPL.OfConstantWrapper +AbstractPPL.SymbolicRef +AbstractPPL.SymbolicExpr +``` diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 78f0748f..3b00f186 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -95,4 +95,15 @@ export AbstractOptic, using Accessors: set export set +include("of.jl") +export of, @of +@static if VERSION >= v"1.11.0" + eval( + Meta.parse( + "public OfType, OfReal, OfInt, OfArray, OfNamedTuple, OfConstantWrapper, " * + "flatten, unflatten", + ), + ) +end + end # module diff --git a/src/of.jl b/src/of.jl new file mode 100644 index 00000000..e52ede34 --- /dev/null +++ b/src/of.jl @@ -0,0 +1,2266 @@ +using Random: AbstractRNG, default_rng, randexp + +""" + OfType + +Abstract base type for all types in the `of` type system. + +The `of` type system provides a declarative way to specify parameter types for +probabilistic programming. All `of` types encode their specifications (dimensions, +bounds, etc.) in type parameters so downstream libraries can use them as schema +types in their own annotation systems. + +# Subtypes +- `OfReal{T,Lower,Upper}`: Bounded or unbounded floating-point numbers +- `OfInt{Lower,Upper}`: Bounded or unbounded integers +- `OfArray{T,N,Dims}`: Arrays with specified element type and dimensions +- `OfNamedTuple{Names,Types}`: Named tuples with typed fields +- `OfConstantWrapper{T}`: Wrapper marking a type as constant/hyperparameter + +# See also +[`of`](@ref), [`@of`](@ref) +""" +abstract type OfType end + +""" + SymbolicRef{S} + +Wrapper type for symbolic references in bounds and dimensions. + +Used internally to encode references to earlier constant fields when specifying bounds or dimensions. For example, when using `@of(n=of(Int; constant=true), +data=of(Array, n, 2))`, the reference to `n` in the array dimension is encoded as +`SymbolicRef{:n}`. + +# Type Parameters +- `S`: The symbol being referenced + +# See also +[`@of`](@ref), [`of`](@ref) +""" +struct SymbolicRef{S} end + +""" + SymbolicExpr{E} + +Wrapper type for symbolic expressions in dimensions. + +Used internally to encode arithmetic expressions involving earlier constant fields. For example, +when using `@of(n=of(Int; constant=true), padded=of(Array, n+1, n+1))`, the expression +`n+1` is encoded as `SymbolicExpr{(:+, :n, 1)}`. + +Supported operations: `+`, `-`, `*`, `/`. Division operations must result in integers +when used for array dimensions. + +# Type Parameters +- `E`: A tuple representing the expression in prefix notation + +# See also +[`@of`](@ref), [`of`](@ref) +""" +struct SymbolicExpr{E} end + +""" + OfReal{T<:AbstractFloat,Lower,Upper} + +Type specification for bounded or unbounded floating-point numbers. + +# Type Parameters +- `T<:AbstractFloat`: The concrete floating-point type (e.g., `Float64`, `Float32`) +- `Lower`: Lower bound (numeric value, `Nothing` for unbounded, or `SymbolicRef`) +- `Upper`: Upper bound (numeric value, `Nothing` for unbounded, or `SymbolicRef`) + +# Examples +```julia +of(Float64) # OfReal{Float64, Nothing, Nothing} +of(Float32, 0.0, 1.0) # OfReal{Float32, 0.0, 1.0} +of(Real, 0, nothing) # OfReal{Float64, 0, Nothing} (defaults to Float64) +``` + +# See also +[`of`](@ref), [`@of`](@ref) +""" +struct OfReal{T<:AbstractFloat,Lower,Upper} <: OfType + function OfReal{T,L,U}() where {T<:AbstractFloat,L,U} + return error( + "OfReal is a type specification, not an instantiable object. Use of(Float64, ...) or of(Float32, ...) to create the type.", + ) + end +end + +""" + OfInt{Lower,Upper} + +Type specification for bounded or unbounded integers. + +# Type Parameters +- `Lower`: Lower bound (integer value, `Nothing` for unbounded, or `SymbolicRef`) +- `Upper`: Upper bound (integer value, `Nothing` for unbounded, or `SymbolicRef`) + +# Examples +```julia +of(Int) # OfInt{Nothing, Nothing} +of(Int, 1, 10) # OfInt{1, 10} +of(Int, 0, nothing) # OfInt{0, Nothing} +``` + +# See also +[`of`](@ref), [`@of`](@ref) +""" +struct OfInt{Lower,Upper} <: OfType + function OfInt{L,U}() where {L,U} + return error( + "OfInt is a type specification, not an instantiable object. Use of(Int, ...) to create the type.", + ) + end +end + +""" + OfArray{T,N,Dims} + +Type specification for arrays with fixed element type and dimensions. + +# Type Parameters +- `T`: Element type of the array +- `N`: Number of dimensions +- `Dims`: Tuple type encoding the size of each dimension (can include `SymbolicRef` or `SymbolicExpr`) + +# Examples +```julia +of(Array, 3, 4) # OfArray{Float64, 2, (3, 4)} +of(Array, Float32, 10) # OfArray{Float32, 1, (10,)} +@of(n=of(Int; constant=true), data=of(Array, n, 2)) # Symbolic dimension +``` + +# See also +[`of`](@ref), [`@of`](@ref) +""" +struct OfArray{T,N,Dims} <: OfType + function OfArray{T,N,D}() where {T,N,D} + return error( + "OfArray is a type specification, not an instantiable object. Use of(Array, ...) to create the type.", + ) + end +end + +""" + OfNamedTuple{Names,Types<:Tuple} + +Type specification for named tuples with typed fields. + +# Type Parameters +- `Names`: Tuple of field names as symbols +- `Types<:Tuple`: Tuple of field types (each must be an `OfType`) + +# Examples +```julia +@of(mu=of(Real), tau=of(Real, 0, nothing)) +of((a=of(Int), b=of(Array, 3, 3))) +``` + +# See also +[`of`](@ref), [`@of`](@ref) +""" +struct OfNamedTuple{Names,Types<:Tuple} <: OfType + function OfNamedTuple{Names,Types}() where {Names,Types} + return error( + "OfNamedTuple is a type specification, not an instantiable object. Use of(...) to create the type.", + ) + end +end + +""" + OfConstantWrapper{T<:OfType} + +Wrapper type marking a field as a constant/hyperparameter. + +Constants are not included in flattened representations and must be provided +when creating instances or concretizing types with symbolic dimensions. + +# Type Parameters +- `T<:OfType`: The wrapped type specification + +# Examples +```julia +of(Int; constant=true) # OfConstantWrapper{OfInt{Nothing, Nothing}} +of(Real; constant=true) # OfConstantWrapper{OfReal{Float64, Nothing, Nothing}} +``` + +# See also +[`of`](@ref), [`@of`](@ref) +""" +struct OfConstantWrapper{T<:OfType} <: OfType + function OfConstantWrapper{T}() where {T<:OfType} + return error( + "OfConstantWrapper is a type specification, not an instantiable object. Use of(...; constant=true) to create the type.", + ) + end +end + +get_lower(::Type{OfReal{T,L,U}}) where {T,L,U} = L +get_upper(::Type{OfReal{T,L,U}}) where {T,L,U} = U +get_element_type(::Type{OfReal{T,L,U}}) where {T,L,U} = T +get_lower(::Type{OfInt{L,U}}) where {L,U} = L +get_upper(::Type{OfInt{L,U}}) where {L,U} = U +get_element_type(::Type{OfArray{T,N,D}}) where {T,N,D} = T +get_ndims(::Type{OfArray{T,N,D}}) where {T,N,D} = N +function get_dims(::Type{OfArray{T,N,D}}) where {T,N,D} + return D isa DataType && D <: Tuple ? tuple(D.parameters...) : D +end +get_names(::Type{OfNamedTuple{Names,Types}}) where {Names,Types} = Names +get_types(::Type{OfNamedTuple{Names,Types}}) where {Names,Types} = Types +get_wrapped_type(::Type{OfConstantWrapper{T}}) where {T} = T + +# A dimension is symbolic when it is a bare field symbol or a `SymbolicExpr` parameter. +_is_symbolic_dim(d) = d isa Symbol || (d isa Type && d <: SymbolicExpr) +# A bound is symbolic when it is a `SymbolicRef` or `SymbolicExpr` parameter. +_is_symbolic_bound(b) = b isa Type && (b <: SymbolicRef || b <: SymbolicExpr) + +_is_symbolic_expr_tuple(x) = x isa Tuple && !isempty(x) && x[1] in (:+, :-, :*, :/) + +function _assert_valid_dimension(d) + if _is_symbolic_dim(d) + return nothing + end + d isa Integer || error("Array dimension $d must be an integer or symbolic reference.") + d >= 0 || error("Array dimension $d must be nonnegative.") + return nothing +end + +function _assert_valid_dimensions(dims) + foreach(_assert_valid_dimension, dims) + return nothing +end + +function _assert_valid_bound_spec(lower, upper; kind::AbstractString="bounds") + if !_is_symbolic_bound(lower) && !_is_symbolic_bound(upper) + lo = type_to_bound(lower) + hi = type_to_bound(upper) + if !isnothing(lo) && !isnothing(hi) && lo > hi + error("Invalid $kind: lower bound $lo is greater than upper bound $hi.") + end + end + return nothing +end + +function _normalise_int_bound(bound) + if bound === Nothing || _is_symbolic_bound(bound) + return bound + end + value = type_to_bound(bound) + if value isa Integer + return Int(value) + elseif value isa Real && isinteger(value) + return Int(value) + else + error("Int bound $value must be an integer or symbolic reference.") + end +end + +# Fail fast (with a clear message) before a symbolic bound reaches arithmetic/comparison. +function _assert_concrete_bounds(::Type{T}) where {T<:OfType} + if _is_symbolic_bound(get_lower(T)) || _is_symbolic_bound(get_upper(T)) + error( + "Cannot instantiate a type with symbolic bounds. Resolve them with of(T; name=value) first.", + ) + end + return nothing +end + +is_leaf(::Type{<:OfArray}) = true +is_leaf(::Type{<:OfReal}) = true +is_leaf(::Type{<:OfInt}) = true +is_leaf(::Type{<:OfNamedTuple}) = false +is_leaf(::Type{<:OfConstantWrapper}) = true + +bound_to_type(::Nothing) = Nothing +bound_to_type(x::Real) = x +bound_to_type(s::Symbol) = SymbolicRef{s} +bound_to_type(s::QuoteNode) = SymbolicRef{s.value} + +type_to_bound(::Type{Nothing}) = nothing +type_to_bound(::Type{x}) where {x<:Real} = x +type_to_bound(::Type{SymbolicRef{S}}) where {S} = S +type_to_bound(s::Symbol) = s +type_to_bound(x::Real) = x + +function _eval_symbolic_op(op, args) + if op == :+ + return sum(args) + elseif op == :- + length(args) in (1, 2) || error("Subtraction requires 1 or 2 arguments") + return length(args) == 1 ? -args[1] : args[1] - args[2] + elseif op == :* + return prod(args) + elseif op == :/ + length(args) == 2 || error("Division requires exactly 2 arguments") + result = args[1] / args[2] + # Array dimensions must be integers; guard the result itself, not just the dividend. + isinteger(result) || error( + "Division $(args[1]) / $(args[2]) = $result is not an integer. Array dimensions must be integers.", + ) + return Int(result) + end +end + +function _check_symbolic_expr_header(expr::Tuple) + if length(expr) < 2 + error("Invalid expression format: $expr") + end + + op = expr[1] + if !(op in (:+, :-, :*, :/)) + error("Unsupported operation: $op. Only +, -, *, / are supported.") + end + return op +end + +function eval_symbolic_expr(expr::Tuple, bindings::NamedTuple) + op = _check_symbolic_expr_header(expr) + + args = map(expr[2:end]) do arg + if arg isa Symbol + if haskey(bindings, arg) + bindings[arg] + else + error("Symbol '$arg' not found in bindings") + end + elseif arg isa Tuple + eval_symbolic_expr(arg, bindings) + else + arg + end + end + + return _eval_symbolic_op(op, args) +end + +function substitute_symbolic_expr(expr::Tuple, bindings::NamedTuple) + op = _check_symbolic_expr_header(expr) + unresolved = false + + args = map(expr[2:end]) do arg + if arg isa Symbol + if haskey(bindings, arg) + bindings[arg] + else + unresolved = true + arg + end + elseif _is_symbolic_expr_tuple(arg) + substituted = substitute_symbolic_expr(arg, bindings) + if _is_symbolic_expr_tuple(substituted) + unresolved = true + end + substituted + else + arg + end + end + + return unresolved ? Tuple((op, args...)) : _eval_symbolic_op(op, args) +end + +# Resolve bound references during type concretization +function resolve_bound(::Type{Nothing}, replacements::NamedTuple) + return Nothing +end + +function resolve_bound(::Type{x}, replacements::NamedTuple) where {x<:Real} + return x +end + +function resolve_bound(::Type{SymbolicRef{S}}, replacements::NamedTuple) where {S} + if haskey(replacements, S) + return bound_to_type(replacements[S]) + else + return SymbolicRef{S} + end +end + +function resolve_bound(::Type{SymbolicExpr{E}}, replacements::NamedTuple) where {E} + evaluated = substitute_symbolic_expr(E, replacements) + return if _is_symbolic_expr_tuple(evaluated) + SymbolicExpr{evaluated} + else + bound_to_type(evaluated) + end +end + +function resolve_bound(T::Type, ::NamedTuple) + return T +end + +function resolve_bound(x::Real, ::NamedTuple) + return x +end + +function process_array_dimensions(dims) + # Unwrap quoted field references; keep everything else (including SymbolicExpr{...}) as-is. + processed_dims = map(d -> d isa QuoteNode ? d.value : d, dims) + _assert_valid_dimensions(processed_dims) + if length(processed_dims) == 1 + Tuple{processed_dims[1]} + else + Tuple{processed_dims...} + end +end + +function process_bounds(lower, upper) + L = if lower isa Type + lower + else + bound_to_type(lower) + end + U = if upper isa Type + upper + else + bound_to_type(upper) + end + _assert_valid_bound_spec(L, U) + return L, U +end + +""" + of(T, args...; constant::Bool=false) + +Create an `OfType` specification from various inputs. + +# Main Methods + +## Arrays +```julia +of(Array, dims...) # Float64 array with given dimensions +of(Array, T, dims...) # Array with numeric element type T and given dimensions +``` + +## Real Numbers +```julia +of(Float64) # Unbounded Float64 +of(Float64, lower, upper) # Bounded Float64 +of(Float32) # Unbounded Float32 +of(Float32, lower, upper) # Bounded Float32 +of(Real) # Unbounded Real (defaults to Float64) +of(Real, lower, upper) # Bounded Real (defaults to Float64) +``` + +## Integers +```julia +of(Int) # Unbounded integer +of(Int, lower, upper) # Bounded integer +``` + +## Named Tuples +```julia +of((;field1=spec1, field2=spec2, ...)) # NamedTuple with typed fields +``` + +## From Values (Type Inference) +```julia +of(1.0) # Infers of(Float64) +of([1, 2, 3]) # Infers of(Array, Int, 3) +of((a=1, b=2.0)) # Infers OfNamedTuple +``` + +# Arguments +- `T`: Type to create specification for +- `args...`: Type-specific arguments (bounds, dimensions, etc.) +- `constant`: Mark type as constant/hyperparameter (default: false) + +# Returns +An `OfType` subtype encoding the specification in its type parameters. + +# Examples +```julia +# Basic types +T1 = of(Float64, 0, 1) # OfReal{Float64, 0, 1} +T2 = of(Array, 3, 4) # OfArray{Float64, 2, (3, 4)} +T3 = of(Int; constant=true) # OfConstantWrapper{OfInt{Nothing, Nothing}} + +# With @of macro for cleaner syntax +T4 = @of( + n = of(Int; constant=true), + data = of(Array, n, 2) # Symbolic dimension +) + +# Type concretization +T5 = of(T4; n=10) # Concrete type with n=10 +``` + +# See also +[`@of`](@ref), [`OfType`](@ref) +""" +function of(::Type{Array}, dims...; constant::Bool=false) + if constant + error("constant=true is only supported for Int and Real types, not Array") + end + # Default to Float64 for unspecified array types + dims_tuple = process_array_dimensions(dims) + return OfArray{Float64,length(dims),dims_tuple} +end + +function of(::Type{Array}, T::Type, dims...; constant::Bool=false) + # Check if T is a symbolic expression type (which should be treated as a dimension) + if T <: SymbolicExpr + # This is actually a dimension, not an element type + # Construct the array type directly with Float64 as element type + if constant + error("constant=true is only supported for Int and Real types, not Array") + end + all_dims = (T, dims...) + dims_tuple = process_array_dimensions(all_dims) + return OfArray{Float64,length(all_dims),dims_tuple} + end + + if constant + error("constant=true is only supported for Int and Real types, not Array") + end + T <: Number || error("Array element type must be a subtype of Number, got $T") + dims_tuple = process_array_dimensions(dims) + return OfArray{T,length(dims),dims_tuple} +end + +function of(::Type{Int}; constant::Bool=false) + base_type = OfInt{Nothing,Nothing} + return constant ? OfConstantWrapper{base_type} : base_type +end + +function of( + ::Type{Int}, + lower::Union{Int,Nothing,Symbol,Type}, + upper::Union{Int,Nothing,Symbol,Type}; + constant::Bool=false, +) + L, U = process_bounds(lower, upper) + base_type = OfInt{L,U} + return constant ? OfConstantWrapper{base_type} : base_type +end + +function of(::Type{T}; constant::Bool=false) where {T<:AbstractFloat} + base_type = OfReal{T,Nothing,Nothing} + return constant ? OfConstantWrapper{base_type} : base_type +end + +function of( + ::Type{T}, + lower::Union{Real,Nothing,Symbol,Type}, + upper::Union{Real,Nothing,Symbol,Type}; + constant::Bool=false, +) where {T<:AbstractFloat} + L, U = process_bounds(lower, upper) + base_type = OfReal{T,L,U} + return constant ? OfConstantWrapper{base_type} : base_type +end + +# of(Real) creates Float64 +function of(::Type{Real}; constant::Bool=false) + base_type = OfReal{Float64,Nothing,Nothing} + return constant ? OfConstantWrapper{base_type} : base_type +end + +function of( + ::Type{Real}, + lower::Union{Real,Nothing,Symbol,Type}, + upper::Union{Real,Nothing,Symbol,Type}; + constant::Bool=false, +) + L, U = process_bounds(lower, upper) + base_type = OfReal{Float64,L,U} + return constant ? OfConstantWrapper{base_type} : base_type +end + +# Infer OfType from concrete values +function of(value::T) where {T<:AbstractFloat} + return of(T) +end + +function of(value::Integer) + return of(Int) +end + +# Fallback for other Real types +function of(value::Real) + return of(Float64) +end + +function of(value::AbstractArray{T,N}) where {T,N} + return of(Array, T, size(value)...) +end + +function of(value::NamedTuple{names}) where {names} + # Check if all values are already OfType types + vals = values(value) + if all(v -> v isa Type && v <: OfType, vals) + # This is a NamedTuple of types, not values + return OfNamedTuple{names,Tuple{vals...}} + else + # This is a NamedTuple of values, infer types + of_types = map(of, vals) + return OfNamedTuple{names,Tuple{of_types...}} + end +end + +function resolve_bounded_type(::Type{T}, replacements::NamedTuple) where {T<:OfType} + if !(T <: OfReal || T <: OfInt) + return T + end + + lower = get_lower(T) + upper = get_upper(T) + new_lower = resolve_bound(lower, replacements) + new_upper = resolve_bound(upper, replacements) + + if new_lower !== lower || new_upper !== upper + if T <: OfReal + _assert_valid_bound_spec(new_lower, new_upper) + elem_type = get_element_type(T) + return OfReal{elem_type,new_lower,new_upper} + elseif T <: OfInt + new_lower = _normalise_int_bound(new_lower) + new_upper = _normalise_int_bound(new_upper) + _assert_valid_bound_spec(new_lower, new_upper) + return OfInt{new_lower,new_upper} + end + else + return T + end +end + +""" + of(::Type{T}, replacements::NamedTuple) where T<:OfType + of(::Type{T}; kwargs...) where T<:OfType + of(::Type{T}, pairs::Pair{Symbol}...) where T<:OfType + +Create a concrete type by resolving symbolic dimensions and removing constants. + +This function takes an `OfType` with symbolic dimensions or constants and creates +a new type with some or all symbols resolved to concrete values. Constants that +are provided are removed from the resulting type. + +# Arguments +- `T<:OfType`: The type to concretize +- `replacements`: Named tuple or keyword arguments mapping symbols to values + +# Returns +A new `OfType` with symbols replaced and constants removed. + +# Examples +```julia +# Define type with symbolic dimensions +T = @of( + n = of(Int; constant=true), + data = of(Array, n, 2) +) + +# Create concrete type +ConcreteT = of(T; n=10) # @of(data=of(Array, 10, 2)) + +# Partial concretization +T2 = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + matrix = of(Array, rows, cols) +) +Partial = of(T2; rows=5) # @of(cols=of(Int; constant=true), matrix=of(Array, 5, :cols)) +``` + +# See also +[`of`](@ref), [`@of`](@ref) +""" +function of(::Type{T}, pairs::Pair{Symbol}...) where {T<:OfType} + return of(T, NamedTuple(pairs)) +end + +function of(::Type{T}; kwargs...) where {T<:OfType} + return of(T, NamedTuple(kwargs)) +end + +function _resolve_dimensions(dims, replacements::NamedTuple) + return map(dims) do d + if d isa Symbol && haskey(replacements, d) + resolved = replacements[d] + _assert_valid_dimension(resolved) + resolved + elseif d isa Type && d <: SymbolicExpr + expr = d.parameters[1] + resolved = substitute_symbolic_expr(expr, replacements) + if _is_symbolic_expr_tuple(resolved) + SymbolicExpr{resolved} + else + _assert_valid_dimension(resolved) + resolved + end + else + d + end + end +end + +function _normalise_constant_replacements( + ::Type{OfNamedTuple{Names,Types}}, replacements::NamedTuple +) where {Names,Types} + normalised = (;) + for i in 1:length(Names) + name = Names[i] + field_type = Types.parameters[i] + if field_type <: OfConstantWrapper && haskey(replacements, name) + wrapped = get_wrapped_type(field_type) + resolved = of(wrapped, normalised) + has_symbolic_dims(resolved) && error( + "Cannot validate constant `$name` while its type still has unresolved symbols.", + ) + value = _validate_leaf(resolved, replacements[name]) + normalised = merge(normalised, NamedTuple{(name,)}((value,))) + end + end + return merge(replacements, normalised) +end + +function _process_field_for_concretization( + name::Symbol, field_type::Type, replacements::NamedTuple +) + # Case 1: Constant field that has been resolved and validated - skip it + if field_type <: OfConstantWrapper && haskey(replacements, name) + return nothing + end + + # Case 2: Array with potentially symbolic dimensions + if field_type <: OfArray + dims = get_dims(field_type) + new_dims = _resolve_dimensions(dims, replacements) + + if new_dims != dims + T = get_element_type(field_type) + return (name, of(Array, T, new_dims...)) + else + return (name, field_type) + end + end + + # Case 3: Nested NamedTuple - recursively concretize + if field_type <: OfNamedTuple + nested = _concretize_namedtuple(field_type, replacements; allow_empty=true) + return isnothing(nested) ? nothing : (name, nested) + end + + # Case 4: Bounded types (Real/Int) with potentially symbolic bounds + if field_type <: OfReal || field_type <: OfInt + return (name, resolve_bounded_type(field_type, replacements)) + end + + # Case 5: Constant wrapper without replacement - resolve wrapped type + if field_type <: OfConstantWrapper + wrapped = get_wrapped_type(field_type) + resolved = resolve_bounded_type(wrapped, replacements) + + if resolved !== wrapped + return (name, OfConstantWrapper{resolved}) + else + return (name, field_type) + end + end + + # Case 6: Other types - pass through unchanged + return (name, field_type) +end + +function _concretize_namedtuple( + ::Type{OfNamedTuple{Names,Types}}, replacements::NamedTuple; allow_empty::Bool=false +) where {Names,Types} + replacements = _normalise_constant_replacements(OfNamedTuple{Names,Types}, replacements) + processed_fields = [] + + for i in 1:length(Names) + name = Names[i] + field_type = Types.parameters[i] + + result = _process_field_for_concretization(name, field_type, replacements) + + if !isnothing(result) + push!(processed_fields, result) + end + end + + # Check if any fields remain + if isempty(processed_fields) + allow_empty && return nothing + error("All fields were constants and have been resolved. No fields remain.") + end + + # Extract names and types from processed fields + remaining_names = [field[1] for field in processed_fields] + remaining_types = [field[2] for field in processed_fields] + + return OfNamedTuple{Tuple(remaining_names),Tuple{remaining_types...}} +end + +function of(::Type{OfNamedTuple{Names,Types}}, replacements::NamedTuple) where {Names,Types} + return _concretize_namedtuple(OfNamedTuple{Names,Types}, replacements) +end + +function of(::Type{OfArray{T,N,D}}, replacements::NamedTuple) where {T,N,D} + # Replace symbolic dimensions in array types + dims = get_dims(OfArray{T,N,D}) + new_dims = _resolve_dimensions(dims, replacements) + return of(Array, T, new_dims...) +end + +function of(::Type{OfReal{T,L,U}}, replacements::NamedTuple) where {T,L,U} + return resolve_bounded_type(OfReal{T,L,U}, replacements) +end + +function of(::Type{OfInt{L,U}}, replacements::NamedTuple) where {L,U} + return resolve_bounded_type(OfInt{L,U}, replacements) +end + +function of(::Type{OfConstantWrapper{T}}, replacements::NamedTuple) where {T} + resolved = of(T, replacements) + return resolved === T ? OfConstantWrapper{T} : OfConstantWrapper{resolved} +end + +function of(::Type{T}, replacements::NamedTuple) where {T<:OfType} + return T +end + +function _create_with_default(::Type{OfArray{T,N,D}}, default_value) where {T,N,D} + dims = get_dims(OfArray{T,N,D}) + if any(d -> d isa Symbol || (d isa Type && d <: SymbolicExpr), dims) + error( + "Cannot create array with symbolic dimensions. Use T(default_value; kwargs...) with dimension values.", + ) + end + # Handle missing values specially - create array of Union{T,Missing} + if default_value === missing + return fill(missing, dims) + else + return fill(convert(T, default_value), dims) + end +end + +function _create_with_default(::Type{OfReal{T,L,U}}, default_value) where {T,L,U} + if default_value === missing + return missing + end + val = convert(T, default_value) + lower = type_to_bound(L) + upper = type_to_bound(U) + validate_bounds(val, lower, upper; kind="Real value") + return val +end + +function _create_with_default(::Type{OfInt{L,U}}, default_value) where {L,U} + if default_value === missing + return missing + end + # Use round for converting floats to ints + val = if isa(default_value, Integer) + convert(Int, default_value) + else + round(Int, default_value) + end + lower = type_to_bound(L) + upper = type_to_bound(U) + validate_bounds(val, lower, upper; kind="Int value") + return val +end + +function _create_with_default( + ::Type{OfNamedTuple{Names,Types}}, default_value +) where {Names,Types} + values = Tuple( + _create_with_default(Types.parameters[i], default_value) for i in 1:length(Names) + ) + return NamedTuple{Names}(values) +end + +function _create_with_default(::Type{OfConstantWrapper{T}}, default_value) where {T} + return error( + "Cannot create values for constants. Provide the constant value in T(; const_name=value).", + ) +end + +function _create_instance_impl( + ::Type{T}, value_generator::Function, kwargs +) where {T<:OfNamedTuple} + names = get_names(T) + types = get_types(T) + + unknown = setdiff(collect(keys(kwargs)), collect(names)) + isempty(unknown) || error("Unknown field(s): $(join(unknown, ", "))") + + constants = Dict{Symbol,Any}() + values = Dict{Symbol,Any}() + + for (key, val) in pairs(kwargs) + idx = findfirst(==(key), names) + if idx !== nothing && types.parameters[idx] <: OfConstantWrapper + constants[key] = val + else + values[key] = val + end + end + + for (idx, name) in enumerate(names) + if types.parameters[idx] <: OfConstantWrapper && !haskey(constants, name) + error("Constant `$name` is required but not provided") + end + end + + # First concretize with constants + concrete_type = of(T, NamedTuple(constants)) + + if has_symbolic_dims(concrete_type) + missing_symbols = get_unresolved_symbols(concrete_type) + error("Missing values for symbolic dimensions: $(join(missing_symbols, ", "))") + end + + # Get the names and types from the concrete type (constants removed) + concrete_names = get_names(concrete_type) + concrete_types = get_types(concrete_type) + + # Build the result with provided values or defaults + result_values = Any[] + for (idx, name) in enumerate(concrete_names) + field_type = concrete_types.parameters[idx] + + if haskey(values, name) + try + push!(result_values, _validate(field_type, values[name])) + catch e + error("Validation failed for field $name: $(sprint(showerror, e))") + end + else + push!(result_values, value_generator(field_type)) + end + end + + return NamedTuple{concrete_names}(Tuple(result_values)) +end + +""" + (::Type{T})(; kwargs...) where {T<:OfNamedTuple} + (::Type{T})(default_value; kwargs...) where {T<:OfNamedTuple} + +Create an instance of an `OfNamedTuple` type with specified constant values. + +This constructor creates actual values (instances) from an `OfNamedTuple` specification. +It requires all constants to be provided and initializes non-constant fields +to zero or a specified default value. Only `OfNamedTuple` types are constructible this +way; the leaf types (`OfReal`, `OfInt`, …) are specifications, not instantiable objects. + +# Arguments +- `T<:OfNamedTuple`: The named-tuple specification to instantiate +- `default_value`: Optional value to initialize all non-constant fields (default: appropriate zero) +- `kwargs...`: Constant values and optionally non-constant field values + +# Returns +A `NamedTuple` instance with all fields initialized. + +# Examples +```julia +# Define type with constants +T = @of( + n = of(Int; constant=true), + mu = of(Real), + data = of(Array, n, 2) +) + +# Create instance with constants (non-constants default to zero) +instance = T(; n=5) +# Returns: (mu = 0.0, data = zeros(5, 2)) + +# Create instance with custom default +instance = T(1.0; n=5) +# Returns: (mu = 1.0, data = ones(5, 2)) + +# Create instance with missing values +instance = T(missing; n=5) +# Returns: (mu = missing, data = 5×2 Array{Missing}) + +# Provide specific values for non-constants +instance = T(; n=5, mu=2.5, data=rand(5, 2)) +# Returns: (mu = 2.5, data = ) +``` + +# Errors +- Throws error if required constants are not provided + +# See also +[`of`](@ref), [`@of`](@ref) +""" +function (::Type{T})(; kwargs...) where {T<:OfNamedTuple} + return _create_instance_impl(T, zero, kwargs) +end + +function (::Type{T})(default_value; kwargs...) where {T<:OfNamedTuple} + return _create_instance_impl( + T, field_type -> _create_with_default(field_type, default_value), kwargs + ) +end + +""" + @of(field1=spec1, field2=spec2, ...) + +Create an `OfNamedTuple` type with cleaner syntax for field references. + +The `@of` macro provides a more intuitive syntax for creating named tuple types +where fields can reference each other. Field names used in dimensions or bounds +are automatically converted to symbolic references. + +# Syntax +```julia +@of( + field_name = of_specification, + ... +) +``` + +# Features +- Direct references to earlier constant fields without quoting (e.g., `n` instead of `:n`) +- Support for arithmetic expressions in dimensions (e.g., `n+1`, `2*n`) +- Automatic conversion to appropriate `OfNamedTuple` type +- Fields are processed in order, allowing later fields to reference earlier constants + +# Examples +```julia +# Basic usage with constants and arrays +T = @of( + n = of(Int; constant=true), + mu = of(Real), + data = of(Array, n, 2) # 'n' automatically converted to symbolic reference +) + +# With arithmetic expressions +T = @of( + n = of(Int; constant=true), + original = of(Array, n, n), + padded = of(Array, n+1, n+1), + doubled = of(Array, 2*n, n) +) + +# Nested structures: field references resolve within the @of that declares them, so keep +# a dimension's constants in the same block (cross-level paths like `dims.rows` are not +# supported), then concretize the whole tree at once. +Inner = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + matrix = of(Array, rows, cols) +) +T = @of(block = Inner) +CT = of(T; rows=3, cols=4) +``` + +# See also +[`of`](@ref), [`OfNamedTuple`](@ref) +""" +macro of(args...) + # Parse the arguments to extract field specifications + fields = Dict{Symbol,Any}() + field_order = Symbol[] + + for arg in args + if !(arg isa Expr && arg.head == :(=) && length(arg.args) == 2) + error("@of expects keyword arguments like field=spec") + end + + field_name = arg.args[1] + field_spec = arg.args[2] + + if field_name isa Symbol + fields[field_name] = field_spec + push!(field_order, field_name) + else + error("Field name must be a symbol, got $(field_name)") + end + end + + processed_fields = Dict{Symbol,Any}() + available_constants = Symbol[] + + for field_name in field_order + spec = fields[field_name] + processed_spec = process_of_spec(spec, available_constants, field_order) + processed_fields[field_name] = processed_spec + if is_constant_spec(spec) + push!(available_constants, field_name) + end + end + + nt_expr = Expr(:tuple) + for field_name in field_order + push!(nt_expr.args, Expr(:(=), field_name, processed_fields[field_name])) + end + + return esc(:($(GlobalRef(@__MODULE__, :of))($nt_expr))) +end + +function is_constant_spec(spec) + spec isa Expr || return false + spec.head == :call || return false + func = spec.args[1] + is_of_call = func === :of || (func isa GlobalRef && func.name === :of) + is_of_call || return false + for arg in spec.args[2:end] + if arg isa Expr && arg.head == :parameters + any(_is_constant_kw, arg.args) && return true + elseif _is_constant_kw(arg) + return true + end + end + return false +end + +function _is_constant_kw(arg) + return arg isa Expr && + arg.head == :kw && + arg.args[1] === :constant && + arg.args[2] === true +end + +# Process an of specification, converting field references to symbols +function process_of_spec( + spec::Expr, available_constants::Vector{Symbol}, all_fields::Vector{Symbol} +) + if spec.head == :call && length(spec.args) >= 1 + func = spec.args[1] + + # Check if this is an of(...) call + if func === :of || (func isa GlobalRef && func.name === :of) + # Process the arguments + new_args = Any[GlobalRef(@__MODULE__, :of)] + + # Separate positional and keyword arguments + pos_args = [] + kw_args = [] + + for arg in spec.args[2:end] + if arg isa Expr && arg.head == :parameters + # Handle parameters block (e.g., f(x; a=1, b=2)) + for param in arg.args + push!(kw_args, param) + end + elseif arg isa Expr && arg.head == :kw + # Handle individual keyword argument + push!(kw_args, arg) + else + push!(pos_args, arg) + end + end + + # Process positional arguments + for arg in pos_args + processed_arg = process_dimension_arg(arg, available_constants, all_fields) + push!(new_args, processed_arg) + end + + if !isempty(kw_args) + params_expr = Expr(:parameters, kw_args...) + insert!(new_args, 2, params_expr) + end + + return Expr(:call, new_args...) + else + # Not an of call, process recursively + return Expr( + spec.head, + [ + process_of_spec(arg, available_constants, all_fields) for + arg in spec.args + ]..., + ) + end + else + return spec + end +end + +process_of_spec(x, ::Vector{Symbol}, ::Vector{Symbol}) = x + +# Process a dimension/bound argument, converting field references to symbols +function process_dimension_arg( + arg, available_constants::Vector{Symbol}, all_fields::Vector{Symbol} +) + if arg isa Symbol && arg in available_constants + # Convert field reference to symbol + return QuoteNode(arg) + elseif arg isa Symbol && arg in all_fields + error( + "Field reference `$arg` is not available here. Symbolic dimensions and bounds may only refer to earlier fields declared with constant=true.", + ) + elseif arg isa QuoteNode && arg.value isa Symbol && arg.value in available_constants + return arg + elseif arg isa QuoteNode && arg.value isa Symbol && arg.value in all_fields + error( + "Field reference `$(arg.value)` is not available here. Symbolic dimensions and bounds may only refer to earlier fields declared with constant=true.", + ) + elseif arg isa Expr + return process_expression_refs(arg, available_constants, all_fields) + else + return arg + end +end + +function _reject_field_refs_in_unsupported_expr(x, all_fields::Vector{Symbol}) + if x isa Symbol && x in all_fields + error( + "Field reference `$x` appears in an unsupported expression. Symbolic references may only be bare symbols or use +, -, *, / arithmetic.", + ) + elseif x isa QuoteNode && x.value isa Symbol && x.value in all_fields + error( + "Field reference `$(x.value)` appears in an unsupported expression. Symbolic references may only be bare symbols or use +, -, *, / arithmetic.", + ) + elseif x isa Expr + for arg in x.args + _reject_field_refs_in_unsupported_expr(arg, all_fields) + end + end + return nothing +end + +# Check if a processed expression is a `SymbolicExpr{...}` type. The head may be the bare +# symbol or the `GlobalRef` the macro emits for hygiene, so accept both. +function _is_symbolic_expr_type(expr::Expr) + (expr.head == :curly && length(expr.args) >= 2) || return false + head = expr.args[1] + return head === :SymbolicExpr || (head isa GlobalRef && head.name === :SymbolicExpr) +end + +# Process a single argument in an arithmetic expression +function _process_arithmetic_arg( + arg, available_constants::Vector{Symbol}, all_fields::Vector{Symbol} +) + if arg isa Symbol && arg in available_constants + return (QuoteNode(arg), true) + elseif arg isa Symbol && arg in all_fields + error( + "Field reference `$arg` is not available here. Symbolic dimensions and bounds may only refer to earlier fields declared with constant=true.", + ) + elseif arg isa Expr + processed = process_expression_refs(arg, available_constants, all_fields) + if _is_symbolic_expr_type(processed) + # Extract the tuple from SymbolicExpr{...} + return (processed.args[2], true) + else + return (processed, false) + end + else + return (arg, false) + end +end + +# Process an arithmetic expression, converting field references to symbols +function _process_arithmetic_expr( + expr::Expr, available_constants::Vector{Symbol}, all_fields::Vector{Symbol} +) + op = expr.args[1] + + # Build tuple representation: (op, arg1, arg2, ...) + tuple_args = Any[QuoteNode(op)] + has_field_ref = false + + for arg in expr.args[2:end] + processed_arg, has_ref = _process_arithmetic_arg( + arg, available_constants, all_fields + ) + push!(tuple_args, processed_arg) + has_field_ref |= has_ref + end + + if has_field_ref + # Emit a fully-qualified `SymbolicExpr` so the escaped `@of` expansion resolves + # even when the caller only does `using AbstractPPL` (the name is `public`, not exported). + tuple_expr = Expr(:tuple, tuple_args...) + return :($(GlobalRef(@__MODULE__, :SymbolicExpr)){$tuple_expr}) + else + return expr + end +end + +# Process an expression, converting field references to symbols in expressions +function process_expression_refs( + expr::Expr, available_constants::Vector{Symbol}, all_fields::Vector{Symbol} +) + # Check if this is an arithmetic call expression + if expr.head == :call && length(expr.args) >= 2 + op = expr.args[1] + if op in [:+, :-, :*, :/] + return _process_arithmetic_expr(expr, available_constants, all_fields) + end + end + + _reject_field_refs_in_unsupported_expr(expr, all_fields) + return expr +end + +""" + rand([rng::AbstractRNG], ::Type{T}) where T<:OfType + +Generate random values matching the type specification. + +Creates random instances that satisfy the constraints encoded in the `OfType`. +Arrays are filled with random values, named tuples recurse over their fields, and +bounded scalars respect their bounds. Pass an `rng` for reproducible draws; the +method without one uses `Random.default_rng()`. + +The bounded/unbounded distributions are deliberate but unspecified conveniences: + +- `OfReal`: uniform on `[lower, upper]`; for a single bound, a (reflected) shifted + exponential; standard normal when unbounded. +- `OfInt`: uniform on `lower:upper`; for a single bound, an arbitrary 100-wide window + anchored at it; `-100:100` when unbounded. +- `OfArray`: each element drawn as for its element type. +- `OfNamedTuple`: each field drawn recursively. +- `OfConstantWrapper`: not supported (throws). + +# Examples +```julia +using Random + +rand(of(Float64, 0, 1)) # a Float64 in [0, 1] +rand(of(Array, 3, 4)) # a 3×4 Matrix{Float64} +rand(StableRNG(1), @of(x=of(Real, 0, 1))) # reproducible NamedTuple draw + +T = @of(n=of(Int; constant=true), data=of(Array, n, 2)) +rand(of(T; n=5)) # (data = ,) +``` + +# Errors +- Throws for types with unresolved symbolic dimensions or bounds. +- Throws for constant wrapper types. + +# See also +[`zero`](@ref), [`of`](@ref) +""" +Base.rand(::AbstractRNG, ::Type{<:OfType}) + +# Each type gets an explicit-RNG method (the contract downstream samplers rely on); the +# convenience method without an RNG forwards to `Random.default_rng()`. +function Base.rand(rng::AbstractRNG, ::Type{OfArray{T,N,D}}) where {T,N,D} + dims = get_dims(OfArray{T,N,D}) + any(_is_symbolic_dim, dims) && error( + "Cannot generate random array with symbolic dimensions. Resolve them with of(T; name=value) first.", + ) + return rand(rng, T, dims) +end +Base.rand(::Type{OfArray{T,N,D}}) where {T,N,D} = rand(default_rng(), OfArray{T,N,D}) + +function Base.rand(rng::AbstractRNG, ::Type{OfReal{T,L,U}}) where {T,L,U} + _assert_concrete_bounds(OfReal{T,L,U}) + lower = type_to_bound(L) + upper = type_to_bound(U) + if !isnothing(lower) && !isnothing(upper) + return T(lower + rand(rng) * (upper - lower)) + elseif !isnothing(lower) + # Lower bound only: draw from a shifted exponential on [lower, ∞). + return T(lower + randexp(rng)) + elseif !isnothing(upper) + # Upper bound only: draw from a reflected shifted exponential on (-∞, upper]. + return T(upper - randexp(rng)) + else + return T(randn(rng)) + end +end +Base.rand(::Type{OfReal{T,L,U}}) where {T,L,U} = rand(default_rng(), OfReal{T,L,U}) + +function Base.rand(rng::AbstractRNG, ::Type{OfInt{L,U}}) where {L,U} + _assert_concrete_bounds(OfInt{L,U}) + lower = type_to_bound(L) + upper = type_to_bound(U) + if !isnothing(lower) && !isnothing(upper) + return rand(rng, lower:upper) + elseif !isnothing(lower) + # Lower bound only: an arbitrary but reasonable [lower, lower+100] window. + return rand(rng, lower:(lower + 100)) + elseif !isnothing(upper) + return rand(rng, (upper - 100):upper) + else + return rand(rng, -100:100) + end +end +Base.rand(::Type{OfInt{L,U}}) where {L,U} = rand(default_rng(), OfInt{L,U}) + +# `@generated` so the per-field draws unroll and the NamedTuple eltypes stay inferable. +@generated function Base.rand( + rng::AbstractRNG, ::Type{OfNamedTuple{Names,Types}} +) where {Names,Types} + draws = [:(rand(rng, $P)) for P in Types.parameters] + return :(NamedTuple{Names}(($(draws...),))) +end +function Base.rand(::Type{OfNamedTuple{Names,Types}}) where {Names,Types} + return rand(default_rng(), OfNamedTuple{Names,Types}) +end + +function Base.rand(::AbstractRNG, ::Type{OfConstantWrapper{T}}) where {T} + return error( + "Cannot generate random values for constants. Use rand(of(T; const_name=value)) after providing the constant value.", + ) +end +function Base.rand(::Type{OfConstantWrapper{T}}) where {T} + return rand(default_rng(), OfConstantWrapper{T}) +end + +""" + zero(::Type{T}) where T<:OfType + +Generate zero/default values matching the type specification. + +Creates instances initialized to appropriate zero values that satisfy the +constraints. For bounded types where zero is outside the bounds, returns +the nearest bound value. + +# Behavior by Type +- `OfReal`: Returns 0.0 if within bounds, otherwise nearest bound +- `OfInt`: Returns 0 if within bounds, otherwise nearest bound +- `OfArray`: Returns array filled with zeros +- `OfNamedTuple`: Recursively generates zero values for all fields +- `OfConstantWrapper`: Not supported (throws error) + +# Examples +```julia +# Unbounded types +zero(of(Float64)) # 0.0 +zero(of(Int)) # 0 +zero(of(Array, 3, 2)) # 3×2 matrix of zeros + +# Bounded types respect bounds +zero(of(Real, 1.0, 2.0)) # 1.0 (lower bound since 0 is outside) +zero(of(Int, -10, -5)) # -5 (upper bound since 0 is outside) + +# Named tuples +T = @of(x=of(Real), y=of(Array, 2, 2)) +zero(T) # (x=0.0, y=[0.0 0.0; 0.0 0.0]) + +# With resolved constants +T = @of(n=of(Int; constant=true), data=of(Array, n, n)) +zero(of(T; n=3)) # (data = 3×3 zero matrix) +``` + +# Errors +- Throws error for types with unresolved symbolic dimensions +- Throws error for constant wrapper types + +# See also +[`rand`](@ref), [`of`](@ref) +""" +function Base.zero(::Type{OfArray{T,N,D}}) where {T,N,D} + dims = get_dims(OfArray{T,N,D}) + any(_is_symbolic_dim, dims) && error( + "Cannot create zero array with symbolic dimensions. Resolve them with of(T; name=value) first.", + ) + return zeros(T, dims) +end + +function Base.zero(::Type{OfReal{T,L,U}}) where {T,L,U} + _assert_concrete_bounds(OfReal{T,L,U}) + lower = type_to_bound(L) + upper = type_to_bound(U) + if !isnothing(lower) && lower > 0 + return T(lower) + elseif !isnothing(upper) && upper < 0 + return T(upper) + else + return zero(T) + end +end + +function Base.zero(::Type{OfInt{L,U}}) where {L,U} + _assert_concrete_bounds(OfInt{L,U}) + lower = type_to_bound(L) + upper = type_to_bound(U) + if !isnothing(lower) && lower > 0 + return lower + elseif !isnothing(upper) && upper < 0 + return upper + else + return 0 + end +end + +@generated function Base.zero(::Type{OfNamedTuple{Names,Types}}) where {Names,Types} + zeros_ = [:(zero($P)) for P in Types.parameters] + return :(NamedTuple{Names}(($(zeros_...),))) +end + +function Base.zero(::Type{OfConstantWrapper{T}}) where {T} + return error( + "Cannot generate zero values for constants. Use zero(of(T; const_name=value)) after providing the constant value.", + ) +end + +""" + size(::Type{T}) where T<:OfType + +Get the dimensions/shape of an `OfType` specification. + +Returns the size information encoded in the type. For arrays, returns a tuple +of dimensions. For scalars, returns an empty tuple. For named tuples, returns +a named tuple with the size of each field. + +# Returns +- `OfArray`: Tuple of dimensions +- `OfReal`, `OfInt`: Empty tuple `()` +- `OfNamedTuple`: Named tuple with sizes of each field +- `OfConstantWrapper`: Delegates to wrapped type + +# Examples +```julia +size(of(Array, 3, 4)) # (3, 4) +size(of(Float64)) # () +size(of(Int, 0, 10)) # () + +T = @of(x=of(Real), y=of(Array, 2, 3)) +size(T) # (x=(), y=(2, 3)) +``` + +# Errors +- Throws error for arrays with unresolved symbolic dimensions + +# See also +[`length`](@ref), [`of`](@ref) +""" +function Base.size(::Type{OfArray{T,N,D}}) where {T,N,D} + dims = get_dims(OfArray{T,N,D}) + any(_is_symbolic_dim, dims) && + error("Cannot get size of array with symbolic dimensions.") + return dims +end + +Base.size(::Type{OfReal{T,L,U}}) where {T,L,U} = () +Base.size(::Type{OfInt{L,U}}) where {L,U} = () + +function Base.size(::Type{OfNamedTuple{Names,Types}}) where {Names,Types} + dims = ntuple(i -> size(Types.parameters[i]), length(Names)) + return NamedTuple{Names}(dims) +end + +function Base.size(::Type{OfConstantWrapper{T}}) where {T} + return size(T) +end + +""" + length(::Type{T}) where T<:OfType + +Get the total number of elements when the type is flattened. + +Returns the total count of numerical values that would be in a flattened +representation. Arrays contribute their total element count, scalars +contribute 1, and named tuples sum the lengths of all fields. Constants +(wrapped in `OfConstantWrapper`) contribute 0 as they are not part of +the flattened representation. + +# Returns +- `OfArray`: Product of dimensions (total elements) +- `OfReal`, `OfInt`: 1 +- `OfNamedTuple`: Sum of lengths of all fields +- `OfConstantWrapper`: 0 (constants excluded from flattening) + +# Examples +```julia +length(of(Array, 3, 4)) # 12 +length(of(Float64)) # 1 +length(of(Int, 0, 10)) # 1 + +T = @of(x=of(Real), y=of(Array, 2, 3)) +length(T) # 7 (1 + 6) + +# Constants contribute 0; concrete fields still count +T2 = @of(n=of(Int; constant=true), data=of(Array, 3, 3)) +length(T2) # 9 (n contributes 0, data contributes 9) +length(of(T2; n=5)) # 9 (n removed; data unchanged) +``` + +# Errors +- Throws error for arrays with unresolved symbolic dimensions + +# See also +[`size`](@ref), [`flatten`](@ref), [`unflatten`](@ref) +""" +function Base.length(::Type{OfArray{T,N,D}}) where {T,N,D} + dims = get_dims(OfArray{T,N,D}) + any(_is_symbolic_dim, dims) && + error("Cannot get length of array with symbolic dimensions.") + return prod(dims)::Int +end + +Base.length(::Type{OfReal{T,L,U}}) where {T,L,U} = 1 +Base.length(::Type{OfInt{L,U}}) where {L,U} = 1 + +# `@generated` so the total folds to a compile-time constant; a plain recursive `sum` +# widens to `Any` once nesting/field-count grows (poisoning the flatten/unflatten path). +@generated function Base.length(::Type{OfNamedTuple{Names,Types}}) where {Names,Types} + total = sum(Int[length(P) for P in Types.parameters]; init=0) + return :($total) +end + +function Base.length(::Type{OfConstantWrapper{T}}) where {T} + return 0 # Constants are not part of the flattened representation +end + +# Check if a type still carries unresolved symbols: symbolic array dimensions, symbolic +# bounds, or constants (an `OfConstantWrapper`, whether nested or at the top level). +function has_symbolic_dims(::Type{T}) where {T<:OfType} + if T <: OfArray + return any(_is_symbolic_dim, get_dims(T)) + elseif T <: Union{OfReal,OfInt} + return _is_symbolic_bound(get_lower(T)) || _is_symbolic_bound(get_upper(T)) + elseif T <: OfConstantWrapper + return true + elseif T <: OfNamedTuple + types = get_types(T) + for i in 1:length(types.parameters) + has_symbolic_dims(types.parameters[i]) && return true + end + return false + else + return false + end +end + +# Get list of unresolved symbols in a type +function get_unresolved_symbols(::Type{T}) where {T<:OfType} + symbols = Symbol[] + + function collect_symbols(oft_type::Type, path::String="") + if oft_type <: OfArray + dims = get_dims(oft_type) + for d in dims + if d isa Symbol + push!(symbols, d) + elseif d isa Type && d <: SymbolicExpr + extract_symbols_from_expr(d.parameters[1]) + end + end + elseif oft_type <: Union{OfReal,OfInt} + for b in (get_lower(oft_type), get_upper(oft_type)) + if b isa Type && b <: SymbolicRef + push!(symbols, type_to_bound(b)) + elseif b isa Type && b <: SymbolicExpr + extract_symbols_from_expr(b.parameters[1]) + end + end + elseif oft_type <: OfConstantWrapper + collect_symbols(get_wrapped_type(oft_type), path) + elseif oft_type <: OfNamedTuple + names = get_names(oft_type) + types = get_types(oft_type) + for (i, name) in enumerate(names) + field_type = types.parameters[i] + new_path = isempty(path) ? string(name) : "$path.$name" + if field_type <: OfConstantWrapper + push!(symbols, name) + else + collect_symbols(field_type, new_path) + end + end + end + end + + function extract_symbols_from_expr(expr::Tuple) + for arg in expr[2:end] # Skip operator + if arg isa Symbol + push!(symbols, arg) + elseif arg isa Tuple + extract_symbols_from_expr(arg) + end + end + end + + collect_symbols(T) + return unique(symbols) +end + +# Validate that a value is within bounds. `kind` only labels the error message. +function validate_bounds(value, lower, upper; kind::AbstractString="value") + if !isnothing(lower) && !isnothing(upper) && lower > upper + error("Invalid $kind bounds: lower bound $lower is greater than upper bound $upper") + end + if !isnothing(lower) && value < lower + error("$kind $value is below lower bound $lower") + end + if !isnothing(upper) && value > upper + error("$kind $value is above upper bound $upper") + end +end + +function _validate(::Type{T}, value) where {T<:OfType} + if is_leaf(T) + return _validate_leaf(T, value) + else + return _validate_container(T, value) + end +end + +function _validate_leaf(::Type{OfArray{T,N,D}}, value) where {T,N,D} + value isa AbstractArray || error("Expected Array for OfArray, got $(typeof(value))") + + dims = get_dims(OfArray{T,N,D}) + any(d -> d isa Symbol || (d isa Type && d <: SymbolicExpr), dims) && error( + "Cannot validate array with symbolic dimensions. Use the parameterized constructor.", + ) + + # Check dimensions before conversion + ndims(value) == N || + error("Array dimension mismatch: expected $N dimensions, got $(ndims(value))") + size(value) == Tuple(dims) || + error("Array size mismatch: expected $(Tuple(dims)), got $(size(value))") + + arr = convert(Array{T,N}, value) + return arr +end + +function _validate_leaf(::Type{OfReal{T,L,U}}, value) where {T,L,U} + if value isa Real + val = convert(T, value) + lower = type_to_bound(L) + upper = type_to_bound(U) + validate_bounds(val, lower, upper; kind="Real value") + return val + else + error("Expected Real for OfReal, got $(typeof(value))") + end +end + +function _validate_leaf(::Type{OfInt{L,U}}, value) where {L,U} + if value isa Integer + val = convert(Int, value) + lower = type_to_bound(L) + upper = type_to_bound(U) + validate_bounds(val, lower, upper; kind="Int value") + return val + elseif value isa Real + # Allow conversion from Real to Int if it's a whole number + if isinteger(value) + return _validate_leaf(OfInt{L,U}, Int(value)) + else + error("Expected Integer for OfInt, got non-integer Real: $value") + end + else + error("Expected Integer for OfInt, got $(typeof(value))") + end +end + +function _validate_container(::Type{OfNamedTuple{Names,Types}}, value) where {Names,Types} + value isa NamedTuple || + error("Expected NamedTuple for OfNamedTuple, got $(typeof(value))") + + value_names = fieldnames(typeof(value)) + for name in Names + if !(name in value_names) + error("Missing required field: $name. Got fields: $(join(value_names, ", "))") + end + end + extra = setdiff(collect(value_names), collect(Names)) + isempty(extra) || error("Unexpected field(s): $(join(extra, ", "))") + + vals = ntuple(length(Names)) do i + field_name = Names[i] + field_type = Types.parameters[i] + return _validate(field_type, getproperty(value, field_name)) + end + return NamedTuple{Names}(vals) +end + +function _validate_leaf(::Type{OfConstantWrapper{T}}, value) where {T} + return _validate_leaf(T, value) +end + +""" + flatten(::Type{T}, values) where T<:OfType + +Convert structured values to a flat numeric vector. + +Walks `values` in field order, vectorising arrays (column-major) and recursing into +named tuples, and returns a flat vector whose element type is the promotion of the +declared leaf element types (so a pure-`Int` structure stays `Vector{Int}`, while any +float field widens the whole vector). This is the form an optimiser or sampler wants. +Constants are excluded by construction, since a flattenable type has none. + +# Returns +A `Vector{V}` where `V` is `promote_type` of the declared leaf element types. + +# Examples +```julia +flatten(of(Float64), 3.14) # [3.14] (Vector{Float64}) +flatten(of(Array, 2, 2), [1 2; 3 4]) # [1.0, 3.0, 2.0, 4.0] + +T = @of(x=of(Real), y=of(Array, 2, 2)) +flatten(T, (x=1.5, y=[1 2; 3 4])) # [1.5, 1.0, 3.0, 2.0, 4.0] +``` + +# Errors +- Throws if `values` do not match the specification (shape or bounds). +- Throws for types with unresolved symbolic dimensions, bounds, or constants. + +# See also +[`unflatten`](@ref), [`length`](@ref) +""" +function flatten(::Type{T}, values) where {T<:OfType} + has_symbolic_dims(T) && error( + "Cannot flatten a type with symbolic dimensions, symbolic bounds, or constants. Resolve them with of(T; name=value) first.", + ) + validated = _validate(T, values) + out = Vector{_flat_eltype(T)}(undef, length(T)) + _fill_flat!(out, 1, T, validated) + return out +end + +# Element type of the flat vector: promotion of the declared leaf element types. +_flat_eltype(::Type{OfReal{T,L,U}}) where {T,L,U} = T +_flat_eltype(::Type{<:OfInt}) = Int +_flat_eltype(::Type{OfArray{T,N,D}}) where {T,N,D} = T +function _flat_eltype(::Type{OfNamedTuple{Names,Types}}) where {Names,Types} + return promote_type(map(_flat_eltype, (Types.parameters...,))...) +end + +# Write a leaf/subtree into `out` starting at index `i`; return the next free index. +_fill_flat!(out, i, ::Type{<:OfReal}, x) = (out[i] = x; i + 1) +_fill_flat!(out, i, ::Type{<:OfInt}, x) = (out[i] = x; i + 1) +function _fill_flat!(out, i, ::Type{<:OfArray}, a) + copyto!(out, i, vec(a), 1, length(a)) + return i + length(a) +end +# `@generated` so the per-field recursion unrolls and stays type-stable on the AD path. +@generated function _fill_flat!( + out, i0, ::Type{OfNamedTuple{Names,Types}}, nt +) where {Names,Types} + body = quote + i = i0 + end + for (k, P) in enumerate(Types.parameters) + push!(body.args, :(i = _fill_flat!(out, i, $P, nt[$k]))) + end + push!(body.args, :(return i)) + return body +end + +""" + unflatten(::Type{T}, flat_values::AbstractVector{<:Real}) where T<:OfType + unflatten(::Type{T}, ::Missing) where T<:OfType + +Reconstruct structured values from a flat numeric vector (the inverse of [`flatten`](@ref)). + +Arrays are reshaped, named tuples are rebuilt in field order, and bounds are validated. +Floating-point leaves take `promote_type(declared, eltype(flat_values))`: the declared float +type acts as a precision floor, while wider numbers in `flat_values` — AD numbers +(`ForwardDiff.Dual`), `BigFloat` — flow through unchanged. Integer leaves are rounded to `Int`. + +The `missing` method builds a structure with every element set to `missing`. + +# Examples +```julia +unflatten(of(Float64), [3.14]) # 3.14 +unflatten(of(Array, 2, 2), [1, 3, 2, 4]) # [1.0 2.0; 3.0 4.0] + +T = @of(x=of(Real), y=of(Array, 2, 2)) +unflatten(T, [1.5, 1.0, 3.0, 2.0, 4.0]) # (x=1.5, y=[1.0 2.0; 3.0 4.0]) + +unflatten(T, missing) # (x=missing, y=[missing missing; missing missing]) +``` + +# Errors +- Throws if `length(flat_values)` differs from `length(T)`. +- Throws if values violate bounds. +- Throws for types with unresolved symbolic dimensions, bounds, or constants. + +# See also +[`flatten`](@ref), [`length`](@ref) +""" +function unflatten(::Type{T}, flat_values::AbstractVector{<:Real}) where {T<:OfType} + has_symbolic_dims(T) && error( + "Cannot unflatten a type with symbolic dimensions, symbolic bounds, or constants. Resolve them with of(T; name=value) first.", + ) + n = length(T) + length(flat_values) == n || + error("Length mismatch: type expects $n values, got $(length(flat_values)).") + value, _ = _unflat(T, flat_values, 1) + return value +end + +# Reconstruct one leaf/subtree from `v` starting at index `i`; return (value, next index). +# Float leaves take `promote_type(declared, eltype(v))`, so the declared type is a precision +# floor while AD numbers (`Dual`), `BigFloat`, etc. in `v` flow through. +function _unflat(::Type{OfReal{T,L,U}}, v, i) where {T,L,U} + x = convert(promote_type(T, eltype(v)), v[i]) + validate_bounds(x, type_to_bound(L), type_to_bound(U); kind="Real value") + return x, i + 1 +end +function _unflat(::Type{OfInt{L,U}}, v, i) where {L,U} + x = round(Int, v[i]) + validate_bounds(x, type_to_bound(L), type_to_bound(U); kind="Int value") + return x, i + 1 +end +function _unflat(::Type{OfArray{T,N,D}}, v, i) where {T,N,D} + dims = get_dims(OfArray{T,N,D}) + n = prod(dims) + arr = _to_array(T, @view(v[i:(i + n - 1)]), dims) + return arr, i + n +end +@generated function _unflat(::Type{OfNamedTuple{Names,Types}}, v, i0) where {Names,Types} + body = quote + i = i0 + end + syms = Symbol[] + for (k, P) in enumerate(Types.parameters) + s = Symbol(:val_, k) + push!(syms, s) + push!(body.args, :(($s, i) = _unflat($P, v, i))) + end + push!(body.args, :(return NamedTuple{Names}(($(syms...),)), i)) + return body +end + +# Integer element types round to that type; otherwise promote the declared element type with +# the flat vector's eltype (honour the declaration, but let AD/wider numbers through). `collect` +# copies, so the result never aliases the input vector. +_to_array(::Type{ET}, slice, dims) where {ET<:Integer} = reshape(round.(ET, slice), dims) +function _to_array(::Type{ET}, slice, dims) where {ET} + return reshape(collect(promote_type(ET, eltype(slice)), slice), dims) +end + +function unflatten(::Type{T}, ::Missing) where {T<:OfType} + has_symbolic_dims(T) && error( + "Cannot unflatten a type with symbolic dimensions, symbolic bounds, or constants. Resolve them with of(T; name=value) first.", + ) + return _unflat_missing(T) +end + +_unflat_missing(::Type{<:OfReal}) = missing +_unflat_missing(::Type{<:OfInt}) = missing +function _unflat_missing(::Type{OfArray{T,N,D}}) where {T,N,D} + return fill(missing, get_dims(OfArray{T,N,D})) +end +@generated function _unflat_missing(::Type{OfNamedTuple{Names,Types}}) where {Names,Types} + vals = [:(_unflat_missing($P)) for P in Types.parameters] + return :(NamedTuple{Names}(($(vals...),))) +end + +# Format a bound type for display +function format_bound(bound_type, constant_fields, use_color) + if bound_type === Nothing + return "nothing" + elseif bound_type isa Real + # Numeric values + return string(bound_type) + elseif bound_type isa Type && bound_type <: SymbolicRef + sym = type_to_bound(bound_type) + if sym in constant_fields && use_color + return sprint() do io_inner + return printstyled(io_inner, string(sym); color=:cyan) + end + else + return string(sym) + end + else + return string(bound_type) + end +end + +# Show a bounded type (OfReal or OfInt) with proper formatting +function show_bounded_type(io::IO, type_name::String, L, U; constant::Bool=false) + use_color = get(io, :color, false) + constant_fields = get(io, :constant_fields, Symbol[]) + + if L === Nothing && U === Nothing + if constant + if use_color + printstyled(io, "of($type_name"; color=:cyan) + printstyled(io, "; constant=true"; color=:light_black) + printstyled(io, ")"; color=:cyan) + else + print(io, "of($type_name; constant=true)") + end + else + print(io, "of($type_name)") + end + else + lower_str = format_bound(L, constant_fields, use_color) + upper_str = format_bound(U, constant_fields, use_color) + + if constant && use_color + printstyled(io, "of($type_name, "; color=:cyan) + if L isa Type && L <: SymbolicRef && type_to_bound(L) in constant_fields + printstyled(io, string(type_to_bound(L)); color=:cyan) + else + printstyled(io, lower_str; color=:cyan) + end + printstyled(io, ", "; color=:cyan) + if U isa Type && U <: SymbolicRef && type_to_bound(U) in constant_fields + printstyled(io, string(type_to_bound(U)); color=:cyan) + else + printstyled(io, upper_str; color=:cyan) + end + printstyled(io, "; constant=true"; color=:light_black) + printstyled(io, ")"; color=:cyan) + else + print(io, "of($type_name, ") + if L isa Type && + L <: SymbolicRef && + type_to_bound(L) in constant_fields && + use_color + printstyled(io, string(type_to_bound(L)); color=:cyan) + else + print(io, lower_str) + end + print(io, ", ") + if U isa Type && + U <: SymbolicRef && + type_to_bound(U) in constant_fields && + use_color + printstyled(io, string(type_to_bound(U)); color=:cyan) + else + print(io, upper_str) + end + if constant + print(io, "; constant=true") + end + print(io, ")") + end + end +end + +# Helper to convert expression tuple back to string +function expr_tuple_to_string(expr::Tuple) + if length(expr) < 2 + return string(expr) + end + + op = expr[1] + if op in (:+, :-, :*, :/) && length(expr) == 3 + arg1_str = format_expr_arg(expr[2]) + arg2_str = format_expr_arg(expr[3]) + + # Add parentheses for multiplication and division if needed + if op in (:*, :/) && expr[2] isa Tuple + arg1_str = "($arg1_str)" + end + if op in (:*, :/) && expr[3] isa Tuple + arg2_str = "($arg2_str)" + end + + return "$arg1_str $op $arg2_str" + else + return string(expr) + end +end + +# Format a single expression argument +function format_expr_arg(arg) + if arg isa Symbol + string(arg) + elseif arg isa Tuple + expr_tuple_to_string(arg) + else + string(arg) + end +end + +# Show implementations. Each guards against non-concrete types (free TypeVars, e.g. when a +# method signature or stacktrace frame is rendered): touching the static params there throws +# UndefVarError, which is fatal mid-backtrace, so fall back to Base's generic Type printer. +function Base.show(io::IO, t::Type{OfArray{T,N,D}}) where {T,N,D} + isconcretetype(t) || return invoke(show, Tuple{IO,Type}, io, t) + use_color = get(io, :color, false) + constant_fields = get(io, :constant_fields, Symbol[]) + + # Process dimensions, highlighting those that reference constants + if use_color && !isempty(constant_fields) + print(io, "of(Array, ") + if T !== Float64 + print(io, T, ", ") + end + # D is a Tuple type, so we need to access its parameters + dims_list = get_dims(OfArray{T,N,D}) + for (i, d) in enumerate(dims_list) + if d isa Symbol && d in constant_fields + printstyled(io, string(d); color=:cyan) + elseif d isa Type && d <: SymbolicExpr + # This is an expression - format it nicely + expr = d.parameters[1] + expr_str = expr_tuple_to_string(expr) + # Check if any symbols in the expression are constants + has_constant = false + function check_expr(e::Tuple) + for arg in e[2:end] + if arg isa Symbol && arg in constant_fields + has_constant = true + elseif arg isa Tuple + check_expr(arg) + end + end + end + check_expr(expr) + if has_constant + printstyled(io, expr_str; color=:cyan) + else + print(io, expr_str) + end + else + print(io, string(d)) + end + if i < length(dims_list) + print(io, ", ") + end + end + print(io, ")") + return nothing + end + + # Non-color version. + dims_list = get_dims(OfArray{T,N,D}) + dims_str = join( + map(dims_list) do d + if d isa Type && d <: SymbolicExpr + expr_tuple_to_string(d.parameters[1]) + else + string(d) + end + end, + ", ", + ) + + prefix = T === Float64 ? "of(Array" : "of(Array, $T" + # Append dims only when present, so a 0-dim array prints `of(Array)` not `of(Array, )`. + return print(io, isempty(dims_str) ? "$prefix)" : "$prefix, $dims_str)") +end + +function Base.show(io::IO, t::Type{OfReal{T,L,U}}) where {T,L,U} + isconcretetype(t) || return invoke(show, Tuple{IO,Type}, io, t) + return show_bounded_type(io, string(T), L, U) +end + +function Base.show(io::IO, t::Type{OfInt{L,U}}) where {L,U} + isconcretetype(t) || return invoke(show, Tuple{IO,Type}, io, t) + return show_bounded_type(io, "Int", L, U) +end + +# Helper function to collect constant fields from a NamedTuple type +function _collect_constant_fields(Names, Types) + constant_fields = Symbol[] + for (name, T) in zip(Names, Types.parameters) + if T <: OfConstantWrapper + push!(constant_fields, name) + end + end + return constant_fields +end + +# Helper function to estimate output length for a NamedTuple type +function _estimate_namedtuple_length(Names, Types) + total_length = 4 # "@of(" + for (name, T) in zip(Names, Types.parameters) + total_length += length(string(name)) + 1 # name= + # Rough estimate of type string length + if T <: OfArray + dims = get_dims(T) + total_length += 15 + sum(d -> length(string(d)), dims; init=0) + else + total_length += 20 + end + total_length += 2 # ", " + end + return total_length +end + +# Helper function to show a field with appropriate styling +function _show_field_type(io::IO, T::Type, is_constant::Bool) + if is_constant && get(io, :color, false) + wrapped = get_wrapped_type(T) + if wrapped <: OfReal && + get_lower(wrapped) === Nothing && + get_upper(wrapped) === Nothing + elem_type = get_element_type(wrapped) + type_name = elem_type === Float64 ? "Real" : string(elem_type) + printstyled(io, "of($type_name"; color=:cyan) + printstyled(io, "; constant=true"; color=:light_black) + printstyled(io, ")"; color=:cyan) + elseif wrapped <: OfInt && + get_lower(wrapped) === Nothing && + get_upper(wrapped) === Nothing + printstyled(io, "of(Int"; color=:cyan) + printstyled(io, "; constant=true"; color=:light_black) + printstyled(io, ")"; color=:cyan) + else + show(io, T) + end + else + show(io, T) + end +end + +# Helper function to print a single field in NamedTuple +function _print_namedtuple_field( + io::IO, name::Symbol, T::Type, is_constant::Bool, separator::String +) + if is_constant && get(io, :color, false) + printstyled(io, name; color=:cyan, bold=true) + else + print(io, name) + end + + print(io, separator) + return _show_field_type(io, T, is_constant) +end + +function Base.show(io::IO, t::Type{OfNamedTuple{Names,Types}}) where {Names,Types} + isconcretetype(t) || return invoke(show, Tuple{IO,Type}, io, t) + # Collect constant fields to pass to child types + constant_fields = _collect_constant_fields(Names, Types) + io_with_constants = IOContext(io, :constant_fields => constant_fields) + + compact = get(io, :compact, false) + multiline = !compact && length(Names) > 3 + + # Check if single-line output would be too long + if !multiline && !compact + multiline = _estimate_namedtuple_length(Names, Types) > 80 + end + + print(io, "@of(") + + if multiline + println(io) + for (i, (name, T)) in enumerate(zip(Names, Types.parameters)) + print(io, " ") + is_constant = T <: OfConstantWrapper + _print_namedtuple_field(io_with_constants, name, T, is_constant, " = ") + + if i < length(Names) + println(io, ",") + else + println(io) + end + end + print(io, ")") + else + # Single line format + for (i, (name, T)) in enumerate(zip(Names, Types.parameters)) + is_constant = T <: OfConstantWrapper + _print_namedtuple_field(io_with_constants, name, T, is_constant, "=") + + if i < length(Names) + print(io, ", ") + end + end + print(io, ")") + end +end + +function Base.show(io::IO, t::Type{OfConstantWrapper{T}}) where {T} + isconcretetype(t) || return invoke(show, Tuple{IO,Type}, io, t) + # Show the wrapped type with constant=true + if T <: OfReal + elem_type = get_element_type(T) + # Use "Real" for backward compatibility when Float64 is the element type + type_name = elem_type === Float64 ? "Real" : string(elem_type) + show_bounded_type(io, type_name, get_lower(T), get_upper(T); constant=true) + elseif T <: OfInt + show_bounded_type(io, "Int", get_lower(T), get_upper(T); constant=true) + elseif T <: OfArray + # This case should not happen since constant=true is not allowed for Arrays + # But if it does, show it as a fallback + print(io, "OfConstantWrapper{", T, "}") + printstyled(io, " # Invalid: constant=true not supported for Arrays"; color=:red) + else + # Fallback + print(io, "OfConstantWrapper{", T, "}") + end +end diff --git a/test/of.jl b/test/of.jl new file mode 100644 index 00000000..cd501b72 --- /dev/null +++ b/test/of.jl @@ -0,0 +1,1265 @@ +using Test + +using AbstractPPL +using AbstractPPL: + OfType, + OfInt, + OfReal, + OfArray, + OfNamedTuple, + OfConstantWrapper, + SymbolicRef, + SymbolicExpr, + get_names, + get_types, + flatten, + unflatten, + has_symbolic_dims, + get_unresolved_symbols, + get_dims, + get_element_type, + get_ndims, + get_lower, + get_upper + +using Random: MersenneTwister + +@testset "Basic type creation" begin + @testset "Simple type creation" begin + @test of(Int) == OfInt{Nothing,Nothing} + @test of(Int, 0, 10) == OfInt{0,10} + @test of(Real) == OfReal{Float64,Nothing,Nothing} + @test of(Real, 0.0, 1.0) == OfReal{Float64,0.0,1.0} + + @test of(Array, 5) == OfArray{Float64,1,Tuple{5}} + @test of(Array, 3, 4) == OfArray{Float64,2,Tuple{3,4}} + @test of(Array, Int, 2, 2) == OfArray{Int,2,Tuple{2,2}} + end + + @testset "Symbolic bounds" begin + T1 = of(Real, :lower, :upper) + @test T1 == OfReal{Float64,SymbolicRef{:lower},SymbolicRef{:upper}} + + T2 = of(Int, 0, :max) + @test T2 == OfInt{0,SymbolicRef{:max}} + + T3 = of(Real, :min, :max; constant=true) + @test T3 == OfConstantWrapper{OfReal{Float64,SymbolicRef{:min},SymbolicRef{:max}}} + end + + @testset "Explicit float types" begin + @test of(Float64) == OfReal{Float64,Nothing,Nothing} + @test of(Float64, 0.0, 1.0) == OfReal{Float64,0.0,1.0} + @test of(Float64; constant=true) == + OfConstantWrapper{OfReal{Float64,Nothing,Nothing}} + + @test of(Float32) == OfReal{Float32,Nothing,Nothing} + @test of(Float32, -1.0f0, 1.0f0) == OfReal{Float32,-1.0f0,1.0f0} + @test of(Float32; constant=true) == + OfConstantWrapper{OfReal{Float32,Nothing,Nothing}} + + @test of(Real) == OfReal{Float64,Nothing,Nothing} + + @test rand(of(Float64)) isa Float64 + @test rand(of(Float32)) isa Float32 + @test rand(of(Real)) isa Float64 + + @test zero(of(Float64)) isa Float64 + @test zero(of(Float32)) isa Float32 + @test zero(of(Real)) isa Float64 + + val64 = rand(of(Float64, 0.0, 1.0)) + @test val64 isa Float64 + @test 0.0 <= val64 <= 1.0 + + val32 = rand(of(Float32, -1.0f0, 1.0f0)) + @test val32 isa Float32 + @test -1.0f0 <= val32 <= 1.0f0 + end +end + +@testset "@of macro tests" begin + @testset "Basic constant syntax" begin + T1 = of(Int; constant=true) + @test T1 == OfConstantWrapper{OfInt{Nothing,Nothing}} + @test string(T1) == "of(Int; constant=true)" + + T2 = of(Real; constant=true) + @test T2 == OfConstantWrapper{OfReal{Float64,Nothing,Nothing}} + @test string(T2) == "of(Real; constant=true)" + + T3 = of(Int) + @test T3 == OfInt{Nothing,Nothing} + + T4 = of(Real, 0, 10) + @test T4 == OfReal{Float64,0,10} + + @test_throws ErrorException of(Array, 10; constant=true) + @test_throws ErrorException of(Array, Float64, 5, 5; constant=true) + end + + @testset "Simple @of macro" begin + T = @of(mu = of(Real), sigma = of(Real, 0, nothing), data = of(Array, 10)) + + @test T <: OfNamedTuple + names = get_names(T) + @test names == (:mu, :sigma, :data) + + types = get_types(T) + @test types.parameters[1] == OfReal{Float64,Nothing,Nothing} + @test types.parameters[2] == OfReal{Float64,0,Nothing} + @test types.parameters[3] == OfArray{Float64,1,Tuple{10}} + end + + @testset "@of with constants and references" begin + T = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + data = of(Array, rows, cols) + ) + + @test T <: OfNamedTuple + names = get_names(T) + @test names == (:rows, :cols, :data) + + types = get_types(T) + @test types.parameters[1] == OfConstantWrapper{OfInt{Nothing,Nothing}} + @test types.parameters[2] == OfConstantWrapper{OfInt{Nothing,Nothing}} + @test types.parameters[3] == OfArray{Float64,2,Tuple{:rows,:cols}} + end + + @testset "@of with expressions" begin + # Arithmetic expressions in dimensions are supported and encode as SymbolicExpr. + T = @of(n = of(Int; constant=true), data = of(Array, n + 1, 2 * n)) + @test T <: OfNamedTuple + CT = of(T; n=5) + @test get_dims(get_types(CT).parameters[1]) == (6, 10) + end + + @testset "@of with float types" begin + T = @of( + f64_val = of(Float64), + f32_val = of(Float32, 0.0f0, 1.0f0), + real_val = of(Real; constant=true), + f64_array = of(Array, Float64, 3), + f32_array = of(Array, Float32, 2, 2) + ) + + types = get_types(T) + @test types.parameters[1] == OfReal{Float64,Nothing,Nothing} + @test types.parameters[2] == OfReal{Float32,0.0f0,1.0f0} + @test types.parameters[3] == OfConstantWrapper{OfReal{Float64,Nothing,Nothing}} + @test types.parameters[4] == OfArray{Float64,1,Tuple{3}} + @test types.parameters[5] == OfArray{Float32,2,Tuple{2,2}} + + instance = T(; real_val=5.0) + @test instance.f64_val isa Float64 + @test instance.f32_val isa Float32 + @test instance.f64_array isa Vector{Float64} + @test instance.f32_array isa Matrix{Float32} + end + + @testset "Concrete instance creation" begin + MatrixType = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + data = of(Array, rows, cols) + ) + + instance = MatrixType(; rows=3, cols=4) + + @test instance isa NamedTuple + @test keys(instance) == (:data,) + @test instance.data isa Matrix{Float64} + @test size(instance.data) == (3, 4) + @test all(instance.data .== 0.0) + + test_data = rand(3, 4) + instance2 = MatrixType(; rows=3, cols=4, data=test_data) + @test instance2.data ≈ test_data + end + + @testset "rand and zero with constants" begin + T = @of(n = of(Int; constant=true), data = of(Array, n)) + + CT = of(T; n=5) + val = rand(CT) + @test haskey(val, :data) + @test size(val.data) == (5,) + + CT2 = of(T; n=3) + val = zero(CT2) + @test haskey(val, :data) + @test size(val.data) == (3,) + @test all(val.data .== 0.0) + end + + @testset "flatten/unflatten preserves array types" begin + T = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + data = of(Array, rows, cols) + ) + ConcreteType = of(T; rows=2, cols=3) + + original = (data=rand(Float64, 2, 3),) + + flat = flatten(ConcreteType, original) + reconstructed = unflatten(ConcreteType, flat) + + @test typeof(reconstructed.data) == typeof(original.data) + @test typeof(reconstructed.data) == Matrix{Float64} + @test reconstructed.data ≈ original.data + end + + @testset "flatten/unflatten with concrete types" begin + T = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + scale = of(Real, 0.1, 10.0), + data = of(Array, rows, cols) + ) + + instance = T(; rows=3, cols=2) + @test instance isa NamedTuple + @test keys(instance) == (:scale, :data) + @test size(instance.data) == (3, 2) + + CT = of(T; rows=3, cols=2) + + flat = flatten(CT, instance) + @test length(flat) == 7 # 1 scale + 6 data elements + + reconstructed = unflatten(CT, flat) + @test reconstructed.scale ≈ instance.scale + @test reconstructed.data ≈ instance.data + + instance2 = (scale=2.5, data=rand(3, 2)) + flat2 = flatten(CT, instance2) + reconstructed2 = unflatten(CT, flat2) + @test reconstructed2.scale ≈ 2.5 + @test reconstructed2.data ≈ instance2.data + end +end + +@testset "Symbolic bounds tests" begin + @testset "Symbolic references in @of macro" begin + T = @of( + min = of(Real, 0, 10; constant=true), + max = of(Real, 20, 30; constant=true), + value = of(Real, min, max), + ) + + types = get_types(T) + @test types.parameters[1] == OfConstantWrapper{OfReal{Float64,0,10}} + @test types.parameters[2] == OfConstantWrapper{OfReal{Float64,20,30}} + @test types.parameters[3] == OfReal{Float64,SymbolicRef{:min},SymbolicRef{:max}} + end + + @testset "Symbolic bounds in named tuples" begin + T = @of( + lower_bound = of(Real, 0, nothing; constant=true), + upper_bound = of(Real, lower_bound, nothing; constant=true), + param = of(Real, lower_bound, upper_bound), + ) + + types = get_types(T) + @test types.parameters[1] == OfConstantWrapper{OfReal{Float64,0,Nothing}} + @test types.parameters[2] == + OfConstantWrapper{OfReal{Float64,SymbolicRef{:lower_bound},Nothing}} + @test types.parameters[3] == + OfReal{Float64,SymbolicRef{:lower_bound},SymbolicRef{:upper_bound}} + end + + @testset "@of macro with symbolic bounds" begin + Schema = @of( + min_val = of(Real, 0, 10; constant=true), + max_val = of(Real, min_val, 100; constant=true), + param = of(Real, min_val, max_val) + ) + + types = get_types(Schema) + @test types.parameters[1] == OfConstantWrapper{OfReal{Float64,0,10}} + @test types.parameters[2] == + OfConstantWrapper{OfReal{Float64,SymbolicRef{:min_val},100}} + @test types.parameters[3] == + OfReal{Float64,SymbolicRef{:min_val},SymbolicRef{:max_val}} + end + + @testset "Symbolic bounds with constants" begin + Schema = @of( + lower = of(Real, 0, nothing; constant=true), + upper = of(Real, lower, nothing; constant=true), + x = of(Real, lower, upper) + ) + + types = get_types(Schema) + @test types.parameters[1] == OfConstantWrapper{OfReal{Float64,0,Nothing}} + @test types.parameters[2] == + OfConstantWrapper{OfReal{Float64,SymbolicRef{:lower},Nothing}} + @test types.parameters[3] == OfReal{Float64,SymbolicRef{:lower},SymbolicRef{:upper}} + end + + @testset "Concrete instance creation with symbolic resolution" begin + Schema = @of( + min_bound = of(Real; constant=true), + max_bound = of(Real; constant=true), + value = of(Real, min_bound, max_bound) + ) + + instance = Schema(; min_bound=0.0, max_bound=1.0) + + @test instance isa NamedTuple + @test keys(instance) == (:value,) + @test instance.value == 0.0 + + instance2 = Schema(; min_bound=0.0, max_bound=1.0, value=0.5) + @test instance2.value == 0.5 + end + + @testset "Validation with symbolic bounds" begin + # A symbolic bound is resolved from the constant during construction, then the + # provided value is validated against the resolved bound. + Schema = @of( + threshold = of(Real, 0, nothing; constant=true), value = of(Real, 0, threshold) + ) + + instance = Schema(; threshold=10.0, value=5.0) + @test keys(instance) == (:value,) + @test instance.value == 5.0 + + # value above the resolved upper bound (threshold) must throw + @test_throws ErrorException Schema(; threshold=10.0, value=15.0) + # value below the lower bound must throw + @test_throws ErrorException Schema(; threshold=10.0, value=-1.0) + end +end + +@testset "Constant Elimination After Concretization" begin + @testset "Basic elimination" begin + T = @of( + n = of(Int; constant=true), m = of(Int; constant=true), data = of(Array, n, m) + ) + + CT = of(T; n=3, m=4) + + @test get_names(CT) == (:data,) + + types = get_types(CT) + @test types.parameters[1] == of(Array, 3, 4) + + instance = rand(CT) + @test instance isa NamedTuple + @test !haskey(instance, :n) + @test !haskey(instance, :m) + @test haskey(instance, :data) + @test size(instance.data) == (3, 4) + + @test length(CT) == 12 # 3×4 array + end + + @testset "Constants with bounds" begin + T = @of( + lower = of(Int, 1, 10; constant=true), + upper = of(Int, 50, 100; constant=true), + value = of(Real, lower, upper) + ) + + CT = of(T; lower=5, upper=75) + + types = get_types(CT) + @test get_names(CT) == (:value,) + @test types.parameters[1] == of(Real, 5, 75) + end + + @testset "Partial concretization" begin + T = @of( + a = of(Int; constant=true), b = of(Int; constant=true), data = of(Array, a, b) + ) + + CT = of(T; a=10) + + names = get_names(CT) + types = get_types(CT) + # 'a' is eliminated, 'b' still wrapped as constant, data uses concrete 'a' + @test :a ∉ names + @test :b ∈ names + @test :data ∈ names + + b_idx = findfirst(==(Symbol("b")), names) + @test types.parameters[b_idx] <: OfConstantWrapper + + data_idx = findfirst(==(Symbol("data")), names) + @test types.parameters[data_idx] == of(Array, 10, :b) + end + + @testset "Nested structures" begin + InnerT = @of(size = of(Int; constant=true), values = of(Array, size)) + + OuterT = @of(n = of(Int; constant=true), inner = InnerT) + + CT = of(OuterT; n=5, size=3) + + outer_names = get_names(CT) + @test :n ∉ outer_names + @test :inner ∈ outer_names + + outer_types = get_types(CT) + inner_type = outer_types.parameters[1] + + inner_names = get_names(inner_type) + @test :size ∉ inner_names + @test :values ∈ inner_names + + inner_types = get_types(inner_type) + @test inner_types.parameters[1] == of(Array, 3) + end + + @testset "Symbolic dimension checking" begin + T = @of(const_field = of(Int; constant=true), regular_field = of(Real, 0, 1)) + + @test has_symbolic_dims(T) == true + + CT = of(T; const_field=42) + @test has_symbolic_dims(CT) == false + + @test get_names(CT) == (:regular_field,) + end +end + +@testset "Multi-hop constant dependencies" begin + @testset "Chain dependencies" begin + T = @of( + a = of(Int, 1, 10; constant=true), + b = of(Int, a, 20; constant=true), + c = of(Int, b, 30; constant=true), + data = of(Array, c, c) + ) + + CT = of(T; a=5, b=10, c=15) + + @test get_names(CT) == (:data,) + types = get_types(CT) + @test types.parameters[1] == of(Array, 15, 15) + end + + @testset "Expression dependencies" begin + T = @of( + base = of(Int, 2, 5; constant=true), + width = of(Int, base, base * 2; constant=true), + height = of(Int, base, base * 3; constant=true), + volume = of(Array, width, height, base) + ) + + CT = of(T; base=3, width=5, height=7) + + @test get_names(CT) == (:volume,) + types = get_types(CT) + @test types.parameters[1] == of(Array, 5, 7, 3) + end +end + +@testset "Type operations" begin + @testset "length calculation" begin + @test length(of(Int)) == 1 + @test length(of(Real)) == 1 + @test length(of(Array, 5)) == 5 + @test length(of(Array, 3, 4)) == 12 + @test length(of(Array, Int, 2, 3, 4)) == 24 + + T = @of(a = of(Int), b = of(Real), c = of(Array, 3)) + @test length(T) == 5 # 1 + 1 + 3 + + T2 = @of(n = of(Int; constant=true), data = of(Array, n)) + CT = of(T2; n=10) + @test length(CT) == 10 # only data field remains (n is eliminated) + end + + @testset "rand generation" begin + @test rand(of(Int, 1, 10)) isa Int + @test rand(of(Real, 0.0, 1.0)) isa Float64 + arr = rand(of(Array, 5, 3)) + @test arr isa Matrix{Float64} + @test size(arr) == (5, 3) + + T = @of(x = of(Real), y = of(Int, 0, 100), z = of(Array, 2, 2)) + instance = rand(T) + @test instance isa NamedTuple + @test haskey(instance, :x) && instance.x isa Float64 + @test haskey(instance, :y) && instance.y isa Int + @test haskey(instance, :z) && instance.z isa Matrix{Float64} + end + + @testset "zero generation" begin + @test zero(of(Int)) == 0 + @test zero(of(Real)) == 0.0 + arr = zero(of(Array, 3, 2)) + @test arr isa Matrix{Float64} + @test all(arr .== 0.0) + + T = @of(a = of(Int), b = of(Real), c = of(Array, 2)) + instance = zero(T) + @test instance.a == 0 + @test instance.b == 0.0 + @test all(instance.c .== 0.0) + end + + @testset "flatten/unflatten" begin + T = @of( + int_val = of(Int), + real_val = of(Real), + vec = of(Array, 3), + mat = of(Array, 2, 2) + ) + + original = (int_val=42, real_val=3.14, vec=[1.0, 2.0, 3.0], mat=[4.0 5.0; 6.0 7.0]) + flat = flatten(T, original) + reconstructed = unflatten(T, flat) + + @test reconstructed.int_val == original.int_val + @test reconstructed.real_val ≈ original.real_val + @test reconstructed.vec ≈ original.vec + @test reconstructed.mat ≈ original.mat + + @test length(flat) == length(T) + end + + @testset "flatten promotes mixed element types" begin + # A flat vector has a single element type: the promotion of the declared leaf types. + # Mixing Float32 and Float64 fields therefore yields a Float64 vector, and unflatten + # reconstructs every float leaf at that promoted precision (the declared type is a + # floor, not a forced down-conversion — see the AD round-trip test below). + T = @of( + f64 = of(Float64, 0.0, 1.0), + f32 = of(Float32, -1.0f0, 1.0f0), + f64_vec = of(Array, Float64, 3), + f32_mat = of(Array, Float32, 2, 2) + ) + + original = ( + f64=0.5, f32=0.25f0, f64_vec=[0.1, 0.2, 0.3], f32_mat=Float32[0.1 0.2; 0.3 0.4] + ) + + flat = flatten(T, original) + @test flat isa Vector{Float64} # concrete promoted eltype, not Vector{Real} + + reconstructed = unflatten(T, flat) + @test reconstructed.f64 isa Float64 + @test reconstructed.f32 isa Float64 # promoted to the flat vector's eltype + @test reconstructed.f64_vec isa Vector{Float64} + @test reconstructed.f32_mat isa Matrix{Float64} + + @test reconstructed.f64 ≈ original.f64 + @test reconstructed.f32 ≈ original.f32 + @test reconstructed.f64_vec ≈ original.f64_vec + @test reconstructed.f32_mat ≈ original.f32_mat + end + + @testset "flatten preserves a uniform element type" begin + # When every leaf shares an element type, flatten/unflatten round-trip it exactly. + T32 = @of(a = of(Float32), v = of(Array, Float32, 2)) + flat32 = flatten(T32, (a=0.5f0, v=Float32[1, 2])) + @test flat32 isa Vector{Float32} + r32 = unflatten(T32, flat32) + @test r32.a isa Float32 + @test r32.v isa Vector{Float32} + + # A pure-integer structure stays integer-typed (no gratuitous Float64 coercion). + Tint = @of(i = of(Int), w = of(Array, Int, 2)) + flatint = flatten(Tint, (i=3, w=[4, 5])) + @test flatint isa Vector{Int} + @test flatint == [3, 4, 5] + end +end + +@testset "Array type specifications" begin + @testset "Different element types" begin + T1 = of(Array, Int, 5) + @test get_element_type(T1) == Int + @test get_ndims(T1) == 1 + @test get_dims(T1) == (5,) + + T2 = of(Array, Bool, 3, 3) + @test get_element_type(T2) == Bool + @test get_ndims(T2) == 2 + @test get_dims(T2) == (3, 3) + + T3 = of(Array, 10) + @test get_element_type(T3) == Float64 + end + + @testset "Symbolic dimensions in arrays" begin + T = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + matrix = of(Array, rows, cols), + tensor = of(Array, rows, cols, 3) + ) + + types = get_types(T) + mat_type = types.parameters[3] + @test get_dims(mat_type) == (:rows, :cols) + + tensor_type = types.parameters[4] + @test get_dims(tensor_type) == (:rows, :cols, 3) + + CT = of(T; rows=2, cols=4) + # rows and cols are eliminated, only matrix and tensor remain + @test get_names(CT) == (:matrix, :tensor) + ct_types = get_types(CT) + @test get_dims(ct_types.parameters[1]) == (2, 4) + @test get_dims(ct_types.parameters[2]) == (2, 4, 3) + end +end + +@testset "Constructor with default_value" begin + @testset "Basic default_value usage" begin + T = @of( + rows = of(Int; constant=true), + cols = of(Int; constant=true), + scale = of(Real, 0.1, 10.0), + data = of(Array, rows, cols) + ) + + # No positional arg: each leaf defaults to zero() (or its lower bound). + instance1 = T(; rows=3, cols=2) + @test instance1.scale == 0.1 + @test all(instance1.data .== 0.0) + + instance2 = T(1.5; rows=3, cols=2) + @test instance2.scale == 1.5 + @test all(instance2.data .== 1.5) + + instance3 = T(missing; rows=3, cols=2) + @test instance3.scale === missing + @test all(instance3.data .=== missing) + + instance4 = T(2.0; rows=3, cols=2, scale=5.0) + @test instance4.scale == 5.0 + @test all(instance4.data .== 2.0) + end + + @testset "Default value validation" begin + T = @of(n = of(Int; constant=true), bounded = of(Real, 0, 10), data = of(Array, n)) + + # Valid default_value within bounds + instance = T(5.0; n=3) + @test instance.bounded == 5.0 + @test all(instance.data .== 5.0) + + # Invalid default_value outside bounds should throw + @test_throws ErrorException T(15.0; n=3) # 15.0 > upper bound 10 + @test_throws ErrorException T(-5.0; n=3) # -5.0 < lower bound 0 + end + + @testset "Different types with default_value" begin + T = @of( + size = of(Int; constant=true), + int_val = of(Int, 1, 100), + real_val = of(Real), + vec = of(Array, size), + mat = of(Array, size, size) + ) + + instance1 = T(42; size=2) + @test instance1.int_val == 42 + @test instance1.real_val == 42.0 + @test all(instance1.vec .== 42.0) + @test all(instance1.mat .== 42.0) + + instance2 = T(3.14; size=2) + @test instance2.int_val == 3 # Should round to Int + @test instance2.real_val ≈ 3.14 + @test all(instance2.vec .≈ 3.14) + @test all(instance2.mat .≈ 3.14) + end + + @testset "Nested structures with default_value" begin + # For nested structures, we need a simpler example + # The inner structure's constants should be handled at the outer level + OuterT = @of(n = of(Int; constant=true), scale = of(Real), vec = of(Array, n)) + + instance = OuterT(7.0; n=5) + @test instance.scale == 7.0 + @test instance.vec isa Vector{Float64} + @test length(instance.vec) == 5 + @test all(instance.vec .== 7.0) + end + + @testset "Type stability of default_value" begin + T = @of(n = of(Int; constant=true), data = of(Array, n)) + + # The two constructor methods should be type-stable + CT = of(T; n=5) + + # Method 1: no positional argument + @inferred NamedTuple{(:data,),Tuple{Vector{Float64}}} T(; n=5) + + # Method 2: with positional default_value + @inferred NamedTuple{(:data,),Tuple{Vector{Float64}}} T(1.0; n=5) + end +end + +@testset "Edge cases and error handling" begin + @testset "Invalid bounds" begin + # Test that invalid bounds are caught during validation + T = @of(value = of(Real, 0, 10)) + @test_throws ErrorException T(value=15.0) # value > upper bound + @test_throws ErrorException T(value=-5.0) # value < lower bound + end + + @testset "Missing required constants" begin + T = @of(n = of(Int; constant=true), data = of(Array, n)) + + # Should throw when trying to use without providing constant + @test_throws ErrorException rand(T) + @test_throws ErrorException zero(T) + + # Should throw when trying to create instance without providing constant + @test_throws ErrorException T() + @test_throws ErrorException T(data=rand(5)) + end + + @testset "Type display" begin + @test string(of(Int)) == "of(Int)" + @test string(of(Int, 0, 10)) == "of(Int, 0, 10)" + @test string(of(Real, 0.0, nothing)) == "of(Float64, 0.0, nothing)" + @test string(of(Float64, 0.0, nothing)) == "of(Float64, 0.0, nothing)" + @test string(of(Float32, 0.0f0, nothing)) == "of(Float32, 0.0, nothing)" + @test string(of(Array, 5)) == "of(Array, 5)" + @test string(of(Array, Float32, 3, 3)) == "of(Array, Float32, 3, 3)" + + # Test constant wrapper display + @test string(of(Int; constant=true)) == "of(Int; constant=true)" + @test string(of(Real, 0, 1; constant=true)) == "of(Real, 0, 1; constant=true)" + @test string(of(Float64; constant=true)) == "of(Real; constant=true)" + @test string(of(Float32; constant=true)) == "of(Float32; constant=true)" + + # Test that types without bounds don't show "nothing" + T = @of(rows = of(Int), cols = of(Int), data = of(Array, 3, 4)) + str = string(T) + @test occursin("rows = of(Int)", str) + @test occursin("cols = of(Int)", str) + @test !occursin("nothing", str) + end + + @testset "Type inference from values" begin + @test of(1.0) == of(Float64) + @test of(1.0f0) == of(Float32) + @test of(1) == of(Int) + @test of(1//2) == of(Float64) # Rationals default to Float64 + @test of(big(1.0)) == of(BigFloat) + + # Test arrays + @test of([1.0, 2.0, 3.0]) == of(Array, Float64, 3) + @test of(Float32[1.0, 2.0]) == of(Array, Float32, 2) + @test of([1 2; 3 4]) == of(Array, Int, 2, 2) + end +end + +@testset "Show method for NamedTuple" begin + T = @of(x = of(Real), y = of(Int, 0, 10)) + str = string(T) + @test occursin("@of(", str) + @test occursin("x=of(Float64)", str) + @test occursin("y=of(Int, 0, 10)", str) + + # Test multiline display with many fields + T2 = @of(a = of(Real), b = of(Int), c = of(Array, 3, 4), d = of(Float32, 0.0, 1.0)) + str2 = string(T2) + @test occursin("@of(", str2) + + # Test with constants + T3 = @of(n = of(Int; constant=true), data = of(Array, n, 2)) + str3 = string(T3) + @test occursin("of(Int; constant=true)", str3) + @test occursin("of(Array, n, 2)", str3) +end + +@testset "Type concretization" begin + T = @of(n = of(Int; constant=true), data = of(Array, n)) + + ConcreteT = of(T; n=5) + @test ConcreteT <: OfNamedTuple + @test get_names(ConcreteT) == (:data,) + types = get_types(ConcreteT) + @test get_dims(types.parameters[1]) == (5,) + + # Test expression dimensions + T2 = @of( + n = of(Int; constant=true), + original = of(Array, n, n), + padded = of(Array, n + 1, n + 1), + doubled = of(Array, 2 * n, n) + ) + + ConcreteT2 = of(T2; n=10) + names = get_names(ConcreteT2) + @test names == (:original, :padded, :doubled) + + types2 = get_types(ConcreteT2) + @test get_dims(types2.parameters[1]) == (10, 10) + @test get_dims(types2.parameters[2]) == (11, 11) + @test get_dims(types2.parameters[3]) == (20, 10) + + # Test bounded types with symbolic references + T3 = @of( + lower = of(Real; constant=true), + upper = of(Real; constant=true), + param = of(Real, lower, upper) + ) + + ConcreteT3 = of(T3; lower=0.0, upper=1.0) + types3 = get_types(ConcreteT3) + @test get_lower(types3.parameters[1]) == 0.0 + @test get_upper(types3.parameters[1]) == 1.0 +end + +@testset "Expression processing in @of macro" begin + T = @of( + n = of(Int; constant=true), + data1 = of(Array, n + 1), + data2 = of(Array, n * 2), + data3 = of(Array, n - 1), + data4 = of(Array, n / 2) + ) + + ConcreteT = of(T; n=10) + types = get_types(ConcreteT) + @test get_dims(types.parameters[1]) == (11,) + @test get_dims(types.parameters[2]) == (20,) + @test get_dims(types.parameters[3]) == (9,) + @test get_dims(types.parameters[4]) == (5,) + + # Test nested expressions + T2 = @of( + a = of(Int; constant=true), + b = of(Int; constant=true), + data = of(Array, (a + b) * 2) + ) + + ConcreteT2 = of(T2; a=10, b=4) + types2 = get_types(ConcreteT2) + @test get_dims(types2.parameters[1]) == (28,) # (10+4)*2 + + # Test division that requires integer result + T3 = @of(n = of(Int; constant=true), data = of(Array, n / 3)) + + # Should error when n=10 since 10/3 is not an integer + @test_throws ErrorException of(T3; n=10) + + # But should work when n=9 + ConcreteT3 = of(T3; n=9) + types3 = get_types(ConcreteT3) + @test get_dims(types3.parameters[1]) == (3,) +end + +# A submodule that does ONLY `using AbstractPPL`, with no access to internal names. This is +# the real downstream scope; the testsets above import `SymbolicExpr` and so cannot catch a +# macro that emits an unqualified reference to it. +module DownstreamScope +using AbstractPPL +using Test + +@testset "@of expands in a using-only scope" begin + # Plain symbolic dimensions (resolved at runtime, no injected type name). + Tsym = @of(n = of(Int; constant=true), data = of(Array, n, 2)) + @test of(Tsym; n=4) isa Type + + # Arithmetic dimensions inject `SymbolicExpr`, which must be emitted fully qualified so + # it resolves even though it is only `public`, not exported. + Texpr = @of( + n = of(Int; constant=true), + a = of(Array, n + 1), + b = of(Array, 2 * n, n), + c = of(Array, (n + 1) * 2) + ) + CT = of(Texpr; n=5) + @test size(rand(CT).a) == (6,) + @test size(rand(CT).b) == (10, 5) + @test size(rand(CT).c) == (12,) + + # Symbolic bounds likewise resolve in a using-only scope. + Tbound = @of(lo = of(Real; constant=true), x = of(Real, lo, nothing)) + @test of(Tbound; lo=0.0) isa Type + + # The generated calls are qualified to AbstractPPL, not captured by caller locals. + let of = _ -> error("shadowed") + Tshadow = @of(n = of(Int; constant=true), data = of(Array, n)) + @test AbstractPPL.of(Tshadow; n=2) isa Type + end +end +end # module DownstreamScope + +@testset "show is safe for non-concrete types" begin + # Rendering a free-typevar `of`-type (method signatures, stacktraces, Documenter) must not + # touch the static params; previously this threw UndefVarError, fatal mid-backtrace. + @test sprint(show, Tuple{Type{OfArray{T,N,D}}} where {T,N,D}) isa String + @test sprint(show, Tuple{Type{OfReal{T,L,U}}} where {T,L,U}) isa String + @test sprint(show, Tuple{Type{OfInt{L,U}}} where {L,U}) isa String + @test sprint(show, Tuple{Type{OfNamedTuple{Names,Types}}} where {Names,Types}) isa + String + # Listing the methods of a function with `of`-typed signatures must not crash. + @test sprint(show, methods(rand)) isa String + + # Concrete types still print in the pretty `of(...)` form. + @test string(of(Array, 2, 3)) == "of(Array, 2, 3)" + @test string(of(Array, Int)) == "of(Array, Int64)" # 0-dim: no trailing comma +end + +@testset "flatten/unflatten numeric contract" begin + @testset "concrete, promoted element type" begin + @test flatten(of(Int), 3) isa Vector{Int} + @test flatten(of(Float64), 1.5) isa Vector{Float64} + T = @of(i = of(Int), x = of(Real)) + @test flatten(T, (i=2, x=1.5)) isa Vector{Float64} # promote(Int, Float64) + end + + @testset "AD/wide eltypes flow through unflatten" begin + # BigFloat stands in for ForwardDiff.Dual: any `<:Real` wider than the declared type + # must survive unflatten without being coerced back to Float64. + T = @of(x = of(Real), data = of(Array, 2, 2)) + flat = BigFloat[big"1.0", big"2.0", big"3.0", big"4.0", big"5.0"] + r = unflatten(T, flat) + @test r.x isa BigFloat + @test r.data isa Matrix{BigFloat} + @test r.data == BigFloat[2 4; 3 5] + end + + @testset "declared type is a precision floor" begin + # Narrow input widens up to the declared float type. + @test unflatten(of(Array, 2, 2), [1, 3, 2, 4]) isa Matrix{Float64} + @test unflatten(of(Float64), Float32[0.5]) isa Float64 + # Integer leaves round to Int regardless of input eltype. + @test unflatten(of(Int), [2.0]) === 2 + end + + @testset "type stability on the sampler path" begin + T = @of( + x = of(Real), + n = of(Int), + data = of(Array, 2, 2), + inner = @of(a = of(Real), v = of(Array, 3)) + ) + v = collect(1.0:10.0) + nt = @inferred unflatten(T, v) + @inferred flatten(T, nt) + @inferred length(T) + @inferred size(T) + @test flatten(T, unflatten(T, v)) == v + end + + @testset "length / count errors" begin + T = @of(x = of(Real), data = of(Array, 2, 2)) + @test length(T) == 5 + @test_throws ErrorException unflatten(T, [1.0, 2.0]) # too few + @test_throws ErrorException unflatten(T, collect(1.0:6.0)) # too many + end +end + +@testset "rand threads an explicit RNG" begin + T = @of( + x = of(Real, 0, 1), n = of(Int, 1, 10), v = of(Array, 3), inner = @of(a = of(Real)) + ) + @test rand(MersenneTwister(42), T) == rand(MersenneTwister(42), T) + @test rand(MersenneTwister(1), of(Float64, 0, 1)) == + rand(MersenneTwister(1), of(Float64, 0, 1)) + @test rand(MersenneTwister(1), of(Array, 2, 2)) == + rand(MersenneTwister(1), of(Array, 2, 2)) + # The RNG-accepting method is what downstream samplers dispatch on. + @test hasmethod(rand, Tuple{MersenneTwister,Type{OfReal{Float64,Nothing,Nothing}}}) + @test (@inferred rand(MersenneTwister(1), T)) isa NamedTuple +end + +@testset "symbolic bounds are detected" begin + # has_symbolic_dims / get_unresolved_symbols must see symbolic *bounds*, not just dims, + # so zero/rand/flatten fail with a clean message instead of a raw MethodError. + T = @of(lo = of(Real; constant=true), x = of(Real, lo, nothing)) + @test has_symbolic_dims(T) + @test :lo in get_unresolved_symbols(T) + @test_throws ErrorException zero(T) + @test_throws ErrorException rand(T) + @test_throws ErrorException flatten(T, (lo=0.0, x=1.0)) + + # Expression bounds too. + T2 = @of(base = of(Int; constant=true), x = of(Int, base, base * 2)) + @test has_symbolic_dims(T2) + @test :base in get_unresolved_symbols(T2) +end + +@testset "top-level constants are rejected, never silent" begin + # A bare OfConstantWrapper is not flattenable; both ops must throw, not return `nothing`. + @test has_symbolic_dims(of(Int; constant=true)) + @test_throws ErrorException flatten(of(Int; constant=true), 5) + @test_throws ErrorException unflatten(of(Int; constant=true), Float64[]) + @test_throws ErrorException unflatten(of(Int; constant=true), missing) +end + +@testset "symbolic division guards the result" begin + T = @of(n = of(Int; constant=true), data = of(Array, n / 2)) + @test get_dims(get_types(of(T; n=10)).parameters[1]) == (5,) + # A non-integer quotient raises the dedicated error, not a raw InexactError. + err = try + of(T; n=11) + nothing + catch e + e + end + @test err isa ErrorException + @test occursin("not an integer", err.msg) +end + +@testset "leaf types are specifications, not instantiable" begin + @test_throws ErrorException OfReal{Float64,Nothing,Nothing}() + @test_throws ErrorException OfInt{Nothing,Nothing}() + @test_throws ErrorException OfArray{Float64,1,Tuple{3}}() + @test_throws ErrorException OfConstantWrapper{OfInt{Nothing,Nothing}}() +end + +@testset "eval_symbolic_expr operations and errors" begin + eval_expr = AbstractPPL.eval_symbolic_expr + @test eval_expr((:+, (:*, :n, 2), 1), (n=3,)) == 7 # nested + @test eval_expr((:-, :n), (n=3,)) == -3 # unary minus + @test eval_expr((:-, :n, 1), (n=3,)) == 2 # binary minus + @test eval_expr((:*, :n, 2), (n=3,)) == 6 + @test eval_expr((:/, :n, 2), (n=6,)) == 3 # exact division + @test_throws ErrorException eval_expr((:+,), (n=3,)) # too few elements + @test_throws ErrorException eval_expr((:^, :n, 2), (n=3,)) # unsupported op + @test_throws ErrorException eval_expr((:+, :m, 1), (n=3,)) # unknown symbol + @test_throws ErrorException eval_expr((:-, :n, 1, 2), (n=3,)) # too many minus args + @test_throws ErrorException eval_expr((:/, :n, 2), (n=5,)) # non-integer quotient + @test_throws ErrorException eval_expr((:/, :n, 2, 1), (n=6,)) # division needs 2 args +end + +@testset "rand covers single-sided and unbounded branches" begin + rng = MersenneTwister(0) + @test rand(rng, of(Real, 0.0, nothing)) >= 0.0 # lower-only: shifted exponential + @test rand(rng, of(Real, nothing, 1.0)) <= 1.0 # upper-only + @test rand(rng, of(Int, 5, nothing)) >= 5 # lower-only + @test rand(rng, of(Int, nothing, 5)) <= 5 # upper-only + @test rand(rng, of(Int)) isa Int # unbounded +end + +@testset "zero respects the far bound" begin + @test zero(of(Real, nothing, -1.0)) == -1.0 # upper < 0 + @test zero(of(Real, 2.0, 5.0)) == 2.0 # lower > 0 + @test zero(of(Int, 5, 10)) == 5 # lower > 0 + @test zero(of(Int, nothing, -3)) == -3 # upper < 0 +end + +@testset "value validation rejects mismatches" begin + @test_throws ErrorException (@of(x = of(Real)))(; x="not real") + @test_throws ErrorException (@of(i = of(Int)))(; i=3.5) # non-integer Real + @test_throws ErrorException (@of(i = of(Int)))(; i="x") # non-Real + @test (@of(i = of(Int)))(; i=3.0).i === 3 # whole-number Real converts + @test_throws ErrorException flatten(@of(x = of(Real), y = of(Real)), (x=1.0,)) # missing field + @test_throws ErrorException flatten(@of(x = of(Real)), (x=1.0, y=2.0)) # extra field + @test_throws ErrorException of(@of(n = of(Int; constant=true)); n=3) # all fields constant + @test_throws ErrorException (@of(n = of(Int; constant=true), d = of(Array, n)))() # missing constant + @test_throws ErrorException (@of(n = of(Int; constant=true), d = of(Array, n)))(; + n=3, datta=ones(3) + ) + @test_throws ErrorException AbstractPPL._create_with_default( + OfConstantWrapper{OfInt{Nothing,Nothing}}, 0 + ) +end + +@testset "spec validation and concretization regressions" begin + @test_throws ErrorException of(Real, 2.0, 1.0) + @test_throws ErrorException of(Int, 5, 3) + @test_throws ErrorException of(Array, -1) + @test_throws ErrorException of(Array, 2.5) + @test_throws ErrorException of(Array, String, 2) + @test of(Array, ComplexF64, 2) == OfArray{ComplexF64,1,Tuple{2}} # any Number element + + T = of(Real, :lo, :hi) + @test of(T; lo=0.0, hi=1.0) == of(Real, 0.0, 1.0) + + TI = of(Int, :lo, :hi) + @test of(TI; lo=1, hi=3) == of(Int, 1, 3) + @test of(TI; lo=1.0, hi=3.0) == of(Int, 1, 3) + @test_throws ErrorException of(TI; lo=1.5, hi=3.5) + + TA = @of( + a = of(Int; constant=true), b = of(Int; constant=true), data = of(Array, a + 1, b) + ) + CTA = of(TA; b=3) + data_type = get_types(CTA).parameters[2] + @test has_symbolic_dims(CTA) + @test :a in get_unresolved_symbols(CTA) + @test get_dims(data_type) == (SymbolicExpr{(:+, :a, 1)}, 3) + + TC = @of(n = of(Int, 1, 5; constant=true), data = of(Array, n)) + @test_throws ErrorException of(TC; n=100) + @test_throws ErrorException TC(; n=100) + @test get_dims(get_types(of(TC; n=3)).parameters[1]) == (3,) + + Inner = @of(n = of(Int; constant=true)) + Outer = @of(inner = Inner, x = of(Real)) + @test of(Outer; n=2) == @of(x = of(Real)) +end + +@testset "@of reference semantics are constant-only and ordered" begin + @test_throws ErrorException macroexpand( + @__MODULE__, :(@of(lo = of(Real), x = of(Real, lo, nothing))) + ) + @test_throws ErrorException macroexpand( + @__MODULE__, :(@of(data = of(Array, n), n = of(Int; constant=true))) + ) + @test_throws ErrorException macroexpand( + @__MODULE__, :(@of(lo = of(Real), x = of(Real, :lo, nothing))) + ) + @test_throws ErrorException macroexpand( + @__MODULE__, :(@of(data = of(Array, :n), n = of(Int; constant=true))) + ) + @test_throws ErrorException macroexpand( + @__MODULE__, + :(@of(lo = of(Real; constant=true), x = of(Real, max(lo, 0.0), nothing))), + ) +end + +@testset "default_value initialises non-constant fields" begin + T = @of(n = of(Int; constant=true), x = of(Real), arr = of(Array, n)) + inst = T(missing; n=2) + @test inst.x === missing + @test all(ismissing, inst.arr) +end + +@testset "symbol collection across dims and bounds" begin + T = @of( + n = of(Int; constant=true), + s = of(Real; constant=true), + lo = of(Real, 0.0, 1.0), + bnd = of(Real, n, nothing), # symbolic bound + a = of(Array, n, 2 * n), # symbolic + arithmetic dims + ) + syms = get_unresolved_symbols(T) + @test :n in syms + @test :s in syms + @test has_symbolic_dims(T) + @test !has_symbolic_dims(@of(x = of(Real), y = of(Array, 2, 2))) + @test get_unresolved_symbols(of(Int; constant=true)) isa Vector{Symbol} +end + +@testset "symbolic bounds resolve under concretization" begin + TR = @of(n = of(Int; constant=true), r = of(Real, n, nothing)) + @test get_lower(get_types(of(TR; n=2)).parameters[1]) == 2 + TI = @of(m = of(Int; constant=true), k = of(Int, m, nothing)) + @test get_lower(get_types(of(TI; m=4)).parameters[1]) == 4 + TE = @of(n = of(Int; constant=true), e = of(Real, n + 1, nothing)) + @test get_lower(get_types(of(TE; n=5)).parameters[1]) == 6 + @test has_symbolic_dims(of(TR)) # nothing resolved -> reference is retained +end + +@testset "unflatten with missing builds a missing-filled structure" begin + P = @of(x = of(Real), a = of(Array, 2, 2)) + um = unflatten(P, missing) + @test um.x === missing + @test all(ismissing, um.a) +end + +@testset "expression formatting for display" begin + fmt = AbstractPPL.expr_tuple_to_string + @test fmt((:*, (:+, :n, 1), 2)) == "(n + 1) * 2" + @test fmt((:/, 2, (:+, :n, 1))) == "2 / (n + 1)" + @test fmt((:+, :n, 1)) == "n + 1" + @test fmt((:+,)) isa String # too short + @test fmt((:+, 1, 2, 3)) isa String # not a binary op + @test fmt((:%, :n, 2)) isa String # unsupported op +end + +@testset "show covers colored, symbolic, and fallback paths" begin + # Non-color array with symbolic and arithmetic dimensions. + s = sprint(show, @of(n = of(Int; constant=true), a = of(Array, n, n + 1))) + @test occursin("n", s) + @test occursin("n + 1", s) + + # A colored render of a rich spec exercises the highlighting branches. + Tc = @of( + n = of(Int; constant=true), # constant Int + s = of(Real; constant=true), # constant Real + bc = of(Int, 1, 9; constant=true), # bounded constant + lo = of(Real, 0.0, 1.0), + bnd = of(Real, n, nothing), # symbolic bound referencing a constant + a = of(Array, n, 2 * n), # symbolic + arithmetic dims + ) + cs = sprint(show, Tc; context=(:color => true)) + @test occursin("@of(", cs) + @test occursin("constant=true", cs) + + # Colored bounded scalars and constants. + @test occursin("Int", sprint(show, of(Int, 0, 10); context=(:color => true))) + @test occursin( + "constant", sprint(show, of(Int, 0, 10; constant=true); context=(:color => true)) + ) + @test occursin( + "constant", sprint(show, of(Int; constant=true); context=(:color => true)) + ) + @test occursin( + "constant", sprint(show, of(Real; constant=true); context=(:color => true)) + ) + + # Constant-wrapper show fallbacks for shapes the public API never produces. + @test occursin( + "OfConstantWrapper", sprint(show, OfConstantWrapper{OfArray{Float64,1,Tuple{3}}}) + ) + @test occursin( + "OfConstantWrapper", + sprint( + show, + OfConstantWrapper{OfNamedTuple{(:x,),Tuple{OfReal{Float64,Nothing,Nothing}}}}, + ), + ) + + # Colored symbolic bounds: constant-wrapped (both ends) and a symbolic-expression bound. + Tb = @of( + n = of(Int; constant=true), + cc = of(Real, n, n; constant=true), # constant field, both bounds symbolic + be = of(Real, n + 1, nothing), # symbolic-expression bound + ) + @test occursin("of(", sprint(show, Tb; context=(:color => true))) + # A symbolic bound whose symbol is not a known constant still renders. + @test occursin("foo", sprint(show, of(Real, :foo, nothing); context=(:color => true))) +end + +@testset "type inference from NamedTuple of values" begin + T = of((a=1, b=2.0, c=[1.0, 2.0])) + @test T <: OfNamedTuple + @test get_names(T) == (:a, :b, :c) + @test get_types(T).parameters[1] == OfInt{Nothing,Nothing} + @test get_types(T).parameters[2] == OfReal{Float64,Nothing,Nothing} +end + +@testset "missing default fills integer fields too" begin + T = @of(n = of(Int; constant=true), i = of(Int), arr = of(Array, n)) + inst = T(missing; n=2) + @test inst.i === missing + @test all(ismissing, inst.arr) +end + +@testset "@of rejects malformed field specifications" begin + # `@of` errors during macro expansion for non-`field = spec` arguments. + @test ( + try + macroexpand(@__MODULE__, :(@of(of(Real)))) + false + catch + true + end + ) + # A non-symbol field name is rejected. + @test ( + try + macroexpand(@__MODULE__, :(@of(a.b = of(Real)))) + false + catch + true + end + ) +end + +@testset "nested arithmetic dimensions collect inner symbols" begin + T = @of(n = of(Int; constant=true), m = of(Array, (n + 1) * 2)) + @test :n in get_unresolved_symbols(T) + @test get_dims(get_types(of(T; n=3)).parameters[1]) == (8,) +end diff --git a/test/runtests.jl b/test/runtests.jl index ad97e108..db3e8823 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ const GROUP = get(ENV, "GROUP", "All") include("varname/hasvalue.jl") include("varname/leaves.jl") include("varname/serialize.jl") + include("of.jl") end if GROUP == "All" || GROUP == "Doctests"