Skip to content

Commit d741d88

Browse files
committed
Add default inference of input bindings
1 parent c52a541 commit d741d88

5 files changed

Lines changed: 175 additions & 3 deletions

File tree

docs/src/API/API_public.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ Period conversion detail:
4343
Trait-based inference detail:
4444
- If `TimeStepModel(...)` is omitted, `timestep_hint(::Type{<:Model})` may provide:
4545
: fixed period (`Dates.Day(1)`) or required range (`(Dates.Minute(1), Dates.Hour(4))`).
46+
- If `InputBindings(...)` is omitted, same-name sources are inferred automatically from
47+
: unique producers (same scale first, then cross-scale). Ambiguous cases require explicit bindings.
4648
- If `MeteoBindings(...)` / `MeteoWindow(...)` are omitted, `meteo_hint(::Type{<:Model})`
4749
: may provide `(; bindings=..., window=...)`.
4850
- Explicit mapping-level configuration always overrides hints.

docs/src/model_execution.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ the runtime can infer defaults from model traits:
2525
- `timestep_hint(::Type{<:MyModel})`
2626
- `meteo_hint(::Type{<:MyModel})`
2727

28+
If users do not provide `InputBindings(...)`, runtime infers same-name bindings:
29+
- first from a unique producer at the same scale;
30+
- otherwise from a unique producer at another scale;
31+
- if no producer exists, input stays unresolved (so initialization/forced values can be used);
32+
- if multiple producers are possible, runtime errors and asks for explicit `InputBindings(...)`.
33+
2834
For timestep hints:
2935
- `Dates.FixedPeriod` sets a fixed inferred timestep, e.g. `Dates.Day(1)`.
3036
- `(min_period, max_period)` sets a required range. For models with only range hints,
@@ -39,7 +45,7 @@ For meteo hints:
3945
Inspection helpers:
4046
- `resolved_model_specs(mapping)` returns resolved specs after inference/validation.
4147
- `explain_model_specs(mapping_or_sim)` prints a compact summary (`timestep`,
42-
`meteo_bindings`, `meteo_window`) for each model process.
48+
`input_bindings`, `meteo_bindings`, `meteo_window`) for each model process.
4349

4450
Policy parameterization:
4551
- `Integrate()` defaults to `SumReducer()`; you can pass another reducer, e.g. `Integrate(MeanReducer())` or `Integrate(vals -> maximum(vals) - minimum(vals))`.

src/mtg/model_spec_inference.jl

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,102 @@ function _infer_timestep_hints!(model_specs)
217217
return nothing
218218
end
219219

