Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ export get_irfs, get_irf, get_IRF, simulate, get_simulation, get_simulations, ge
export get_conditional_forecast
export get_solution, get_first_order_solution, get_perturbation_solution, get_second_order_solution, get_third_order_solution
export get_steady_state, get_SS, get_ss, get_non_stochastic_steady_state, get_stochastic_steady_state, get_SSS, steady_state, SS, SSS, ss, sss
export get_mean_at_approximation_point
export get_non_stochastic_steady_state_residuals, get_residuals, check_residuals
export get_moments, get_statistics, get_covariance, get_standard_deviation, get_variance, get_var, get_std, get_stdev, get_cov, var, std, stdev, cov, get_mean #, mean
export get_autocorrelation, get_correlation, get_variance_decomposition, get_corr, get_autocorr, get_var_decomp, corr, autocorr
Expand Down Expand Up @@ -6654,6 +6655,132 @@ function rrule(::typeof(calculate_third_order_stochastic_steady_state),
return (x, solved), third_order_stochastic_steady_state_pullback
end


"""
calculate_pruned_second_order_mean_at_point(parameters::Vector{M},
𝓂::ℳ;
approximation_point::Union{Nothing, Vector{M}} = nothing,
opts::CalculationOptions = merge_calculation_options()) where M

Calculate the mean of the ergodic distribution implied by the pruned second-order perturbation solution
around a specified approximation point.

If `approximation_point` is `nothing`, the mean is calculated around the non-stochastic steady state (NSSS).
Otherwise, the approximation point specifies the values of the state variables around which to linearize.

# Arguments
- `parameters`: Vector of model parameters
- `𝓂`: Model object
- `approximation_point`: Optional vector of values for variables around which to compute perturbation (same length as SS_and_pars)
- `opts`: Calculation options

# Returns
- Tuple of (mean_of_variables::Vector, solved::Bool, SS_and_pars::Vector, solution_error, ∇₁, ∇₂, 𝐒₁, 𝐒₂)
"""
function calculate_pruned_second_order_mean_at_point(parameters::Vector{M},
𝓂::ℳ;
approximation_point::Union{Nothing, Vector{M}} = nothing,
opts::CalculationOptions = merge_calculation_options())::Tuple{Vector{M}, Bool, Vector{M}, M, AbstractMatrix{M}, SparseMatrixCSC{M, Int}, AbstractMatrix{M}, SparseMatrixCSC{M, Int}} where M

# Get NSSS and check for errors
SS_and_pars, (solution_error, iters) = get_NSSS_and_parameters(𝓂, parameters, opts = opts)

if solution_error > opts.tol.NSSS_acceptance_tol || isnan(solution_error)
if opts.verbose println("NSSS not found") end
return zeros(𝓂.timings.nVars), false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0)
end

# Use the provided approximation point or NSSS
approx_point = isnothing(approximation_point) ? SS_and_pars : approximation_point

# Calculate derivatives at the approximation point
∇₁ = calculate_jacobian(parameters, approx_point, 𝓂)

𝐒₁, qme_sol, solved = calculate_first_order_solution(∇₁;
T = 𝓂.timings,
opts = opts,
initial_guess = 𝓂.solution.perturbation.qme_solution)

if solved 𝓂.solution.perturbation.qme_solution = qme_sol end

if !solved
if opts.verbose println("1st order solution not found") end
return approx_point[1:𝓂.timings.nVars], false, SS_and_pars, solution_error, zeros(0,0), spzeros(0,0), zeros(0,0), spzeros(0,0)
end

∇₂ = calculate_hessian(parameters, approx_point, 𝓂)

𝐒₂, solved2 = calculate_second_order_solution(∇₁, ∇₂, 𝐒₁,
𝓂.solution.perturbation.second_order_auxiliary_matrices,
𝓂.caches;
T = 𝓂.timings,
initial_guess = 𝓂.solution.perturbation.second_order_solution,
opts = opts)

if !solved2
if opts.verbose println("2nd order solution not found") end
return approx_point[1:𝓂.timings.nVars], false, SS_and_pars, solution_error, ∇₁, spzeros(0,0), 𝐒₁, spzeros(0,0)
end

if eltype(𝐒₂) == Float64 𝓂.solution.perturbation.second_order_solution = 𝐒₂ end

𝐒₂ *= 𝓂.solution.perturbation.second_order_auxiliary_matrices.𝐔₂

