Skip to content

Commit 1b97a5e

Browse files
committed
Add model spec validation
1 parent a254503 commit 1b97a5e

4 files changed

Lines changed: 284 additions & 3 deletions

File tree

src/PlantSimEngine.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ include("mtg/mapping/getters.jl")
7575
include("mtg/mapping/mapping.jl")
7676
include("mtg/mapping/compute_mapping.jl")
7777
include("mtg/mapping/reverse_mapping.jl")
78+
include("mtg/model_spec_validation.jl")
7879
include("mtg/initialisation.jl")
7980
include("mtg/save_results.jl")
8081
include("mtg/add_organ.jl")

src/mtg/initialisation.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ function init_simulation(mtg, mapping; nsteps=1, outputs=nothing, type_promotion
314314
@assert false "Error : Mapping status at $organ_with_vector level contains a vector. If this was intentional, call the function generate_models_from_status_vectors on your mapping before calling run!. And bear in mind this is not meant for production. If this wasn't intentional, then it's likely an issue on the mapping definition, or an unusual model."
315315
end
316316

317+
models = Dict(first(m) => parse_models(get_models(last(m))) for m in mapping)
318+
model_specs = Dict(first(m) => parse_model_specs(last(m)) for m in mapping)
319+
validate_model_specs_configuration(model_specs)
320+
317321
soft_dep_graphs_roots, hard_dep_dict = hard_dependencies(mapping; verbose=false)
318322

319323
# Get the status of each node by node type, pre-initialised considering multi-scale variables:
@@ -346,9 +350,6 @@ function init_simulation(mtg, mapping; nsteps=1, outputs=nothing, type_promotion
346350
@info "Models given for $model_no_node, but no node with this symbol was found in the MTG." maxlog = 1
347351
end
348352

349-
models = Dict(first(m) => parse_models(get_models(last(m))) for m in mapping)
350-
model_specs = Dict(first(m) => parse_model_specs(last(m)) for m in mapping)
351-
352353
outputs = pre_allocate_outputs(statuses, status_templates, reverse_multiscale_mapping, vars_need_init, outputs, nsteps, type_promotion=type_promotion, check=check)
353354

354355
outputs_index = Dict{String, Int}(s => 1 for s in keys(outputs))

src/mtg/model_spec_validation.jl

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
const _INPUT_BINDING_FIELDS = (:process, :var, :scale, :policy)
2+
3+
function _validate_timestep_spec(scale::String, process::Symbol, spec::ModelSpec)
4+
ts = timestep(spec)
5+
isnothing(ts) && return nothing
6+
7+
if ts isa ClockSpec
8+
float(ts.dt) > 0 || error(
9+
"Invalid timestep for process `$(process)` at scale `$(scale)`: ",
10+
"`ClockSpec.dt` must be > 0, got $(ts.dt)."
11+
)
12+
return nothing
13+
end
14+
15+
if ts isa Real
16+
float(ts) > 0 || error(
17+
"Invalid timestep for process `$(process)` at scale `$(scale)`: ",
18+
"numeric timestep must be > 0, got $(ts)."
19+
)
20+
return nothing
21+
end
22+
23+
error(
24+
"Invalid timestep for process `$(process)` at scale `$(scale)`: ",
25+
"expected `Real` or `ClockSpec`, got `$(typeof(ts))`."
26+
)
27+
end
28+
29+
function _validate_binding_policy(scale::String, process::Symbol, input_var::Symbol, policy)
30+
if policy isa DataType
31+
policy <: SchedulePolicy || error(
32+
"Invalid policy for input `$(input_var)` in process `$(process)` at scale `$(scale)`: ",
33+
"expected a `SchedulePolicy` type or instance, got `$(policy)`."
34+
)
35+
return nothing
36+
end
37+
38+
policy isa SchedulePolicy || error(
39+
"Invalid policy for input `$(input_var)` in process `$(process)` at scale `$(scale)`: ",
40+
"expected a `SchedulePolicy` type or instance, got `$(typeof(policy))`."
41+
)
42+
43+
return nothing
44+
end
45+
46+
function _validate_binding_target(
47+
scale::String,
48+
process::Symbol,
49+
input_var::Symbol,
50+
source_process::Symbol,
51+
source_scale,
52+
model_specs,
53+
known_processes::Set{Symbol}
54+
)
55+
source_process in known_processes || error(
56+
"Unknown source process `$(source_process)` for input `$(input_var)` in process `$(process)` at scale `$(scale)`."
57+
)
58+
59+
isnothing(source_scale) && return nothing
60+
src_scale = string(source_scale)
61+
haskey(model_specs, src_scale) || error(
62+
"Unknown source scale `$(src_scale)` for input `$(input_var)` in process `$(process)` at scale `$(scale)`."
63+
)
64+
source_process in keys(model_specs[src_scale]) || error(
65+
"Source process `$(source_process)` for input `$(input_var)` in process `$(process)` ",
66+
"is not declared at scale `$(src_scale)`."
67+
)
68+
return nothing
69+
end
70+
71+
function _validate_input_binding(
72+
scale::String,
73+
process::Symbol,
74+
input_var::Symbol,
75+
binding,
76+
model_specs,
77+
known_processes::Set{Symbol}
78+
)
79+
source_process = nothing
80+
source_scale = nothing
81+
policy = HoldLast()
82+
83+
if binding isa Symbol
84+
source_process = binding
85+
elseif binding isa Pair{Symbol,Symbol}
86+
source_process = first(binding)
87+
elseif binding isa NamedTuple
88+
extra = setdiff(collect(keys(binding)), collect(_INPUT_BINDING_FIELDS))
89+
isempty(extra) || error(
90+
"Invalid input binding for input `$(input_var)` in process `$(process)` at scale `$(scale)`: ",
91+
"unsupported fields $(extra)."
92+
)
93+
haskey(binding, :process) || error(
94+
"Invalid input binding for input `$(input_var)` in process `$(process)` at scale `$(scale)`: ",
95+
"field `process` is required."
96+
)
97+
binding.process isa Symbol || error(
98+
"Invalid input binding for input `$(input_var)` in process `$(process)` at scale `$(scale)`: ",
99+
"`process` must be a Symbol, got `$(typeof(binding.process))`."
100+
)
101+
source_process = binding.process
102+
103+
if haskey(binding, :var)
104+
isnothing(binding.var) || binding.var isa Symbol || error(
105+
"Invalid input binding for input `$(input_var)` in process `$(process)` at scale `$(scale)`: ",
106+
"`var` must be a Symbol or `nothing`, got `$(typeof(binding.var))`."
107+
)
108+
end
109+
110+
if haskey(binding, :scale)
111+
isnothing(binding.scale) || binding.scale isa Symbol || binding.scale isa AbstractString || error(
112+
"Invalid input binding for input `$(input_var)` in process `$(process)` at scale `$(scale)`: ",
113+
"`scale` must be a Symbol, String or `nothing`, got `$(typeof(binding.scale))`."
114+
)
115+
source_scale = binding.scale
116+
end
117+
118+
policy = haskey(binding, :policy) ? binding.policy : HoldLast()
119+
else
120+
error(
121+
"Invalid input binding for input `$(input_var)` in process `$(process)` at scale `$(scale)`: ",
122+
"unsupported binding type `$(typeof(binding))`."
123+
)
124+
end
125+
126+
_validate_binding_policy(scale, process, input_var, policy)
127+
_validate_binding_target(scale, process, input_var, source_process, source_scale, model_specs, known_processes)
128+
return nothing
129+
end
130+
131+
function _validate_input_bindings_for_spec(
132+
scale::String,
133+
process::Symbol,
134+
spec::ModelSpec,
135+
model_specs,
136+
known_processes::Set{Symbol}
137+
)
138+
bindings = input_bindings(spec)
139+
bindings isa NamedTuple || error(
140+
"InputBindings for process `$(process)` at scale `$(scale)` must be a NamedTuple, got `$(typeof(bindings))`."
141+
)
142+
143+
model_inputs = Set(keys(inputs_(model_(spec))))
144+
for (input_var, binding) in pairs(bindings)
145+
input_var isa Symbol || error(
146+
"InputBindings key for process `$(process)` at scale `$(scale)` must be a Symbol, got `$(typeof(input_var))`."
147+
)
148+
input_var in model_inputs || error(
149+
"InputBindings for process `$(process)` at scale `$(scale)` declares binding for input `$(input_var)`, ",
150+
"but model inputs are $(collect(model_inputs))."
151+
)
152+
_validate_input_binding(scale, process, input_var, binding, model_specs, known_processes)
153+
end
154+
return nothing
155+
end
156+
157+
function _validate_output_routing_for_spec(scale::String, process::Symbol, spec::ModelSpec)
158+
routing = output_routing(spec)
159+
routing isa NamedTuple || error(
160+
"OutputRouting for process `$(process)` at scale `$(scale)` must be a NamedTuple, got `$(typeof(routing))`."
161+
)
162+
163+
model_outputs = Set(keys(outputs_(model_(spec))))
164+
for (out_var, mode) in pairs(routing)
165+
out_var isa Symbol || error(
166+
"OutputRouting key for process `$(process)` at scale `$(scale)` must be a Symbol, got `$(typeof(out_var))`."
167+
)
168+
out_var in model_outputs || error(
169+
"OutputRouting for process `$(process)` at scale `$(scale)` declares routing for output `$(out_var)`, ",
170+
"but model outputs are $(collect(model_outputs))."
171+
)
172+
173+
mode_sym = mode isa Symbol ? mode : (mode isa AbstractString ? Symbol(mode) : nothing)
174+
isnothing(mode_sym) && error(
175+
"OutputRouting mode for output `$(out_var)` in process `$(process)` at scale `$(scale)` ",
176+
"must be `:canonical` or `:stream_only`."
177+
)
178+
mode_sym in (:canonical, :stream_only) || error(
179+
"OutputRouting mode `$(mode_sym)` for output `$(out_var)` in process `$(process)` at scale `$(scale)` ",
180+
"is invalid. Allowed values: `:canonical`, `:stream_only`."
181+
)
182+
end
183+
184+
return nothing
185+
end
186+
187+
"""
188+
validate_model_specs_configuration(model_specs)
189+
190+
Validate mapping-level `ModelSpec` configuration before simulation runtime starts.
191+
This catches invalid timestep declarations, input bindings and output routing early.
192+
"""
193+
function validate_model_specs_configuration(model_specs)
194+
known_processes = Set{Symbol}()
195+
for specs_at_scale in values(model_specs)
196+
union!(known_processes, keys(specs_at_scale))
197+
end
198+
199+
for (scale, specs_at_scale) in pairs(model_specs)
200+
for (process, spec) in pairs(specs_at_scale)
201+
_validate_timestep_spec(scale, process, spec)
202+
_validate_input_bindings_for_spec(scale, process, spec, model_specs, known_processes)
203+
_validate_output_routing_for_spec(scale, process, spec)
204+
end
205+
end
206+
207+
return nothing
208+
end

test/test-multirate-runtime.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,25 @@ function PlantSimEngine.run!(::MRAggConsumerModel, models, status, meteo, consta
142142
status.YA = status.XA
143143
end
144144

145+
PlantSimEngine.@process "mrdailysource" verbose = false
146+
struct MRDailySourceModel <: AbstractMrdailysourceModel
147+
n::Base.RefValue{Int}
148+
end
149+
PlantSimEngine.inputs_(::MRDailySourceModel) = NamedTuple()
150+
PlantSimEngine.outputs_(::MRDailySourceModel) = (XD=-Inf,)
151+
function PlantSimEngine.run!(m::MRDailySourceModel, models, status, meteo, constants=nothing, extra=nothing)
152+
m.n[] += 1
153+
status.XD = float(m.n[])
154+
end
155+
156+
PlantSimEngine.@process "mrhourlyfromdailyconsumer" verbose = false
157+
struct MRHourlyFromDailyConsumerModel <: AbstractMrhourlyfromdailyconsumerModel end
158+
PlantSimEngine.inputs_(::MRHourlyFromDailyConsumerModel) = (XD=-Inf,)
159+
PlantSimEngine.outputs_(::MRHourlyFromDailyConsumerModel) = (YD=-Inf,)
160+
function PlantSimEngine.run!(::MRHourlyFromDailyConsumerModel, models, status, meteo, constants=nothing, extra=nothing)
161+
status.YD = status.XD
162+
end
163+
145164
@testset "Multi-rate runtime: HoldLast and conflict validation" begin
146165
mtg = Node(MultiScaleTreeGraph.NodeMTG("/", "Scene", 1, 0))
147166
plant = Node(mtg, MultiScaleTreeGraph.NodeMTG("+", "Plant", 1, 1))
@@ -219,6 +238,9 @@ end
219238
st_clock = status(sim_clock_trait)["Leaf"][1]
220239
@test st_clock.X == 4.0
221240
@test st_clock.Y == 3.0
241+
scope = ScopeId(:global, 1)
242+
@test sim_clock_trait.temporal_state.last_run[ModelKey(scope, "Leaf", :mrclocksource)] == 4.0
243+
@test sim_clock_trait.temporal_state.last_run[ModelKey(scope, "Leaf", :mrclockconsumer)] == 3.0
222244

223245
# Expectation 7: TimeStepModel override takes precedence over model timespec.
224246
source_counter_2 = Ref(0)
@@ -235,6 +257,8 @@ end
235257
st_clock_override = status(sim_clock_override)["Leaf"][1]
236258
@test st_clock_override.X == 4.0
237259
@test st_clock_override.Y == 3.0
260+
@test sim_clock_override.temporal_state.last_run[ModelKey(scope, "Leaf", :mrclocksource)] == 4.0
261+
@test sim_clock_override.temporal_state.last_run[ModelKey(scope, "Leaf", :mrclockconsumer)] == 3.0
238262

239263
# Expectation 8: cross-scale hold-last resolution works with different clocks.
240264
# Leaf producer runs each step; Plant consumer runs every 2 steps (1, 3) and reads Leaf XS through multiscale mapping.
@@ -296,4 +320,51 @@ end
296320
out_agg_df = convert_outputs(out_agg, DataFrame)
297321
@test out_agg_df["Leaf"][:, :YA] == [1.0, 1.0, 2.5, 2.5]
298322
@test status(sim_agg)["Leaf"][1].YA == 2.5
323+
324+
# Expectation 11: daily producer to hourly consumer within same day uses hold-last.
325+
# Source runs at t=1 and t=25 (ClockSpec(24,1)), consumer runs every step.
326+
# YD should stay at 1 for t=1..24, then switch to 2 at t=25.
327+
daily_counter = Ref(0)
328+
mapping_daily_hourly = Dict(
329+
"Leaf" => (
330+
ModelSpec(MRDailySourceModel(daily_counter)) |> TimeStepModel(ClockSpec(24.0, 1.0)),
331+
ModelSpec(MRHourlyFromDailyConsumerModel()) |>
332+
TimeStepModel(1.0) |>
333+
InputBindings(; XD=(process=:mrdailysource, var=:XD)),
334+
),
335+
)
336+
meteo26 = Weather(repeat([Atmosphere(T=20.0, Wind=1.0, Rh=0.65)], 26))
337+
sim_daily_hourly = PlantSimEngine.GraphSimulation(mtg, mapping_daily_hourly, nsteps=26, check=true, outputs=Dict("Leaf" => (:YD,)))
338+
out_daily_hourly = run!(sim_daily_hourly, meteo26, multirate=true, executor=SequentialEx())
339+
out_daily_hourly_df = convert_outputs(out_daily_hourly, DataFrame)
340+
@test out_daily_hourly_df["Leaf"][1:24, :YD] == fill(1.0, 24)
341+
@test out_daily_hourly_df["Leaf"][25:26, :YD] == [2.0, 2.0]
342+
@test sim_daily_hourly.temporal_state.last_run[ModelKey(scope, "Leaf", :mrdailysource)] == 25.0
343+
@test sim_daily_hourly.temporal_state.last_run[ModelKey(scope, "Leaf", :mrhourlyfromdailyconsumer)] == 26.0
344+
345+
# Expectation 12: invalid mapping-level API configuration fails during GraphSimulation init.
346+
mapping_bad_input = Dict(
347+
"Leaf" => (
348+
MRSourceModel(),
349+
ModelSpec(MRConsumerModel()) |>
350+
InputBindings(; Z=(process=:mrsource, var=:S)),
351+
),
352+
)
353+
@test_throws "declares binding for input `Z`" PlantSimEngine.GraphSimulation(mtg, mapping_bad_input, nsteps=1, check=true, outputs=Dict("Leaf" => (:B,)))
354+
355+
mapping_bad_process = Dict(
356+
"Leaf" => (
357+
ModelSpec(MRConsumerModel()) |>
358+
InputBindings(; C=(process=:unknown_process, var=:S)),
359+
),
360+
)
361+
@test_throws "Unknown source process `unknown_process`" PlantSimEngine.GraphSimulation(mtg, mapping_bad_process, nsteps=1, check=true, outputs=Dict("Leaf" => (:B,)))
362+
363+
mapping_bad_routing = Dict(
364+
"Leaf" => (
365+
ModelSpec(MRSourceModel()) |>
366+
OutputRouting(; Z=:stream_only),
367+
),
368+
)
369+
@test_throws "declares routing for output `Z`" PlantSimEngine.GraphSimulation(mtg, mapping_bad_routing, nsteps=1, check=true, outputs=Dict("Leaf" => (:S,)))
299370
end

0 commit comments

Comments
 (0)