220+
function _format_candidate_list(candidates)
221+
isempty(candidates) && return "(none)"
222+
return join(["$(c.scale)/$(c.process)" for c in candidates], ", ")
223+
end
224+
225+
function _input_candidates_for_var(model_specs, consumer_scale::String, consumer_process::Symbol, input_var::Symbol)
226+
same_scale = NamedTuple[]
227+
cross_scale = NamedTuple[]
228+
229+
for (scale, specs_at_scale) in pairs(model_specs)
230+
for (process, spec) in pairs(specs_at_scale)
231+
scale == consumer_scale && process == consumer_process && continue
232+
input_var in keys(outputs_(model_(spec))) || continue
233+
c = (scale=scale, process=process, var=input_var)
234+
if scale == consumer_scale
235+
push!(same_scale, c)
236+
else
237+
push!(cross_scale, c)
238+
end
239+
end
240+
end
241+
242+
return same_scale, cross_scale
243+
end
244+
245+
function _infer_input_binding_for_var(model_specs, scale::String, process::Symbol, input_var::Symbol)
246+
same_scale, cross_scale = _input_candidates_for_var(model_specs, scale, process, input_var)
247+
248+
if length(same_scale) == 1
249+
c = only(same_scale)
250+
return (process=c.process, var=c.var, policy=HoldLast())
251+
elseif length(same_scale) > 1
252+
error(
253+
"Ambiguous inferred producer for input `$(input_var)` in process `$(process)` at scale `$(scale)`. ",
254+
"Multiple same-scale candidates were found: $(_format_candidate_list(same_scale)). ",
255+
"Please provide explicit `InputBindings(...)`."
256+
)
257+
end
258+
259+
if length(cross_scale) == 1
260+
c = only(cross_scale)
261+
return (process=c.process, var=c.var, scale=c.scale, policy=HoldLast())
262+
elseif length(cross_scale) > 1
263+
by_process = Dict{Symbol,Vector{NamedTuple}}()
264+
for c in cross_scale
265+
push!(get!(by_process, c.process, NamedTuple[]), c)
266+
end
267+
268+
if length(by_process) == 1
269+
proc = only(keys(by_process))
270+
scales = unique(c.scale for c in by_process[proc])
271+
if length(scales) == 1
272+
return (process=proc, var=input_var, scale=only(scales), policy=HoldLast())
273+
end
274+
# Same process name appears at multiple scales (common in multiscale
275+
# mappings). Keep scale unresolved so runtime resolves through parent links.
276+
return (process=proc, var=input_var, policy=HoldLast())
277+
end
278+
279+
error(
280+
"Ambiguous inferred producer for input `$(input_var)` in process `$(process)` at scale `$(scale)`. ",
281+
"Multiple cross-scale candidates were found: $(_format_candidate_list(cross_scale)). ",
282+
"Please provide explicit `InputBindings(...)`."
283+
)
284+
end
285+
286+
# No producer found. Keep input unresolved so user-provided initialization/forced
287+
# values can still drive the model.
288+
return nothing
289+
end
290+
291+
function _infer_input_bindings!(model_specs)
292+
for (scale, specs_at_scale) in pairs(model_specs)
293+
for (process, spec) in pairs(specs_at_scale)
294+
current_bindings = input_bindings(spec)
295+
current_bindings isa NamedTuple || continue
296+
297+
inferred = Pair{Symbol,Any}[]
298+
model_inputs = keys(inputs_(model_(spec)))
299+
300+
for input_var in model_inputs
301+
input_var in keys(current_bindings) && continue
302+
inferred_binding = _infer_input_binding_for_var(model_specs, scale, process, input_var)
303+
isnothing(inferred_binding) && continue
304+
push!(inferred, input_var => inferred_binding)
305+
end
306+
307+
isempty(inferred) && continue
308+
merged = (; pairs(current_bindings)..., inferred...)
309+
specs_at_scale[process] = ModelSpec(spec; input_bindings=merged)
310+
end
311+
end
312+
313+
return nothing
314+
end
315+
220316
function _normalize_meteo_hint(scale::String, process::Symbol, hint)
221317
isnothing(hint) && return (bindings=nothing, window=nothing)
222318