if !(typeof(𝐒₂) <: AbstractSparseMatrix)
𝐒₂ = sparse(𝐒₂)
end

nᵉ = 𝓂.timings.nExo
nˢ = 𝓂.timings.nPast_not_future_and_mixed

s_in_s⁺ = BitVector(vcat(ones(Bool, nˢ), zeros(Bool, nᵉ + 1)))
e_in_s⁺ = BitVector(vcat(zeros(Bool, nˢ + 1), ones(Bool, nᵉ)))
v_in_s⁺ = BitVector(vcat(zeros(Bool, nˢ), 1, zeros(Bool, nᵉ)))

kron_states = ℒ.kron(s_in_s⁺, s_in_s⁺)
kron_shocks = ℒ.kron(e_in_s⁺, e_in_s⁺)
kron_volatility = ℒ.kron(v_in_s⁺, v_in_s⁺)

# first order
states_to_variables¹ = sparse(𝐒₁[:,1:nˢ])
states_to_states¹ = 𝐒₁[𝓂.timings.past_not_future_and_mixed_idx, 1:nˢ]
shocks_to_states¹ = 𝐒₁[𝓂.timings.past_not_future_and_mixed_idx, (nˢ + 1):end]

# second order
states_to_variables² = 𝐒₂[:, kron_states]
shocks_to_variables² = 𝐒₂[:, kron_shocks]
volatility_to_variables² = 𝐒₂[:, kron_volatility]

states_to_states² = 𝐒₂[𝓂.timings.past_not_future_and_mixed_idx, kron_states] |> collect
shocks_to_states² = 𝐒₂[𝓂.timings.past_not_future_and_mixed_idx, kron_shocks]
volatility_to_states² = 𝐒₂[𝓂.timings.past_not_future_and_mixed_idx, kron_volatility]

kron_states_to_states¹ = ℒ.kron(states_to_states¹, states_to_states¹) |> collect
kron_shocks_to_states¹ = ℒ.kron(shocks_to_states¹, shocks_to_states¹)

n_sts = nˢ

# Set up in pruned state transition matrices
pruned_states_to_pruned_states = [ states_to_states¹ zeros(M, n_sts, n_sts) zeros(M, n_sts, n_sts^2)
zeros(M, n_sts, n_sts) states_to_states¹ states_to_states² / 2
zeros(M, n_sts^2, 2 * n_sts) kron_states_to_states¹ ]

pruned_states_to_variables = [states_to_variables¹ states_to_variables¹ states_to_variables² / 2]

pruned_states_vol_and_shock_effect = [ zeros(M, n_sts)
vec(volatility_to_states²) / 2 + shocks_to_states² / 2 * vec(ℒ.I(nᵉ))
kron_shocks_to_states¹ * vec(ℒ.I(nᵉ))]

variables_vol_and_shock_effect = (vec(volatility_to_variables²) + shocks_to_variables² * vec(ℒ.I(nᵉ))) / 2

## First-order moments, ie mean of variables
mean_of_pruned_states = (ℒ.I(size(pruned_states_to_pruned_states, 1)) - pruned_states_to_pruned_states) \ pruned_states_vol_and_shock_effect
mean_of_variables = approx_point[1:𝓂.timings.nVars] + pruned_states_to_variables * mean_of_pruned_states + variables_vol_and_shock_effect

return mean_of_variables, true, SS_and_pars, solution_error, ∇₁, ∇₂, 𝐒₁, 𝐒₂
end


@stable default_mode = "disable" begin