@@ -261,10 +357,13 @@ end
261357
"""
262358
infer_model_specs_configuration!(model_specs)
263359
264-
Fill missing `ModelSpec` fields from model-level hint traits.
360+
Fill missing `ModelSpec` fields from inference:
361+
- auto input bindings from unique same-name producers
362+
- model-level hint traits (`timestep_hint`, `meteo_hint`)
265363
Explicit `ModelSpec` user values always take precedence over inferred values.
266364
"""
267365
function infer_model_specs_configuration!(model_specs)
366+
_infer_input_bindings!(model_specs)
268367
_infer_timestep_hints!(model_specs)
269368
_infer_meteo_hints!(model_specs)
270369
return model_specs
@@ -310,6 +409,7 @@ function _model_specs_rows(model_specs)
310409
process=process,
311410
model=typeof(model_(spec)),
312411
timestep=timestep(spec),
412+
input_bindings=input_bindings(spec),
313413
meteo_bindings=meteo_bindings(spec),
314414
meteo_window=meteo_window(spec),
315415
))
@@ -329,6 +429,7 @@ Summary fields:
329429
- `process`
330430
- `model`
331431
- `timestep`
432+
- `input_bindings`
332433
- `meteo_bindings`
333434
- `meteo_window`
334435
"""
@@ -344,6 +445,7 @@ function explain_model_specs(target; io::IO=stdout, infer::Bool=true, validate::
344445

345446
for row in rows
346447
timestep_desc = isnothing(row.timestep) ? "(timespec(model))" : _stringify_compact(row.timestep)
448+
input_bindings_desc = (row.input_bindings isa NamedTuple && isempty(keys(row.input_bindings))) ? "(none)" : _stringify_compact(row.input_bindings)
347449
meteo_bindings_desc = (row.meteo_bindings isa NamedTuple && isempty(keys(row.meteo_bindings))) ? "(none)" : _stringify_compact(row.meteo_bindings)
348450
meteo_window_desc = isnothing(row.meteo_window) ? "(default rolling)" : _stringify_compact(row.meteo_window)
349451
println(
@@ -356,6 +458,8 @@ function explain_model_specs(target; io::IO=stdout, infer::Bool=true, validate::
356458
row.model,
357459
"]: timestep=",
358460
timestep_desc,
461+
", input_bindings=",
462+
input_bindings_desc,
359463
", meteo_bindings=",
360464
meteo_bindings_desc,
361465
", meteo_window=",

test/test-multirate-runtime.jl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,22 @@ function PlantSimEngine.run!(::MRHourlyFromDailyConsumerModel, models, status, m
166166
status.YD = status.XD
167167
end
168168

169+
PlantSimEngine.@process "mrzconsumer" verbose = false
170+
struct MRZConsumerModel <: AbstractMrzconsumerModel end
171+
PlantSimEngine.inputs_(::MRZConsumerModel) = (Z=-Inf,)
172+
PlantSimEngine.outputs_(::MRZConsumerModel) = (ZZ=-Inf,)
173+
function PlantSimEngine.run!(::MRZConsumerModel, models, status, meteo, constants=nothing, extra=nothing)
174+
status.ZZ = status.Z
175+
end
176+
177+
PlantSimEngine.@process "mrmissinginputconsumer" verbose = false
178+
struct MRMissingInputConsumerModel <: AbstractMrmissinginputconsumerModel end
179+
PlantSimEngine.inputs_(::MRMissingInputConsumerModel) = (U=-Inf,)
180+
PlantSimEngine.outputs_(::MRMissingInputConsumerModel) = (OU=-Inf,)
181+
function PlantSimEngine.run!(::MRMissingInputConsumerModel, models, status, meteo, constants=nothing, extra=nothing)
182+
status.OU = status.U
183+
end
184+
169185
PlantSimEngine.@process "mrmeteodailyconsumer" verbose = false
170186
struct MRMeteoDailyConsumerModel <: AbstractMrmeteodailyconsumerModel end
171187
PlantSimEngine.inputs_(::MRMeteoDailyConsumerModel) = NamedTuple()
@@ -268,6 +284,8 @@ PlantSimEngine.meteo_hint(::Type{<:MRMeteoHintConsumerModel}) = (
268284
@test input_bindings(specs_leaf[:mrconsumer]).C.var == :S
269285
@test input_bindings(specs_leaf[:mrconsumer]).C.policy isa HoldLast
270286
@test output_routing(specs_leaf[:mroverwrite]).C == :stream_only
287+
@test input_bindings(specs_leaf[:mrautosamename]).S.process == :mrsource
288+
@test input_bindings(specs_leaf[:mrautosamename]).S.var == :S
271289

272290
st_leaf = status(sim_ok)["Leaf"][1]
273291
# Expectation 1: consumer :C input is remapped from mrsource/:S via mapping-level InputBindings.
@@ -388,6 +406,26 @@ PlantSimEngine.meteo_hint(::Type{<:MRMeteoHintConsumerModel}) = (
388406
@test st_leaf_cross.XS == 4.0
389407
@test st_plant_cross.XP == 3.0
390408

409+
# Expectation 8a: cross-scale producer is inferred automatically when unique.
410+
source_counter_3_auto = Ref(0)
411+
mapping_cross_auto = Dict(
412+
"Leaf" => (
413+
ModelSpec(MRCrossSourceModel(source_counter_3_auto)) |> TimeStepModel(1.0),
414+
),
415+
"Plant" => (
416+
ModelSpec(MRCrossConsumerModel()) |>
417+
MultiScaleModel([:XS => ["Leaf"]]) |>
418+
TimeStepModel(ClockSpec(2.0, 1.0)),
419+
),
420+
)
421+
sim_cross_auto = PlantSimEngine.GraphSimulation(mtg, mapping_cross_auto, nsteps=4, check=true, outputs=Dict("Leaf" => (:XS,), "Plant" => (:XP,)))
422+
run!(sim_cross_auto, meteo4, multirate=true, executor=SequentialEx())
423+
st_plant_cross_auto = status(sim_cross_auto)["Plant"][1]
424+
@test st_plant_cross_auto.XP == 3.0
425+
spec_cross_auto = PlantSimEngine.get_model_specs(sim_cross_auto)["Plant"][:mrcrossconsumer]
426+
@test input_bindings(spec_cross_auto).XS.process == :mrcrosssource
427+
@test input_bindings(spec_cross_auto).XS.scale == "Leaf"
428+
391429
# Expectation 8b: scope partitioning isolates producer streams between plants.
392430
scene2 = Node(MultiScaleTreeGraph.NodeMTG("/", "Scene", 1, 0))
393431
plant2_a = Node(scene2, MultiScaleTreeGraph.NodeMTG("+", "Plant", 1, 1))
@@ -752,7 +790,28 @@ PlantSimEngine.meteo_hint(::Type{<:MRMeteoHintConsumerModel}) = (
752790
@test_throws "No period available" run!(sim_meteo_calendar_prev_strict, meteo_calendar, multirate=true, executor=SequentialEx())
753791
end
754792

755-
# Expectation 24: invalid mapping-level API configuration fails during GraphSimulation init.
793+
# Expectation 24: ambiguous same-name inferred producer is rejected at initialization.
794+
mapping_ambiguous_infer = Dict(
795+
"Leaf" => (
796+
MRConflict1Model(),
797+
MRConflict2Model(),
798+
MRZConsumerModel(),
799+
),
800+
)
801+
@test_throws "Ambiguous inferred producer for input `Z`" PlantSimEngine.GraphSimulation(mtg, mapping_ambiguous_infer, nsteps=1, check=true, outputs=Dict("Leaf" => (:ZZ,)))
802+
803+
# Expectation 25: missing producer remains allowed; model can rely on initialized/forced inputs.
804+
mapping_missing_input = Dict(
805+
"Leaf" => (
806+
MRMissingInputConsumerModel(),
807+
Status(U=42.0),
808+
),
809+
)
810+
sim_missing_input = PlantSimEngine.GraphSimulation(mtg, mapping_missing_input, nsteps=1, check=true, outputs=Dict("Leaf" => (:OU,)))
811+
run!(sim_missing_input, meteo, multirate=true, executor=SequentialEx())
812+
@test status(sim_missing_input)["Leaf"][1].OU == 42.0
813+
814+
# Expectation 26: invalid mapping-level API configuration fails during GraphSimulation init.
756815
mapping_bad_input = Dict(
757816
"Leaf" => (
758817
MRSourceModel(),

test/test-multirate-scaffolding.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using Test
2828
@test explained[1].process == :process1
2929
@test occursin("Resolved model specs:", explain_txt)
3030
@test occursin("Leaf/process1", explain_txt)
31+
@test occursin("input_bindings=", explain_txt)
3132

3233
spec = ModelSpec(m) |>
3334
TimeStepModel(24.0) |>

0 commit comments

Comments
 (0)