function solve!(𝓂::ℳ;
Expand Down
117 changes: 117 additions & 0 deletions src/get_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,123 @@ sss(args...; kwargs...) = get_steady_state(args...; kwargs..., stochastic = true



"""
$(SIGNATURES)
Return the mean of the pruned perturbation solution computed at a specified approximation point.

This function allows computing the mean of the ergodic distribution when the perturbation
solution is computed around a different point than the non-stochastic steady state.

# Arguments
- $MODEL®
# Keyword Arguments
- $PARAMETERS®
- `approximation_point` [Default: `nothing`, Type: `Union{Nothing, KeyedArray}`]: the point around which to compute the perturbation solution. If `nothing`, uses the non-stochastic steady state.
- `algorithm` [Default: `:pruned_second_order`, Type: `Symbol`]: algorithm to use. Currently only `:pruned_second_order` is supported.
- $QME®
- $SYLVESTER®
- $TOLERANCES®
- $VERBOSE®

# Returns
- `KeyedArray` with the mean values.

# Examples
```jldoctest
using MacroModelling

@model RBC begin
1 / c[0] = (β / c[1]) * (α * exp(z[1]) * k[0]^(α - 1) + (1 - δ))
c[0] + k[0] = (1 - δ) * k[-1] + q[0]
q[0] = exp(z[0]) * k[-1]^α
z[0] = ρ * z[-1] + std_z * eps_z[x]
end

@parameters RBC begin
std_z = 0.01
ρ = 0.2
δ = 0.02
α = 0.5
β = 0.95
end

# Mean computed around NSSS (default)
get_mean_at_approximation_point(RBC)
# output
1-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓ Variables ∈ 4-element Vector{Symbol}
And data, 4-element Vector{Float64}:
(:c) 5.936894477029751
(:k) 47.39592121534711
(:q) 6.884812901336693
(:z) 0.0
```
"""
function get_mean_at_approximation_point(𝓂::ℳ;
parameters::ParameterType = nothing,
approximation_point::Union{Nothing, KeyedArray} = nothing,
algorithm::Symbol = :pruned_second_order,
verbose::Bool = DEFAULT_VERBOSE,
tol::Tolerances = Tolerances(),
quadratic_matrix_equation_algorithm::Symbol = DEFAULT_QME_ALGORITHM,
sylvester_algorithm::Union{Symbol,Vector{Symbol},Tuple{Symbol,Vararg{Symbol}}} = DEFAULT_SYLVESTER_SELECTOR(𝓂))::KeyedArray

@assert algorithm ∈ [:pruned_second_order] "Mean at approximation point currently only supports :pruned_second_order algorithm."

opts = merge_calculation_options(tol = tol, verbose = verbose,
quadratic_matrix_equation_algorithm = quadratic_matrix_equation_algorithm,
sylvester_algorithm² = isa(sylvester_algorithm, Symbol) ? sylvester_algorithm : sylvester_algorithm[1],
sylvester_algorithm³ = (isa(sylvester_algorithm, Symbol) || length(sylvester_algorithm) < 2) ? :bicgstab : sylvester_algorithm[2])

# Need to solve with second order algorithm to set up second-order derivatives
solve!(𝓂, parameters = parameters, algorithm = algorithm, opts = opts)

# Convert approximation_point KeyedArray to Vector if provided
approx_point_vec = nothing
if !isnothing(approximation_point) && approximation_point isa KeyedArray
# Get NSSS to understand the structure
SS_and_pars, _ = get_NSSS_and_parameters(𝓂, 𝓂.parameter_values, opts = opts)

# Get variable labels
NSSS_labels = [sort(union(𝓂.exo_present,𝓂.var))...,𝓂.calibration_equations_parameters...]

# Convert KeyedArray to a Vector matching SS_and_pars structure
approx_point_vec = copy(SS_and_pars)

# Get the keys from the approximation_point KeyedArray
approx_keys = axiskeys(approximation_point, 1)

# Map each key to the corresponding position in SS_and_pars
for (i, key) in enumerate(approx_keys)
key_sym = key isa Symbol ? key : Symbol(key)
idx = findfirst(x -> x == key_sym, NSSS_labels)
if !isnothing(idx)
approx_point_vec[idx] = approximation_point[i]
end
end
end

mean_vars, solved, _, _, _, _, _, _ = calculate_pruned_second_order_mean_at_point(𝓂.parameter_values, 𝓂,
approximation_point = approx_point_vec,
opts = opts)

@assert solved "Could not compute mean at the specified approximation point."

vars_in_ss_equations = sort(collect(setdiff(reduce(union,get_symbols.(𝓂.ss_aux_equations)),union(𝓂.parameters_in_equations,𝓂.➕_vars))))

var_idx = indexin([vars_in_ss_equations...], [𝓂.var...,𝓂.calibration_equations_parameters...])

axis1 = vars_in_ss_equations

if any(x -> contains(string(x), "◖"), axis1)
axis1_decomposed = decompose_name.(axis1)
axis1 = [length(a) > 1 ? string(a[1]) * "{" * join(a[2],"}{") * "}" * (a[end] isa Symbol ? string(a[end]) : "") : string(a[1]) for a in axis1_decomposed]
end

return KeyedArray(mean_vars[var_idx]; Variables = axis1)
end


"""
See [`get_steady_state`](@ref)
"""
Expand Down